22
22
23
23
logger = logging .getLogger ("nipype2pydra" )
24
24
25
- CALLABLES_ARGS = ["inputs" , "stdout" , "stderr" , "output_dir" ]
25
+ OUT_FUNC_ARGS = ["callable" , "formatter" ] # arguments to shell.out that are functions
26
+ CALLABLE_ARGS = [
27
+ "inputs" ,
28
+ "stdout" ,
29
+ "stderr" ,
30
+ "output_dir" ,
31
+ ] # Arguments for callable methods
26
32
27
33
28
34
@attrs .define (slots = False )
@@ -93,13 +99,26 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> str:
93
99
94
100
# Pull out xor fields into task-level xor_sets
95
101
xor_sets = set ()
102
+ has_zero_pos = False
96
103
for inpt in input_fields :
97
104
if len (inpt ) == 3 :
98
105
name , _ , mdata = inpt
99
106
else :
100
107
name , _ , __ , mdata = inpt
101
108
if "xor" in mdata :
102
- xor_sets .add (frozenset (mdata ["xor" ] + [name ]))
109
+ xor_sets .add (frozenset (list (mdata ["xor" ]) + [name ]))
110
+ if mdata .get ("position" , None ) == 0 :
111
+ has_zero_pos = True
112
+
113
+ # Increment positions if there is a zero position
114
+ if has_zero_pos :
115
+ for inpt in input_fields :
116
+ if len (inpt ) == 3 :
117
+ name , _ , mdata = inpt
118
+ else :
119
+ name , _ , __ , mdata = inpt
120
+ if "position" in mdata and mdata ["position" ] >= 0 :
121
+ mdata ["position" ] = mdata .pop ("position" ) + 1
103
122
104
123
input_fields_str = ""
105
124
output_fields_str = ""
@@ -126,26 +145,24 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> str:
126
145
mdata .pop ("xor" , None )
127
146
args_str = ", " .join (f"{ k } ={ v !r} " for k , v in mdata .items ())
128
147
if "path_template" in mdata :
129
- output_fields_str = (
148
+ output_fields_str + = (
130
149
f" { name } : { type_str } = shell.outarg({ args_str } )\n "
131
150
)
132
151
else :
133
152
input_fields_str += f" { name } : { type_str } = shell.arg({ args_str } )\n "
134
153
135
- callable_fields = set (n for n , _ , __ in self .callable_output_fields )
154
+ # callable_fields = set(n for n, _, __ in self.callable_output_fields)
136
155
137
156
for outpt in output_fields :
138
157
name , type_ , mdata = outpt
139
- cllble = mdata .pop (
140
- "callable" , f"{ name } _callable" if name in callable_fields else None
141
- )
142
- args_str = ", " .join (f"{ k } ={ v !r} " for k , v in mdata .items ())
143
- if args_str :
144
- args_str += ", "
145
- if cllble :
146
- args_str += f"callable={ cllble } "
158
+ func_args = []
159
+ for func_arg in OUT_FUNC_ARGS :
160
+ if func_arg in mdata :
161
+ func_args .append (f"{ func_arg } ={ mdata [func_arg ]} " )
162
+ mdata .pop (func_arg )
163
+ args = [f"{ k } ={ v !r} " for k , v in mdata .items ()] + func_args
147
164
output_fields_str += (
148
- f" { name } : { type_to_str (type_ )} = shell.out({ args_str } )\n "
165
+ f" { name } : { type_to_str (type_ )} = shell.out({ ', ' . join ( args ) } )\n "
149
166
)
150
167
151
168
spec_str = (
@@ -182,7 +199,7 @@ def generate_code(self, input_fields, nonstd_types, output_fields) -> str:
182
199
s [0 ] == self .nipype_interface ._list_outputs
183
200
for s in self .used .method_stacks [m .__name__ ]
184
201
):
185
- additional_args = CALLABLES_ARGS
202
+ additional_args = CALLABLE_ARGS
186
203
else :
187
204
additional_args = []
188
205
method_str = self .process_method (
@@ -418,7 +435,7 @@ def callables_code(self):
418
435
if not agg_body .strip ():
419
436
return ""
420
437
agg_body = self .unwrap_nested_methods (
421
- agg_body , additional_args = CALLABLES_ARGS , inputs_as_dict = True
438
+ agg_body , additional_args = CALLABLE_ARGS , inputs_as_dict = True
422
439
)
423
440
agg_body = self .replace_supers (
424
441
agg_body ,
@@ -477,13 +494,15 @@ def callables_code(self):
477
494
)
478
495
lo_body = self ._process_inputs (lo_body )
479
496
lo_body = re .sub (
480
- r"(\w+) = self\.output_spec\(\).(?:trait_)get\(\)" , r"\1 = {}" , lo_body
497
+ r"(\w+) = self\.output_spec\(\).(?:trait_)get\(\)" ,
498
+ r"\1 = {}" ,
499
+ lo_body ,
481
500
)
482
501
483
502
if not lo_body .strip ():
484
503
return ""
485
504
lo_body = self .unwrap_nested_methods (
486
- lo_body , additional_args = CALLABLES_ARGS , inputs_as_dict = True
505
+ lo_body , additional_args = CALLABLE_ARGS , inputs_as_dict = True
487
506
)
488
507
lo_body = self .replace_supers (
489
508
lo_body ,
0 commit comments