66import tempfile
77from collections import namedtuple
88from pathlib import Path
9- from typing import Dict , Any
9+ from typing import Dict , Any , List
1010
1111import astor
1212import nbconvert
1313import yaml
1414from nbformat .notebooknode import NotebookNode
1515
16- from .iotypes import CWLFilePathInput , CWLBooleanInput , CWLIntInput , CWLStringInput , CWLFilePathOutput
16+ from .iotypes import CWLFilePathInput , CWLBooleanInput , CWLIntInput , CWLStringInput , CWLFilePathOutput , CWLDumpableFile , \
17+ CWLDumpableBinaryFile
1718from .requirements_manager import RequirementsManager
1819
1920with open (os .sep .join ([os .path .abspath (os .path .dirname (__file__ )), 'templates' , 'template.dockerfile' ])) as f :
2021 DOCKERFILE_TEMPLATE = f .read ()
2122with open (os .sep .join ([os .path .abspath (os .path .dirname (__file__ )), 'templates' , 'template.setup' ])) as f :
2223 SETUP_TEMPLATE = f .read ()
2324
25+ _VariableNameTypePair = namedtuple (
26+ 'VariableNameTypePair' ,
27+ ['name' , 'cwl_typeof' , 'argparse_typeof' , 'required' , 'is_input' , 'is_output' , 'value' ]
28+ )
29+
2430
2531# TODO: check if supports recursion if main function exists
2632
@@ -55,9 +61,15 @@ class AnnotatedVariablesExtractor(ast.NodeTransformer):
5561 CWLFilePathOutput .__name__
5662 }
5763
64+ dumpable_mapper = {
65+ (CWLDumpableFile .__name__ ,): "with open('{var_name}', 'w') as f:\n \t f.write({var_name})" ,
66+ (CWLDumpableBinaryFile .__name__ ,): "with open('{var_name}', 'wb') as f:\n \t f.write({var_name})" ,
67+ }
68+
5869 def __init__ (self , * args , ** kwargs ):
5970 super ().__init__ (* args , ** kwargs )
60- self .extracted_nodes = []
71+ self .extracted_variables : List = []
72+ self .to_dump : List = []
6173
6274 def __get_annotation__ (self , type_annotation ):
6375 annotation = None
@@ -77,15 +89,27 @@ def visit_AnnAssign(self, node):
7789 annotation = self .__get_annotation__ (node .annotation )
7890 if annotation in self .input_type_mapper :
7991 mapper = self .input_type_mapper [annotation ]
80- self .extracted_nodes .append (
81- ( node , mapper [0 ], mapper [1 ], not mapper [0 ].endswith ('?' ), True , False )
92+ self .extracted_variables .append ( _VariableNameTypePair (
93+ node . target . id , mapper [0 ], mapper [1 ], not mapper [0 ].endswith ('?' ), True , False , None )
8294 )
8395 return None
84-
96+ elif annotation in self .dumpable_mapper :
97+ dump_tree = ast .parse (self .dumpable_mapper [annotation ].format (var_name = node .target .id ))
98+ self .to_dump .append (dump_tree .body )
99+ self .extracted_variables .append (_VariableNameTypePair (
100+ node .target .id , None , None , None , False , True , node .target .id )
101+ )
102+ # removing type annotation
103+ return ast .Assign (
104+ col_offset = node .col_offset ,
105+ lineno = node .lineno ,
106+ targets = [node .target ],
107+ value = node .value
108+ )
85109 elif (isinstance (node .annotation , ast .Name ) and node .annotation .id in self .output_type_mapper ) or \
86110 (isinstance (node .annotation , ast .Str ) and node .annotation .s in self .output_type_mapper ):
87- self .extracted_nodes .append (
88- ( node , None , None , None , False , True )
111+ self .extracted_variables .append ( _VariableNameTypePair (
112+ node . target . id , None , None , None , False , True , node . value . s )
89113 )
90114 # removing type annotation
91115 return ast .Assign (
@@ -123,12 +147,6 @@ class AnnotatedIPython2CWLToolConverter:
123147 """
124148
125149 _code : str
126-
127- _VariableNameTypePair = namedtuple (
128- 'VariableNameTypePair' ,
129- ['name' , 'cwl_typeof' , 'argparse_typeof' , 'required' , 'is_input' , 'is_output' , 'value' ]
130- )
131-
132150 """The annotated python code to convert."""
133151
134152 def __init__ (self , annotated_ipython_code : str ):
@@ -137,19 +155,15 @@ def __init__(self, annotated_ipython_code: str):
137155
138156 self ._code = annotated_ipython_code
139157 extractor = AnnotatedVariablesExtractor ()
140- self ._tree = ast .fix_missing_locations (extractor .visit (ast .parse (self ._code )))
158+ self ._tree = extractor .visit (ast .parse (self ._code ))
159+ [self ._tree .body .extend (d ) for d in extractor .to_dump ]
160+ self ._tree = ast .fix_missing_locations (self ._tree )
141161 self ._variables = []
142- for node , cwl_type , click_type , required , is_input , is_output in extractor .extracted_nodes :
143- if is_input :
144- self ._variables .append (
145- self ._VariableNameTypePair (node .target .id , cwl_type , click_type , required , is_input , is_output ,
146- None )
147- )
148- if is_output :
149- self ._variables .append (
150- self ._VariableNameTypePair (node .target .id , cwl_type , click_type , required , is_input , is_output ,
151- node .value .s )
152- )
162+ for variable in extractor .extracted_variables : # type: _VariableNameTypePair
163+ if variable .is_input :
164+ self ._variables .append (variable )
165+ if variable .is_output :
166+ self ._variables .append (variable )
153167
154168 @classmethod
155169 def from_jupyter_notebook_node (cls , node : NotebookNode ) -> 'AnnotatedIPython2CWLToolConverter' :
0 commit comments