Skip to content

Commit 3982260

Browse files
committed
add custom dumpables functionality
1 parent 6853f88 commit 3982260

File tree

4 files changed

+111
-21
lines changed

4 files changed

+111
-21
lines changed

ipython2cwl/cwltoolextractor.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tarfile
66
import tempfile
77
from collections import namedtuple
8+
from copy import deepcopy
89
from pathlib import Path
910
from typing import Dict, Any, List
1011

@@ -14,7 +15,7 @@
1415
from nbformat.notebooknode import NotebookNode
1516

1617
from .iotypes import CWLFilePathInput, CWLBooleanInput, CWLIntInput, CWLStringInput, CWLFilePathOutput, \
17-
CWLDumpableFile, CWLDumpableBinaryFile
18+
CWLDumpableFile, CWLDumpableBinaryFile, CWLDumpable
1819
from .requirements_manager import RequirementsManager
1920

2021
with open(os.sep.join([os.path.abspath(os.path.dirname(__file__)), 'templates', 'template.dockerfile'])) as f:
@@ -64,6 +65,7 @@ class AnnotatedVariablesExtractor(ast.NodeTransformer):
6465
dumpable_mapper = {
6566
(CWLDumpableFile.__name__,): "with open('{var_name}', 'w') as f:\n\tf.write({var_name})",
6667
(CWLDumpableBinaryFile.__name__,): "with open('{var_name}', 'wb') as f:\n\tf.write({var_name})",
68+
(CWLDumpable.__name__, CWLDumpable.dump.__name__): None,
6769
}
6870

6971
def __init__(self, *args, **kwargs):
@@ -82,6 +84,8 @@ def __get_annotation__(self, type_annotation):
8284
annotation = self.__get_annotation__(ann_expr.value)
8385
elif isinstance(type_annotation, ast.Subscript):
8486
annotation = (type_annotation.value.id, *self.__get_annotation__(type_annotation.slice.value))
87+
elif isinstance(type_annotation, ast.Call):
88+
annotation = (type_annotation.func.value.id, type_annotation.func.attr)
8589
return annotation
8690

