11# SPDX-License-Identifier: Apache-2.0
22import hashlib
3- from typing import Any , IO , List , Optional , Union
3+ from typing import Any , IO , List , MutableSequence , Optional , Tuple , Union , cast
44
55from ruamel import yaml
66from schema_salad .exceptions import ValidationException
77from schema_salad .utils import json_dumps
88
9+ import cwl_utils .parser
910import cwl_utils .parser .cwl_v1_0 as cwl
11+ import cwl_utils .parser .utils
1012from cwl_utils .errors import WorkflowException
1113
12-
1314CONTENT_LIMIT : int = 64 * 1024
1415
1516
17+ def _compare_type (type1 : Any , type2 : Any ) -> bool :
18+ if isinstance (type1 , cwl .ArraySchema ) and isinstance (type2 , cwl .ArraySchema ):
19+ return _compare_type (type1 .items , type2 .items )
20+ elif isinstance (type1 , cwl .RecordSchema ) and isinstance (type2 , cwl .RecordSchema ):
21+ fields1 = {
22+ cwl .shortname (field .name ): field .type for field in (type1 .fields or {})
23+ }
24+ fields2 = {
25+ cwl .shortname (field .name ): field .type for field in (type2 .fields or {})
26+ }
27+ if fields1 .keys () != fields2 .keys ():
28+ return False
29+ return all ((_compare_type (fields1 [k ], fields2 [k ]) for k in fields1 .keys ()))
30+ elif isinstance (type1 , MutableSequence ) and isinstance (type2 , MutableSequence ):
31+ if len (type1 ) != len (type2 ):
32+ return False
33+ for t1 in type1 :
34+ if not any ((_compare_type (t1 , t2 ) for t2 in type2 )):
35+ return False
36+ return True
37+ else :
38+ return bool (type1 == type2 )
39+
40+
1641def content_limit_respected_read_bytes (f : IO [bytes ]) -> bytes :
1742 """
1843 Read file content up to 64 kB as a byte array.
@@ -32,49 +57,102 @@ def content_limit_respected_read(f: IO[bytes]) -> str:
3257
3358
3459def convert_stdstreams_to_files (clt : cwl .CommandLineTool ) -> None :
60+ """Convert stdout and stderr type shortcuts to files."""
3561 for out in clt .outputs :
36- if out .type == ' stdout' :
62+ if out .type == " stdout" :
3763 if out .outputBinding is not None :
3864 raise ValidationException (
39- "Not allowed to specify outputBinding when using stdout shortcut." )
65+ "Not allowed to specify outputBinding when using stdout shortcut."
66+ )
4067 if clt .stdout is None :
41- clt .stdout = str (hashlib .sha1 (json_dumps ( # nosec
42- clt .save (), sort_keys = True ).encode ('utf-8' )).hexdigest ())
43- out .type = 'File'
68+ clt .stdout = str (
69+ hashlib .sha1 ( # nosec
70+ json_dumps (clt .save (), sort_keys = True ).encode ("utf-8" )
71+ ).hexdigest ()
72+ )
73+ out .type = "File"
4474 out .outputBinding = cwl .CommandOutputBinding (glob = clt .stdout )
45- elif out .type == ' stderr' :
75+ elif out .type == " stderr" :
4676 if out .outputBinding is not None :
4777 raise ValidationException (
48- "Not allowed to specify outputBinding when using stderr shortcut." )
78+ "Not allowed to specify outputBinding when using stderr shortcut."
79+ )
4980 if clt .stderr is None :
50- clt .stderr = str (hashlib .sha1 (json_dumps ( # nosec
51- clt .save (), sort_keys = True ).encode ('utf-8' )).hexdigest ())
52- out .type = 'File'
81+ clt .stderr = str (
82+ hashlib .sha1 ( # nosec
83+ json_dumps (clt .save (), sort_keys = True ).encode ("utf-8" )
84+ ).hexdigest ()
85+ )
86+ out .type = "File"
5387 out .outputBinding = cwl .CommandOutputBinding (glob = clt .stderr )
5488
5589
90+ def merge_flatten_type (src : Any ) -> Any :
91+ """Return the merge flattened type of the source type."""
92+ if isinstance (src , MutableSequence ):
93+ return [merge_flatten_type (t ) for t in src ]
94+ if isinstance (src , cwl .ArraySchema ):
95+ return src
96+ return cwl .ArraySchema (type = "array" , items = src )
97+
98+
5699def type_for_source (
57100 process : Union [cwl .CommandLineTool , cwl .Workflow , cwl .ExpressionTool ],
58101 sourcenames : Union [str , List [str ]],
59102 parent : Optional [cwl .Workflow ] = None ,
60- ) -> Union [List [Any ], Any ]:
103+ linkMerge : Optional [str ] = None ,
104+ ) -> Any :
61105 """Determine the type for the given sourcenames."""
62- params = param_for_source_id (process , sourcenames , parent )
106+ scatter_context : List [Optional [Tuple [int , str ]]] = []
107+ params = param_for_source_id (process , sourcenames , parent , scatter_context )
63108 if not isinstance (params , list ):
64- return params .type
65- new_type : List [Any ] = []
66- for p in params :
67- if isinstance (p , str ) and p not in new_type :
68- new_type .append (p )
69- elif hasattr (p , "type" ) and p .type not in new_type :
70- new_type .append (p .type )
71- return new_type
109+ new_type = params .type
110+ if scatter_context [0 ] is not None :
111+ if scatter_context [0 ][1 ] == "nested_crossproduct" :
112+ for _ in range (scatter_context [0 ][0 ]):
113+ new_type = cwl .ArraySchema (items = new_type , type = "array" )
114+ else :
115+ new_type = cwl .ArraySchema (items = new_type , type = "array" )
116+ if linkMerge == "merge_nested" :
117+ new_type = cwl .ArraySchema (items = new_type , type = "array" )
118+ elif linkMerge == "merge_flattened" :
119+ new_type = merge_flatten_type (new_type )
120+ return new_type
121+ new_type = []
122+ for p , sc in zip (params , scatter_context ):
123+ if isinstance (p , str ) and not any ((_compare_type (t , p ) for t in new_type )):
124+ cur_type = p
125+ elif hasattr (p , "type" ) and not any (
126+ (_compare_type (t , p .type ) for t in new_type )
127+ ):
128+ cur_type = p .type
129+ else :
130+ cur_type = None
131+ if cur_type is not None :
132+ if sc is not None :
133+ if sc [1 ] == "nested_crossproduct" :
134+ for _ in range (sc [0 ]):
135+ cur_type = cwl .ArraySchema (items = cur_type , type = "array" )
136+ else :
137+ cur_type = cwl .ArraySchema (items = cur_type , type = "array" )
138+ new_type .append (cur_type )
139+ if len (new_type ) == 1 :
140+ new_type = new_type [0 ]
141+ if linkMerge == "merge_nested" :
142+ return cwl .ArraySchema (items = new_type , type = "array" )
143+ elif linkMerge == "merge_flattened" :
144+ return merge_flatten_type (new_type )
145+ elif isinstance (sourcenames , List ):
146+ return cwl .ArraySchema (items = new_type , type = "array" )
147+ else :
148+ return new_type
72149
73150
74151def param_for_source_id (
75152 process : Union [cwl .CommandLineTool , cwl .Workflow , cwl .ExpressionTool ],
76153 sourcenames : Union [str , List [str ]],
77154 parent : Optional [cwl .Workflow ] = None ,
155+ scatter_context : Optional [List [Optional [Tuple [int , str ]]]] = None ,
78156) -> Union [List [cwl .InputParameter ], cwl .InputParameter ]:
79157 """Find the process input parameter that matches one of the given sourcenames."""
80158 if isinstance (sourcenames , str ):
@@ -85,6 +163,8 @@ def param_for_source_id(
85163 for param in process .inputs :
86164 if param .id .split ("#" )[- 1 ] == sourcename .split ("#" )[- 1 ]:
87165 params .append (param )
166+ if scatter_context is not None :
167+ scatter_context .append (None )
88168 targets = [process ]
89169 if parent :
90170 targets .append (parent )
@@ -93,26 +173,72 @@ def param_for_source_id(
93173 for inp in target .inputs :
94174 if inp .id .split ("#" )[- 1 ] == sourcename .split ("#" )[- 1 ]:
95175 params .append (inp )
176+ if scatter_context is not None :
177+ scatter_context .append (None )
96178 for step in target .steps :
97- if sourcename .split ("#" )[- 1 ].split ("/" )[0 ] == step .id .split ("#" )[- 1 ] and step .out :
179+ if (
180+ "/" .join (sourcename .split ("#" )[- 1 ].split ("/" )[:- 1 ])
181+ == step .id .split ("#" )[- 1 ]
182+ and step .out
183+ ):
98184 for outp in step .out :
99185 outp_id = outp if isinstance (outp , str ) else outp .id
100- if outp_id .split ("#" )[- 1 ].split ("/" )[- 1 ] == sourcename .split ("#" )[- 1 ].split ("/" , 1 )[1 ]:
101- if step .run and step .run .outputs :
102- for output in step .run .outputs :
186+ if (
187+ outp_id .split ("#" )[- 1 ].split ("/" )[- 1 ]
188+ == sourcename .split ("#" )[- 1 ].split ("/" )[- 1 ]
189+ ):
190+ step_run = step .run
191+ if isinstance (step .run , str ):
192+ step_run = cwl_utils .parser .load_document_by_uri (
193+ path = target .loadingOptions .fetcher .urljoin (
194+ base_url = cast (
195+ str , target .loadingOptions .fileuri
196+ ),
197+ url = step .run ,
198+ ),
199+ loadingOptions = target .loadingOptions ,
200+ )
201+ cwl_utils .parser .utils .convert_stdstreams_to_files (
202+ step_run
203+ )
204+ if step_run and step_run .outputs :
205+ for output in step_run .outputs :
103206 if (
104- output .id .split ("#" )[- 1 ].split ('/' )[- 1 ]
105- == sourcename .split ('#' )[- 1 ].split ("/" , 1 )[ 1 ]
207+ output .id .split ("#" )[- 1 ].split ("/" )[- 1 ]
208+ == sourcename .split ("#" )[- 1 ].split ("/" )[ - 1 ]
106209 ):
107210 params .append (output )
211+ if scatter_context is not None :
212+ if isinstance (step .scatter , str ):
213+ scatter_context .append (
214+ (
215+ 1 ,
216+ step .scatterMethod
217+ or "dotproduct" ,
218+ )
219+ )
220+ elif isinstance (
221+ step .scatter , MutableSequence
222+ ):
223+ scatter_context .append (
224+ (
225+ len (step .scatter ),
226+ step .scatterMethod
227+ or "dotproduct" ,
228+ )
229+ )
230+ else :
231+ scatter_context .append (None )
108232 if len (params ) == 1 :
109233 return params [0 ]
110234 elif len (params ) > 1 :
111235 return params
112236 raise WorkflowException (
113- "param {} not found in {}\n or \n {}." .format (
237+ "param {} not found in {}\n {}." .format (
114238 sourcename ,
115239 yaml .main .round_trip_dump (cwl .save (process )),
116- yaml .main .round_trip_dump (cwl .save (parent )),
240+ " or\n {}" .format (yaml .main .round_trip_dump (cwl .save (parent )))
241+ if parent is not None
242+ else "" ,
117243 )
118244 )
0 commit comments