|
29 | 29 | ) |
30 | 30 |
|
31 | 31 |
|
32 | | -# TODO: check if supports recursion if main function exists |
33 | | - |
34 | 32 | class AnnotatedVariablesExtractor(ast.NodeTransformer): |
35 | 33 | input_type_mapper = { |
36 | 34 | (CWLFilePathInput.__name__,): ( |
@@ -59,7 +57,7 @@ class AnnotatedVariablesExtractor(ast.NodeTransformer): |
59 | 57 | }} |
60 | 58 |
|
61 | 59 | output_type_mapper = { |
62 | | - CWLFilePathOutput.__name__ |
| 60 | + (CWLFilePathOutput.__name__,) |
63 | 61 | } |
64 | 62 |
|
65 | 63 | dumpable_mapper = { |
@@ -88,75 +86,80 @@ def __get_annotation__(self, type_annotation): |
88 | 86 | annotation = (type_annotation.func.value.id, type_annotation.func.attr) |
89 | 87 | return annotation |
90 | 88 |
|
| 89 | + @classmethod |
| 90 | + def conv_AnnAssign_to_Assign(cls, node): |
| 91 | + return ast.Assign( |
| 92 | + col_offset=node.col_offset, |
| 93 | + lineno=node.lineno, |
| 94 | + targets=[node.target], |
| 95 | + value=node.value |
| 96 | + ) |
| 97 | + |
| 98 | + def _visit_input_ann_assign(self, node, annotation): |
| 99 | + mapper = self.input_type_mapper[annotation] |
| 100 | + self.extracted_variables.append(_VariableNameTypePair( |
| 101 | + node.target.id, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False, None) |
| 102 | + ) |
| 103 | + return None |
| 104 | + |
| 105 | + def _visit_default_dumper(self, node, dumper): |
| 106 | + dump_tree = ast.parse(dumper.format(var_name=node.target.id)) |
| 107 | + self.to_dump.append(dump_tree.body) |
| 108 | + self.extracted_variables.append(_VariableNameTypePair( |
| 109 | + node.target.id, None, None, None, False, True, node.target.id) |
| 110 | + ) |
| 111 | + return self.conv_AnnAssign_to_Assign(node) |
| 112 | + |
| 113 | + def _visit_user_defined_dumper(self, node): |
| 114 | + load_ctx = ast.Load() |
| 115 | + func_name = deepcopy(node.annotation.args[0].value) |
| 116 | + func_name.ctx = load_ctx |
| 117 | + ast.fix_missing_locations(func_name) |
| 118 | + |
| 119 | + new_dump_node = ast.Expr( |
| 120 | + col_offset=0, lineno=0, |
| 121 | + value=ast.Call( |
| 122 | + args=node.annotation.args[1:], keywords=node.annotation.keywords, col_offset=0, |
| 123 | + func=ast.Attribute( |
| 124 | + attr=node.annotation.args[0].attr, |
| 125 | + value=func_name, |
| 126 | + col_offset=0, ctx=load_ctx, lineno=0, |
| 127 | + ), |
| 128 | + ) |
| 129 | + ) |
| 130 | + ast.fix_missing_locations(new_dump_node) |
| 131 | + self.to_dump.append([new_dump_node]) |
| 132 | + self.extracted_variables.append(_VariableNameTypePair( |
| 133 | + node.target.id, None, None, None, False, True, node.annotation.args[1].s) |
| 134 | + ) |
| 135 | + # removing type annotation |
| 136 | + return self.conv_AnnAssign_to_Assign(node) |
| 137 | + |
| 138 | + def _visit_output_type(self, node): |
| 139 | + self.extracted_variables.append(_VariableNameTypePair( |
| 140 | + node.target.id, None, None, None, False, True, node.value.s) |
| 141 | + ) |
| 142 | + # removing type annotation |
| 143 | + return ast.Assign( |
| 144 | + col_offset=node.col_offset, |
| 145 | + lineno=node.lineno, |
| 146 | + targets=[node.target], |
| 147 | + value=node.value |
| 148 | + ) |
| 149 | + |
91 | 150 | def visit_AnnAssign(self, node): |
92 | 151 | try: |
93 | 152 | annotation = self.__get_annotation__(node.annotation) |
94 | 153 | if annotation in self.input_type_mapper: |
95 | | - mapper = self.input_type_mapper[annotation] |
96 | | - self.extracted_variables.append(_VariableNameTypePair( |
97 | | - node.target.id, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False, None) |
98 | | - ) |
99 | | - return None |
| 154 | + return self._visit_input_ann_assign(node, annotation) |
100 | 155 | elif annotation in self.dumpable_mapper: |
101 | 156 | dumper = self.dumpable_mapper[annotation] |
102 | 157 | if dumper is not None: |
103 | | - dump_tree = ast.parse(dumper.format(var_name=node.target.id)) |
104 | | - self.to_dump.append(dump_tree.body) |
105 | | - self.extracted_variables.append(_VariableNameTypePair( |
106 | | - node.target.id, None, None, None, False, True, node.target.id) |
107 | | - ) |
108 | | - # removing type annotation |
109 | | - return ast.Assign( |
110 | | - col_offset=node.col_offset, |
111 | | - lineno=node.lineno, |
112 | | - targets=[node.target], |
113 | | - value=node.value |
114 | | - ) |
| 158 | + return self._visit_default_dumper(node, dumper) |
115 | 159 | else: |
116 | | - load_ctx = ast.Load() |
117 | | - func_name = deepcopy(node.annotation.args[0].value) |
118 | | - func_name.ctx = load_ctx |
119 | | - ast.fix_missing_locations(func_name) |
120 | | - |
121 | | - new_dump_node = ast.Expr( |
122 | | - col_offset=0, lineno=0, |
123 | | - value=ast.Call( |
124 | | - args=node.annotation.args[1:], |
125 | | - col_offset=0, |
126 | | - func=ast.Attribute( |
127 | | - attr=node.annotation.args[0].attr, |
128 | | - col_offset=0, |
129 | | - ctx=load_ctx, |
130 | | - lineno=0, |
131 | | - value=func_name, |
132 | | - ), |
133 | | - keywords=node.annotation.keywords |
134 | | - ) |
135 | | - ) |
136 | | - ast.fix_missing_locations(new_dump_node) |
137 | | - self.to_dump.append([new_dump_node]) |
138 | | - self.extracted_variables.append(_VariableNameTypePair( |
139 | | - node.target.id, None, None, None, False, True, node.annotation.args[1].s) |
140 | | - ) |
141 | | - # removing type annotation |
142 | | - return ast.Assign( |
143 | | - col_offset=node.col_offset, |
144 | | - lineno=node.lineno, |
145 | | - targets=[node.target], |
146 | | - value=node.value |
147 | | - ) |
148 | | - elif (isinstance(node.annotation, ast.Name) and node.annotation.id in self.output_type_mapper) or \ |
149 | | - (isinstance(node.annotation, ast.Str) and node.annotation.s in self.output_type_mapper): |
150 | | - self.extracted_variables.append(_VariableNameTypePair( |
151 | | - node.target.id, None, None, None, False, True, node.value.s) |
152 | | - ) |
153 | | - # removing type annotation |
154 | | - return ast.Assign( |
155 | | - col_offset=node.col_offset, |
156 | | - lineno=node.lineno, |
157 | | - targets=[node.target], |
158 | | - value=node.value |
159 | | - ) |
| 160 | + return self._visit_user_defined_dumper(node) |
| 161 | + elif annotation in self.output_type_mapper: |
| 162 | + return self._visit_output_type(node) |
160 | 163 | except Exception: |
161 | 164 | pass |
162 | 165 | return node |
|
0 commit comments