Skip to content

Commit e245624

Browse files
committed
improve typing parsing on list & optionals
1 parent 9c9a15b commit e245624

File tree

2 files changed

+79
-26
lines changed

2 files changed

+79
-26
lines changed

ipython2cwl/cwltoolextractor.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,31 @@
2626

2727
class AnnotatedVariablesExtractor(ast.NodeTransformer):
2828
input_type_mapper = {
29-
CWLFilePathInput.__name__: (
29+
(CWLFilePathInput.__name__,): (
3030
'File',
3131
'pathlib.Path',
3232
),
33-
CWLBooleanInput.__name__: (
33+
(CWLBooleanInput.__name__,): (
3434
'boolean',
3535
'lambda flag: flag.upper() == "TRUE"',
3636
),
37-
CWLIntInput.__name__: (
37+
(CWLIntInput.__name__,): (
3838
'int',
3939
'int',
4040
),
41-
CWLStringInput.__name__: (
41+
(CWLStringInput.__name__,): (
4242
'string',
4343
'str',
4444
),
4545
}
46+
input_type_mapper = {**input_type_mapper, **{
47+
('List', *(t for t in types_names)): (types[0] + "[]", types[1])
48+
for types_names, types in input_type_mapper.items()
49+
}, **{
50+
('Optional', *(t for t in types_names)): (types[0] + "?", types[1])
51+
for types_names, types in input_type_mapper.items()
52+
}}
53+
4654
output_type_mapper = {
4755
CWLFilePathOutput.__name__
4856
}
@@ -51,33 +59,29 @@ def __init__(self, *args, **kwargs):
5159
super().__init__(*args, **kwargs)
5260
self.extracted_nodes = []
5361

62+
def __get_annotation__(self, type_annotation):
63+
annotation = None
64+
if isinstance(type_annotation, ast.Name):
65+
annotation = (type_annotation.id,)
66+
elif isinstance(type_annotation, ast.Str):
67+
annotation = (type_annotation.s,)
68+
ann_expr = ast.parse(type_annotation.s.strip()).body[0]
69+
if hasattr(ann_expr, 'value') and isinstance(ann_expr.value, ast.Subscript):
70+
annotation = self.__get_annotation__(ann_expr.value)
71+
elif isinstance(type_annotation, ast.Subscript):
72+
annotation = (type_annotation.value.id, *self.__get_annotation__(type_annotation.slice.value))
73+
return annotation
74+
5475
def visit_AnnAssign(self, node):
5576
try:
56-
if (isinstance(node.annotation, ast.Name) and node.annotation.id in self.input_type_mapper) or \
57-
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.input_type_mapper):
58-
if hasattr(node.annotation, 'id'):
59-
mapper = self.input_type_mapper[node.annotation.id]
60-
else:
61-
mapper = self.input_type_mapper[node.annotation.s]
77+
annotation = self.__get_annotation__(node.annotation)
78+
if annotation in self.input_type_mapper:
79+
mapper = self.input_type_mapper[annotation]
6280
self.extracted_nodes.append(
6381
(node, mapper[0], mapper[1], True, True, False)
6482
)
6583
return None
66-
elif isinstance(node.annotation, ast.Subscript):
67-
if node.annotation.value.id == "Optional" \
68-
and node.annotation.slice.value.id in self.input_type_mapper:
69-
mapper = self.input_type_mapper[node.annotation.slice.value.id]
70-
self.extracted_nodes.append(
71-
(node, mapper[0] + '?', mapper[1], False, True, False)
72-
)
73-
return None
74-
elif node.annotation.value.id == "List" \
75-
and node.annotation.slice.value.id in self.input_type_mapper:
76-
mapper = self.input_type_mapper[node.annotation.slice.value.id]
77-
self.extracted_nodes.append(
78-
(node, mapper[0] + '[]', mapper[1], True, True, False)
79-
)
80-
return None
84+
8185
elif (isinstance(node.annotation, ast.Name) and node.annotation.id in self.output_type_mapper) or \
8286
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.output_type_mapper):
8387
self.extracted_nodes.append(
@@ -90,7 +94,7 @@ def visit_AnnAssign(self, node):
9094
targets=[node.target],
9195
value=node.value
9296
)
93-
except AttributeError:
97+
except Exception:
9498
pass
9599
return node
96100

tests/test_cwltoolextractor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,52 @@ def test_AnnotatedIPython2CWLToolConverter_exclamation_mark_command(self):
328328
exec(new_script_without_magics)
329329
locals()['main']('original\n!ls -l')
330330
self.assertEqual('original\n!ls -l', globals()['printed_message'])
331+
332+
def test_AnnotatedIPython2CWLToolConverter_optional_array_input(self):
333+
s1 = os.linesep.join([
334+
'x1: CWLBooleanInput = True',
335+
])
336+
s2 = os.linesep.join([
337+
'x1: "CWLBooleanInput" = True',
338+
])
339+
# all variables must be the same
340+
self.assertEqual(
341+
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
342+
AnnotatedIPython2CWLToolConverter(s2)._variables[0],
343+
)
344+
345+
s1 = os.linesep.join([
346+
'x1: Optional[CWLBooleanInput] = True',
347+
])
348+
s2 = os.linesep.join([
349+
'x1: "Optional[CWLBooleanInput]" = True',
350+
])
351+
s3 = os.linesep.join([
352+
'x1: Optional["CWLBooleanInput"] = True',
353+
])
354+
# all variables must be the same
355+
self.assertEqual(
356+
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
357+
AnnotatedIPython2CWLToolConverter(s2)._variables[0],
358+
)
359+
self.assertEqual(
360+
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
361+
AnnotatedIPython2CWLToolConverter(s3)._variables[0],
362+
)
363+
364+
# test that does not crash
365+
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
366+
'x1: RandomHint = True'
367+
]))._variables)
368+
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
369+
'x1: List[RandomHint] = True'
370+
]))._variables)
371+
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
372+
'x1: List["RandomHint"] = True'
373+
]))._variables)
374+
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
375+
'x1: "List[List[Union[RandomHint, Foo]]]" = True'
376+
]))._variables)
377+
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
378+
'x1: "RANDOM CHARACTERS!!!!!!" = True'
379+
]))._variables)

0 commit comments

Comments
 (0)