66from __future__ import (print_function , division , unicode_literals ,
77 absolute_import )
88from builtins import open
9+ import pytest
910
11+ from .... import config
1012from ... import engine as pe
1113from ....interfaces import base as nib
1214from ....interfaces .utility import IdentityInterface , Function , Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
4547 output1 = nib .traits .Int (desc = 'ouput' )
4648
4749
48- class IncrementInterface (nib .BaseInterface ):
50+ class IncrementInterface (nib .SimpleInterface ):
4951 input_spec = IncrementInputSpec
5052 output_spec = IncrementOutputSpec
5153
5254 def _run_interface (self , runtime ):
5355 runtime .returncode = 0
56+ self ._results ['output1' ] = self .inputs .input1 + self .inputs .inc
5457 return runtime
5558
56- def _list_outputs (self ):
57- outputs = self ._outputs ().get ()
58- outputs ['output1' ] = self .inputs .input1 + self .inputs .inc
59- return outputs
60-
6159
6260_sums = []
6361
@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
7371 operands = nib .traits .List (nib .traits .Int , desc = 'operands' )
7472
7573
76- class SumInterface (nib .BaseInterface ):
74+ class SumInterface (nib .SimpleInterface ):
7775 input_spec = SumInputSpec
7876 output_spec = SumOutputSpec
7977
8078 def _run_interface (self , runtime ):
81- runtime .returncode = 0
82- return runtime
83-
84- def _list_outputs (self ):
8579 global _sum
8680 global _sum_operands
87- outputs = self . _outputs (). get ()
88- outputs ['operands' ] = self .inputs .input1
89- _sum_operands . append ( outputs [ 'operands' ] )
90- outputs [ 'output1' ] = sum (self .inputs .input1 )
91- _sums .append (outputs [ 'output1' ] )
92- return outputs
81+ runtime . returncode = 0
82+ self . _results ['operands' ] = self .inputs .input1
83+ self . _results [ 'output1' ] = sum ( self . inputs . input1 )
84+ _sum_operands . append (self .inputs .input1 )
85+ _sums .append (sum ( self . inputs . input1 ) )
86+ return runtime
9387
9488
9589_set_len = None
@@ -148,35 +142,48 @@ def _list_outputs(self):
148142 return outputs
149143
150144
151- def test_join_expansion (tmpdir ):
145+ @pytest .mark .parametrize ('needed_outputs' , ['true' , 'false' ])
146+ def test_join_expansion (tmpdir , needed_outputs ):
147+ global _sums
148+ global _sum_operands
149+ global _products
152150 tmpdir .chdir ()
153151
152+ # Clean up, just in case some other test modified them
153+ _products = []
154+ _sum_operands = []
155+ _sums = []
156+
157+ prev_state = config .get ('execution' , 'remove_unnecessary_outputs' )
158+ config .set ('execution' , 'remove_unnecessary_outputs' , needed_outputs )
154159 # Make the workflow.
155160 wf = pe .Workflow (name = 'test' )
156161 # the iterated input node
157162 inputspec = pe .Node (IdentityInterface (fields = ['n' ]), name = 'inputspec' )
158163 inputspec .iterables = [('n' , [1 , 2 ])]
159164 # a pre-join node in the iterated path
160165 pre_join1 = pe .Node (IncrementInterface (), name = 'pre_join1' )
161- wf .connect (inputspec , 'n' , pre_join1 , 'input1' )
162166 # another pre-join node in the iterated path
163167 pre_join2 = pe .Node (IncrementInterface (), name = 'pre_join2' )
164- wf .connect (pre_join1 , 'output1' , pre_join2 , 'input1' )
165168 # the join node
166169 join = pe .JoinNode (
167170 SumInterface (),
168171 joinsource = 'inputspec' ,
169172 joinfield = 'input1' ,
170173 name = 'join' )
171- wf .connect (pre_join2 , 'output1' , join , 'input1' )
172174 # an uniterated post-join node
173175 post_join1 = pe .Node (IncrementInterface (), name = 'post_join1' )
174- wf .connect (join , 'output1' , post_join1 , 'input1' )
175176 # a post-join node in the iterated path
176177 post_join2 = pe .Node (ProductInterface (), name = 'post_join2' )
177- wf .connect (join , 'output1' , post_join2 , 'input1' )
178- wf .connect (pre_join1 , 'output1' , post_join2 , 'input2' )
179178
179+ wf .connect ([
180+ (inputspec , pre_join1 , [('n' , 'input1' )]),
181+ (pre_join1 , pre_join2 , [('output1' , 'input1' )]),
182+ (pre_join1 , post_join2 , [('output1' , 'input2' )]),
183+ (pre_join2 , join , [('output1' , 'input1' )]),
184+ (join , post_join1 , [('output1' , 'input1' )]),
185+ (join , post_join2 , [('output1' , 'input1' )]),
186+ ])
180187 result = wf .run ()
181188
182189 # the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +192,8 @@ def test_join_expansion(tmpdir):
185192 # the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186193 # node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187194 # Nipype factors away the IdentityInterface.
188- assert len (
189- result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
195+ assert len (result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
196+
190197 # the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191198 assert len (_sums ) == 1 , "The number of join outputs is incorrect"
192199 assert _sums [
@@ -197,6 +204,7 @@ def test_join_expansion(tmpdir):
197204 # there are two iterations of the post-join node in the iterable path
198205 assert len (_products ) == 2 ,\
199206 "The number of iterated post-join outputs is incorrect"
207+ config .set ('execution' , 'remove_unnecessary_outputs' , prev_state )
200208
201209
202210def test_node_joinsource (tmpdir ):
0 commit comments