11# SPDX-License-Identifier: Apache-2.0
2+ import copy
23import hashlib
3- from typing import IO , Any , List , MutableSequence , Optional , Tuple , Union , cast
4+ import logging
5+ from collections import namedtuple
6+ from typing import Any , Dict , IO , List , MutableSequence , Optional , Tuple , Union , cast
47
58from ruamel import yaml
69from schema_salad .exceptions import ValidationException
7- from schema_salad .utils import json_dumps
10+ from schema_salad .sourceline import SourceLine
11+ from schema_salad .utils import aslist , json_dumps
812
913import cwl_utils .parser
1014import cwl_utils .parser .cwl_v1_0 as cwl
1317
1418CONTENT_LIMIT : int = 64 * 1024
1519
20+ _logger = logging .getLogger ("cwl_utils" )
21+
22+ SrcSink = namedtuple ("SrcSink" , ["src" , "sink" , "linkMerge" , "message" ])
23+
24+
25+ def _compare_records (
26+ src : cwl .RecordSchema , sink : cwl .RecordSchema , strict : bool = False
27+ ) -> bool :
28+ """
29+ Compare two records, ensuring they have compatible fields.
30+
31+ This handles normalizing record names, which will be relative to workflow
32+ step, so that they can be compared.
33+ """
34+ srcfields = {cwl .shortname (field .name ): field .type for field in (src .fields or {})}
35+ sinkfields = {
36+ cwl .shortname (field .name ): field .type for field in (sink .fields or {})
37+ }
38+ for key in sinkfields .keys ():
39+ if (
40+ not can_assign_src_to_sink (
41+ srcfields .get (key , "null" ), sinkfields .get (key , "null" ), strict
42+ )
43+ and sinkfields .get (key ) is not None
44+ ):
45+ _logger .info (
46+ "Record comparison failure for %s and %s\n "
47+ "Did not match fields for %s: %s and %s" ,
48+ cast (
49+ Union [cwl .InputRecordSchema , cwl .CommandOutputRecordSchema ], src
50+ ).name ,
51+ cast (
52+ Union [cwl .InputRecordSchema , cwl .CommandOutputRecordSchema ], sink
53+ ).name ,
54+ key ,
55+ srcfields .get (key ),
56+ sinkfields .get (key ),
57+ )
58+ return False
59+ return True
60+
1661
1762def _compare_type (type1 : Any , type2 : Any ) -> bool :
1863 if isinstance (type1 , cwl .ArraySchema ) and isinstance (type2 , cwl .ArraySchema ):
@@ -38,6 +83,115 @@ def _compare_type(type1: Any, type2: Any) -> bool:
3883 return bool (type1 == type2 )
3984
4085
86+ def can_assign_src_to_sink (src : Any , sink : Any , strict : bool = False ) -> bool :
87+ """
88+ Check for identical type specifications, ignoring extra keys like inputBinding.
89+
90+ src: admissible source types
91+ sink: admissible sink types
92+
93+ In non-strict comparison, at least one source type must match one sink type,
94+ except for 'null'.
95+ In strict comparison, all source types must match at least one sink type.
96+ """
97+ if src == "Any" or sink == "Any" :
98+ return True
99+ if isinstance (src , cwl .ArraySchema ) and isinstance (sink , cwl .ArraySchema ):
100+ return can_assign_src_to_sink (src .items , sink .items , strict )
101+ if isinstance (src , cwl .RecordSchema ) and isinstance (sink , cwl .RecordSchema ):
102+ return _compare_records (src , sink , strict )
103+ if isinstance (src , MutableSequence ):
104+ if strict :
105+ for this_src in src :
106+ if not can_assign_src_to_sink (this_src , sink ):
107+ return False
108+ return True
109+ for this_src in src :
110+ if this_src != "null" and can_assign_src_to_sink (this_src , sink ):
111+ return True
112+ return False
113+ if isinstance (sink , MutableSequence ):
114+ for this_sink in sink :
115+ if can_assign_src_to_sink (src , this_sink ):
116+ return True
117+ return False
118+ return bool (src == sink )
119+
120+
121+ def check_all_types (
122+ src_dict : Dict [str , Any ],
123+ sinks : MutableSequence [Union [cwl .WorkflowStepInput , cwl .WorkflowOutputParameter ]],
124+ type_dict : Dict [str , Any ],
125+ ) -> Dict [str , List [SrcSink ]]:
126+ """Given a list of sinks, check if their types match with the types of their sources."""
127+ validation : Dict [str , List [SrcSink ]] = {"warning" : [], "exception" : []}
128+ for sink in sinks :
129+ if isinstance (sink , cwl .WorkflowOutputParameter ):
130+ sourceName = "outputSource"
131+ sourceField = sink .outputSource
132+ elif isinstance (sink , cwl .WorkflowStepInput ):
133+ sourceName = "source"
134+ sourceField = sink .source
135+ else :
136+ continue
137+ if sourceField is not None :
138+ if isinstance (sourceField , MutableSequence ):
139+ linkMerge = sink .linkMerge or (
140+ "merge_nested" if len (sourceField ) > 1 else None
141+ )
142+ srcs_of_sink = []
143+ for parm_id in sourceField :
144+ srcs_of_sink += [src_dict [parm_id ]]
145+ else :
146+ parm_id = cast (str , sourceField )
147+ if parm_id not in src_dict :
148+ raise SourceLine (sink , sourceName , ValidationException ).makeError (
149+ f"{ sourceName } not found: { parm_id } "
150+ )
151+ srcs_of_sink = [src_dict [parm_id ]]
152+ linkMerge = None
153+ for src in srcs_of_sink :
154+ check_result = check_types (
155+ type_dict [cast (str , src .id )],
156+ type_dict [cast (str , sink .id )],
157+ linkMerge ,
158+ getattr (sink , "valueFrom" , None ),
159+ )
160+ if check_result == "warning" :
161+ validation ["warning" ].append (SrcSink (src , sink , linkMerge , None ))
162+ elif check_result == "exception" :
163+ validation ["exception" ].append (SrcSink (src , sink , linkMerge , None ))
164+ return validation
165+
166+
167+ def check_types (
168+ srctype : Any ,
169+ sinktype : Any ,
170+ linkMerge : Optional [str ],
171+ valueFrom : Optional [str ] = None ,
172+ ) -> str :
173+ """
174+ Check if the source and sink types are correct.
175+
176+ Acceptable types are "pass", "warning", or "exception".
177+ """
178+ if valueFrom is not None :
179+ return "pass"
180+ if linkMerge is None :
181+ if can_assign_src_to_sink (srctype , sinktype , strict = True ):
182+ return "pass"
183+ if can_assign_src_to_sink (srctype , sinktype , strict = False ):
184+ return "warning"
185+ return "exception"
186+ if linkMerge == "merge_nested" :
187+ return check_types (
188+ cwl .ArraySchema (items = srctype , type = "array" ), sinktype , None , None
189+ )
190+ if linkMerge == "merge_flattened" :
191+ return check_types (merge_flatten_type (srctype ), sinktype , None , None )
192+ raise ValidationException (f"Invalid value { linkMerge } for linkMerge field." )
193+
194+
41195def content_limit_respected_read_bytes (f : IO [bytes ]) -> bytes :
42196 """
43197 Read file content up to 64 kB as a byte array.
@@ -96,6 +250,59 @@ def merge_flatten_type(src: Any) -> Any:
96250 return cwl .ArraySchema (type = "array" , items = src )
97251
98252
253+ def type_for_step_input (
254+ step : cwl .WorkflowStep ,
255+ in_ : cwl .WorkflowStepInput ,
256+ ) -> Any :
257+ """Determine the type for the given step input."""
258+ if in_ .valueFrom is not None :
259+ return "Any"
260+ step_run = cwl_utils .parser .utils .load_step (step )
261+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
262+ if step_run and step_run .inputs :
263+ for step_input in step_run .inputs :
264+ if (
265+ cast (str , step_input .id ).split ("#" )[- 1 ]
266+ == cast (str , in_ .id ).split ("#" )[- 1 ]
267+ ):
268+ input_type = step_input .type
269+ if step .scatter is not None and in_ .id in aslist (step .scatter ):
270+ input_type = cwl .ArraySchema (items = input_type , type = "array" )
271+ return input_type
272+ return "Any"
273+
274+
275+ def type_for_step_output (
276+ step : cwl .WorkflowStep ,
277+ sourcename : str ,
278+ ) -> Any :
279+ """Determine the type for the given step output."""
280+ step_run = cwl_utils .parser .utils .load_step (step )
281+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
282+ if step_run and step_run .outputs :
283+ for step_output in step_run .outputs :
284+ if (
285+ step_output .id .split ("#" )[- 1 ].split ("/" )[- 1 ]
286+ == sourcename .split ("#" )[- 1 ].split ("/" )[- 1 ]
287+ ):
288+ output_type = step_output .type
289+ if step .scatter is not None :
290+ if step .scatterMethod == "nested_crossproduct" :
291+ for _ in range (len (aslist (step .scatter ))):
292+ output_type = cwl .ArraySchema (
293+ items = output_type , type = "array"
294+ )
295+ else :
296+ output_type = cwl .ArraySchema (items = output_type , type = "array" )
297+ return output_type
298+ raise ValidationException (
299+ "param {} not found in {}." .format (
300+ sourcename ,
301+ yaml .main .round_trip_dump (cwl .save (step_run )),
302+ )
303+ )
304+
305+
99306def type_for_source (
100307 process : Union [cwl .CommandLineTool , cwl .Workflow , cwl .ExpressionTool ],
101308 sourcenames : Union [str , List [str ]],
@@ -142,7 +349,7 @@ def type_for_source(
142349 return cwl .ArraySchema (items = new_type , type = "array" )
143350 elif linkMerge == "merge_flattened" :
144351 return merge_flatten_type (new_type )
145- elif isinstance (sourcenames , List ):
352+ elif isinstance (sourcenames , List ) and len ( sourcenames ) > 1 :
146353 return cwl .ArraySchema (items = new_type , type = "array" )
147354 else :
148355 return new_type
@@ -181,26 +388,14 @@ def param_for_source_id(
181388 == step .id .split ("#" )[- 1 ]
182389 and step .out
183390 ):
391+ step_run = cwl_utils .parser .utils .load_step (step )
392+ cwl_utils .parser .utils .convert_stdstreams_to_files (step_run )
184393 for outp in step .out :
185394 outp_id = outp if isinstance (outp , str ) else outp .id
186395 if (
187396 outp_id .split ("#" )[- 1 ].split ("/" )[- 1 ]
188397 == sourcename .split ("#" )[- 1 ].split ("/" )[- 1 ]
189398 ):
190- step_run = step .run
191- if isinstance (step .run , str ):
192- step_run = cwl_utils .parser .load_document_by_uri (
193- path = target .loadingOptions .fetcher .urljoin (
194- base_url = cast (
195- str , target .loadingOptions .fileuri
196- ),
197- url = step .run ,
198- ),
199- loadingOptions = target .loadingOptions ,
200- )
201- cwl_utils .parser .utils .convert_stdstreams_to_files (
202- step_run
203- )
204399 if step_run and step_run .outputs :
205400 for output in step_run .outputs :
206401 if (
0 commit comments