55import tarfile
66import tempfile
77from collections import namedtuple
8+ from copy import deepcopy
89from pathlib import Path
9- from typing import Dict , Any
10+ from typing import Dict , Any , List
1011
1112import astor
1213import nbconvert
1314import yaml
1415from nbformat .notebooknode import NotebookNode
1516
16- from .iotypes import CWLFilePathInput , CWLBooleanInput , CWLIntInput , CWLStringInput , CWLFilePathOutput
17+ from .iotypes import CWLFilePathInput , CWLBooleanInput , CWLIntInput , CWLStringInput , CWLFilePathOutput , \
18+ CWLDumpableFile , CWLDumpableBinaryFile , CWLDumpable
1719from .requirements_manager import RequirementsManager
1820
1921with open (os .sep .join ([os .path .abspath (os .path .dirname (__file__ )), 'templates' , 'template.dockerfile' ])) as f :
2022 DOCKERFILE_TEMPLATE = f .read ()
2123with open (os .sep .join ([os .path .abspath (os .path .dirname (__file__ )), 'templates' , 'template.setup' ])) as f :
2224 SETUP_TEMPLATE = f .read ()
2325
26+ _VariableNameTypePair = namedtuple (
27+ 'VariableNameTypePair' ,
28+ ['name' , 'cwl_typeof' , 'argparse_typeof' , 'required' , 'is_input' , 'is_output' , 'value' ]
29+ )
2430
25- # TODO: check if supports recursion if main function exists
2631
2732class AnnotatedVariablesExtractor (ast .NodeTransformer ):
2833 input_type_mapper = {
@@ -52,12 +57,19 @@ class AnnotatedVariablesExtractor(ast.NodeTransformer):
5257 }}
5358
5459 output_type_mapper = {
55- CWLFilePathOutput .__name__
60+ (CWLFilePathOutput .__name__ ,)
61+ }
62+
63+ dumpable_mapper = {
64+ (CWLDumpableFile .__name__ ,): "with open('{var_name}', 'w') as f:\n \t f.write({var_name})" ,
65+ (CWLDumpableBinaryFile .__name__ ,): "with open('{var_name}', 'wb') as f:\n \t f.write({var_name})" ,
66+ (CWLDumpable .__name__ , CWLDumpable .dump .__name__ ): None ,
5667 }
5768
5869 def __init__ (self , * args , ** kwargs ):
5970 super ().__init__ (* args , ** kwargs )
60- self .extracted_nodes = []
71+ self .extracted_variables : List = []
72+ self .to_dump : List = []
6173
6274 def __get_annotation__ (self , type_annotation ):
6375 annotation = None
@@ -70,30 +82,84 @@ def __get_annotation__(self, type_annotation):
7082 annotation = self .__get_annotation__ (ann_expr .value )
7183 elif isinstance (type_annotation , ast .Subscript ):
7284 annotation = (type_annotation .value .id , * self .__get_annotation__ (type_annotation .slice .value ))
85+ elif isinstance (type_annotation , ast .Call ):
86+ annotation = (type_annotation .func .value .id , type_annotation .func .attr )
7387 return annotation
7488
89+ @classmethod
90+ def conv_AnnAssign_to_Assign (cls , node ):
91+ return ast .Assign (
92+ col_offset = node .col_offset ,
93+ lineno = node .lineno ,
94+ targets = [node .target ],
95+ value = node .value
96+ )
97+
98+ def _visit_input_ann_assign (self , node , annotation ):
99+ mapper = self .input_type_mapper [annotation ]
100+ self .extracted_variables .append (_VariableNameTypePair (
101+ node .target .id , mapper [0 ], mapper [1 ], not mapper [0 ].endswith ('?' ), True , False , None )
102+ )
103+ return None
104+
105+ def _visit_default_dumper (self , node , dumper ):
106+ dump_tree = ast .parse (dumper .format (var_name = node .target .id ))
107+ self .to_dump .append (dump_tree .body )
108+ self .extracted_variables .append (_VariableNameTypePair (
109+ node .target .id , None , None , None , False , True , node .target .id )
110+ )
111+ return self .conv_AnnAssign_to_Assign (node )
112+
113+ def _visit_user_defined_dumper (self , node ):
114+ load_ctx = ast .Load ()
115+ func_name = deepcopy (node .annotation .args [0 ].value )
116+ func_name .ctx = load_ctx
117+ ast .fix_missing_locations (func_name )
118+
119+ new_dump_node = ast .Expr (
120+ col_offset = 0 , lineno = 0 ,
121+ value = ast .Call (
122+ args = node .annotation .args [1 :], keywords = node .annotation .keywords , col_offset = 0 ,
123+ func = ast .Attribute (
124+ attr = node .annotation .args [0 ].attr ,
125+ value = func_name ,
126+ col_offset = 0 , ctx = load_ctx , lineno = 0 ,
127+ ),
128+ )
129+ )
130+ ast .fix_missing_locations (new_dump_node )
131+ self .to_dump .append ([new_dump_node ])
132+ self .extracted_variables .append (_VariableNameTypePair (
133+ node .target .id , None , None , None , False , True , node .annotation .args [1 ].s )
134+ )
135+ # removing type annotation
136+ return self .conv_AnnAssign_to_Assign (node )
137+
138+ def _visit_output_type (self , node ):
139+ self .extracted_variables .append (_VariableNameTypePair (
140+ node .target .id , None , None , None , False , True , node .value .s )
141+ )
142+ # removing type annotation
143+ return ast .Assign (
144+ col_offset = node .col_offset ,
145+ lineno = node .lineno ,
146+ targets = [node .target ],
147+ value = node .value
148+ )
149+
75150 def visit_AnnAssign (self , node ):
76151 try :
77152 annotation = self .__get_annotation__ (node .annotation )
78153 if annotation in self .input_type_mapper :
79- mapper = self .input_type_mapper [annotation ]
80- self .extracted_nodes .append (
81- (node , mapper [0 ], mapper [1 ], not mapper [0 ].endswith ('?' ), True , False )
82- )
83- return None
84-
85- elif (isinstance (node .annotation , ast .Name ) and node .annotation .id in self .output_type_mapper ) or \
86- (isinstance (node .annotation , ast .Str ) and node .annotation .s in self .output_type_mapper ):
87- self .extracted_nodes .append (
88- (node , None , None , None , False , True )
89- )
90- # removing type annotation
91- return ast .Assign (
92- col_offset = node .col_offset ,
93- lineno = node .lineno ,
94- targets = [node .target ],
95- value = node .value
96- )
154+ return self ._visit_input_ann_assign (node , annotation )
155+ elif annotation in self .dumpable_mapper :
156+ dumper = self .dumpable_mapper [annotation ]
157+ if dumper is not None :
158+ return self ._visit_default_dumper (node , dumper )
159+ else :
160+ return self ._visit_user_defined_dumper (node )
161+ elif annotation in self .output_type_mapper :
162+ return self ._visit_output_type (node )
97163 except Exception :
98164 pass
99165 return node
@@ -123,12 +189,6 @@ class AnnotatedIPython2CWLToolConverter:
123189 """
124190
125191 _code : str
126-
127- _VariableNameTypePair = namedtuple (
128- 'VariableNameTypePair' ,
129- ['name' , 'cwl_typeof' , 'argparse_typeof' , 'required' , 'is_input' , 'is_output' , 'value' ]
130- )
131-
132192 """The annotated python code to convert."""
133193
134194 def __init__ (self , annotated_ipython_code : str ):
@@ -137,19 +197,15 @@ def __init__(self, annotated_ipython_code: str):
137197
138198 self ._code = annotated_ipython_code
139199 extractor = AnnotatedVariablesExtractor ()
140- self ._tree = ast .fix_missing_locations (extractor .visit (ast .parse (self ._code )))
200+ self ._tree = extractor .visit (ast .parse (self ._code ))
201+ [self ._tree .body .extend (d ) for d in extractor .to_dump ]
202+ self ._tree = ast .fix_missing_locations (self ._tree )
141203 self ._variables = []
142- for node , cwl_type , click_type , required , is_input , is_output in extractor .extracted_nodes :
143- if is_input :
144- self ._variables .append (
145- self ._VariableNameTypePair (node .target .id , cwl_type , click_type , required , is_input , is_output ,
146- None )
147- )
148- if is_output :
149- self ._variables .append (
150- self ._VariableNameTypePair (node .target .id , cwl_type , click_type , required , is_input , is_output ,
151- node .value .s )
152- )
204+ for variable in extractor .extracted_variables : # type: _VariableNameTypePair
205+ if variable .is_input :
206+ self ._variables .append (variable )
207+ if variable .is_output :
208+ self ._variables .append (variable )
153209
154210 @classmethod
155211 def from_jupyter_notebook_node (cls , node : NotebookNode ) -> 'AnnotatedIPython2CWLToolConverter' :
0 commit comments