Skip to content

Commit 1b09dda

Browse files
committed
cleanup
1 parent a39f336 commit 1b09dda

File tree

1 file changed

+67
-64
lines changed

1 file changed

+67
-64
lines changed

ipython2cwl/cwltoolextractor.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
)
3030

3131

32-
# TODO: check if supports recursion if main function exists
33-
3432
class AnnotatedVariablesExtractor(ast.NodeTransformer):
3533
input_type_mapper = {
3634
(CWLFilePathInput.__name__,): (
@@ -59,7 +57,7 @@ class AnnotatedVariablesExtractor(ast.NodeTransformer):
5957
}}
6058

6159
output_type_mapper = {
62-
CWLFilePathOutput.__name__
60+
(CWLFilePathOutput.__name__,)
6361
}
6462

6563
dumpable_mapper = {
@@ -88,75 +86,80 @@ def __get_annotation__(self, type_annotation):
8886
annotation = (type_annotation.func.value.id, type_annotation.func.attr)
8987
return annotation
9088

89+
@classmethod
90+
def conv_AnnAssign_to_Assign(cls, node):
91+
return ast.Assign(
92+
col_offset=node.col_offset,
93+
lineno=node.lineno,
94+
targets=[node.target],
95+
value=node.value
96+
)
97+
98+
def _visit_input_ann_assign(self, node, annotation):
99+
mapper = self.input_type_mapper[annotation]
100+
self.extracted_variables.append(_VariableNameTypePair(
101+
node.target.id, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False, None)
102+
)
103+
return None
104+
105+
def _visit_default_dumper(self, node, dumper):
106+
dump_tree = ast.parse(dumper.format(var_name=node.target.id))
107+
self.to_dump.append(dump_tree.body)
108+
self.extracted_variables.append(_VariableNameTypePair(
109+
node.target.id, None, None, None, False, True, node.target.id)
110+
)
111+
return self.conv_AnnAssign_to_Assign(node)
112+
113+
def _visit_user_defined_dumper(self, node):
114+
load_ctx = ast.Load()
115+
func_name = deepcopy(node.annotation.args[0].value)
116+
func_name.ctx = load_ctx
117+
ast.fix_missing_locations(func_name)
118+
119+
new_dump_node = ast.Expr(
120+
col_offset=0, lineno=0,
121+
value=ast.Call(
122+
args=node.annotation.args[1:], keywords=node.annotation.keywords, col_offset=0,
123+
func=ast.Attribute(
124+
attr=node.annotation.args[0].attr,
125+
value=func_name,
126+
col_offset=0, ctx=load_ctx, lineno=0,
127+
),
128+
)
129+
)
130+
ast.fix_missing_locations(new_dump_node)
131+
self.to_dump.append([new_dump_node])
132+
self.extracted_variables.append(_VariableNameTypePair(
133+
node.target.id, None, None, None, False, True, node.annotation.args[1].s)
134+
)
135+
# removing type annotation
136+
return self.conv_AnnAssign_to_Assign(node)
137+
138+
def _visit_output_type(self, node):
139+
self.extracted_variables.append(_VariableNameTypePair(
140+
node.target.id, None, None, None, False, True, node.value.s)
141+
)
142+
# removing type annotation
143+
return ast.Assign(
144+
col_offset=node.col_offset,
145+
lineno=node.lineno,
146+
targets=[node.target],
147+
value=node.value
148+
)
149+
91150
def visit_AnnAssign(self, node):
92151
try:
93152
annotation = self.__get_annotation__(node.annotation)
94153
if annotation in self.input_type_mapper:
95-
mapper = self.input_type_mapper[annotation]
96-
self.extracted_variables.append(_VariableNameTypePair(
97-
node.target.id, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False, None)
98-
)
99-
return None
154+
return self._visit_input_ann_assign(node, annotation)
100155
elif annotation in self.dumpable_mapper:
101156
dumper = self.dumpable_mapper[annotation]
102157
if dumper is not None:
103-
dump_tree = ast.parse(dumper.format(var_name=node.target.id))
104-
self.to_dump.append(dump_tree.body)
105-
self.extracted_variables.append(_VariableNameTypePair(
106-
node.target.id, None, None, None, False, True, node.target.id)
107-
)
108-
# removing type annotation
109-
return ast.Assign(
110-
col_offset=node.col_offset,
111-
lineno=node.lineno,
112-
targets=[node.target],
113-
value=node.value
114-
)
158+
return self._visit_default_dumper(node, dumper)
115159
else:
116-
load_ctx = ast.Load()
117-
func_name = deepcopy(node.annotation.args[0].value)
118-
func_name.ctx = load_ctx
119-
ast.fix_missing_locations(func_name)
120-
121-
new_dump_node = ast.Expr(
122-
col_offset=0, lineno=0,
123-
value=ast.Call(
124-
args=node.annotation.args[1:],
125-
col_offset=0,
126-
func=ast.Attribute(
127-
attr=node.annotation.args[0].attr,
128-
col_offset=0,
129-
ctx=load_ctx,
130-
lineno=0,
131-
value=func_name,
132-
),
133-
keywords=node.annotation.keywords
134-
)
135-
)
136-
ast.fix_missing_locations(new_dump_node)
137-
self.to_dump.append([new_dump_node])
138-
self.extracted_variables.append(_VariableNameTypePair(
139-
node.target.id, None, None, None, False, True, node.annotation.args[1].s)
140-
)
141-
# removing type annotation
142-
return ast.Assign(
143-
col_offset=node.col_offset,
144-
lineno=node.lineno,
145-
targets=[node.target],
146-
value=node.value
147-
)
148-
elif (isinstance(node.annotation, ast.Name) and node.annotation.id in self.output_type_mapper) or \
149-
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.output_type_mapper):
150-
self.extracted_variables.append(_VariableNameTypePair(
151-
node.target.id, None, None, None, False, True, node.value.s)
152-
)
153-
# removing type annotation
154-
return ast.Assign(
155-
col_offset=node.col_offset,
156-
lineno=node.lineno,
157-
targets=[node.target],
158-
value=node.value
159-
)
160+
return self._visit_user_defined_dumper(node)
161+
elif annotation in self.output_type_mapper:
162+
return self._visit_output_type(node)
160163
except Exception:
161164
pass
162165
return node

0 commit comments

Comments
 (0)