8791
def visit_AnnAssign(self, node):
@@ -94,18 +98,53 @@ def visit_AnnAssign(self, node):
9498
)
9599
return None
96100
elif annotation in self.dumpable_mapper:
97-
dump_tree = ast.parse(self.dumpable_mapper[annotation].format(var_name=node.target.id))
98-
self.to_dump.append(dump_tree.body)
99-
self.extracted_variables.append(_VariableNameTypePair(
100-
node.target.id, None, None, None, False, True, node.target.id)
101-
)
102-
# removing type annotation
103-
return ast.Assign(
104-
col_offset=node.col_offset,
105-
lineno=node.lineno,
106-
targets=[node.target],
107-
value=node.value
108-
)
101+
dumper = self.dumpable_mapper[annotation]
102+
if dumper is not None:
103+
dump_tree = ast.parse(dumper.format(var_name=node.target.id))
104+
self.to_dump.append(dump_tree.body)
105+
self.extracted_variables.append(_VariableNameTypePair(
106+
node.target.id, None, None, None, False, True, node.target.id)
107+
)
108+
# removing type annotation
109+
return ast.Assign(
110+
col_offset=node.col_offset,
111+
lineno=node.lineno,
112+
targets=[node.target],
113+
value=node.value
114+
)
115+
else:
116+
load_ctx = ast.Load()
117+
func_name = deepcopy(node.annotation.args[0].value)
118+
func_name.ctx = load_ctx
119+
ast.fix_missing_locations(func_name)
120+
121+
new_dump_node = ast.Expr(
122+
col_offset=0, lineno=0,
123+
value=ast.Call(
124+
args=node.annotation.args[1:],
125+
col_offset=0,
126+
func=ast.Attribute(
127+
attr=node.annotation.args[0].attr,
128+
col_offset=0,
129+
ctx=load_ctx,
130+
lineno=0,
131+
value=func_name,
132+
),
133+
keywords=node.annotation.keywords
134+
)
135+
)
136+
ast.fix_missing_locations(new_dump_node)
137+
self.to_dump.append([new_dump_node])
138+
self.extracted_variables.append(_VariableNameTypePair(
139+
node.target.id, None, None, None, False, True, node.annotation.args[1].s)
140+
)
141+
# removing type annotation
142+
return ast.Assign(
143+
col_offset=node.col_offset,
144+
lineno=node.lineno,
145+
targets=[node.target],
146+
value=node.value
147+
)
109148
elif (isinstance(node.annotation, ast.Name) and node.annotation.id in self.output_type_mapper) or \
110149
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.output_type_mapper):
111150
self.extracted_variables.append(_VariableNameTypePair(

ipython2cwl/iotypes.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1+
from typing import Callable
2+
3+
14
class _CWLInput:
25
pass
36

47

5-
class CWLFilePathInput(_CWLInput):
8+
class CWLFilePathInput(str, _CWLInput):
69
pass
710

811

912
class CWLBooleanInput(_CWLInput):
1013
pass
1114

1215

13-
class CWLStringInput(_CWLInput):
16+
class CWLStringInput(str, _CWLInput):
1417
pass
1518

1619

@@ -22,17 +25,20 @@ class _CWLOutput:
2225
pass
2326

2427

25-
class CWLFilePathOutput(_CWLOutput):
28+
class CWLFilePathOutput(str, _CWLOutput):
2629
pass
2730

2831

29-
class _CWLDumpable(_CWLOutput):
30-
pass
32+
class CWLDumpable(_CWLOutput):
33+
34+
@classmethod
35+
def dump(cls, dumper: Callable, *args, **kwargs):
36+
return _CWLOutput
3137

3238

33-
class CWLDumpableFile(_CWLDumpable):
39+
class CWLDumpableFile(CWLDumpable):
3440
pass
3541

3642

37-
class CWLDumpableBinaryFile(_CWLDumpable):
43+
class CWLDumpableBinaryFile(CWLDumpable):
3844
pass

test-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ coveralls>=2.0.0
44
virtualenv>=3.1.0
55
gitpython>=3.1.3
66
docker>=4.2.1
7-
git+https://github.com/giannisdoukas/cwltool.git#egg=cwltool
7+
git+https://github.com/giannisdoukas/cwltool.git#egg=cwltool
8+
pandas==1.0.5

tests/test_cwltoolextractor.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,47 @@ def test_AnnotatedIPython2CWLToolConverter_dumpables(self):
430430
},
431431
cwl_tool['outputs']
432432
)
433+
434+
def test_AnnotatedIPython2CWLToolConverter_custom_dumpables(self):
435+
script = os.linesep.join([
436+
'import pandas',
437+
'from ipython2cwl.iotypes import CWLDumpable',
438+
'd: CWLDumpable.dump(d.to_csv, "dumpable.csv", sep="\\t", index=False) = pandas.DataFrame([[1,2,3], [4,5,6], [7,8,9]])'
439+
])
440+
converter = AnnotatedIPython2CWLToolConverter(script)
441+
generated_script = AnnotatedIPython2CWLToolConverter._wrap_script_to_method(
442+
converter._tree, converter._variables
443+
)
444+
for f in ["dumpable.csv"]:
445+
try:
446+
os.remove(f)
447+
except FileNotFoundError:
448+
pass
449+
exec(generated_script)
450+
print(generated_script)
451+
locals()['main']()
452+
import pandas
453+
data_file = pandas.read_csv('dumpable.csv', sep="\t")
454+
self.assertListEqual(
455+
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
456+
(data_file.to_numpy() - pandas.DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).to_numpy()).tolist()
457+
)
458+
459+
cwl_tool = converter.cwl_command_line_tool()
460+
print(cwl_tool)
461+
self.assertDictEqual(
462+
{
463+
'd': {
464+
'type': 'File',
465+
'outputBinding': {
466+
'glob': 'dumpable.csv'
467+
}
468+
},
469+
},
470+
cwl_tool['outputs']
471+
)
472+
for f in ["dumpable.csv"]:
473+
try:
474+
os.remove(f)
475+
except FileNotFoundError:
476+
pass

0 commit comments

Comments
 (0)