1010
1111import argparse
1212import json
13+ import logging
1314import os
15+ import re
1416import sys
1517from collections .abc import MutableMapping
16- from typing import IO , TYPE_CHECKING , Any , Union , cast
18+ from io import TextIOWrapper
19+ from pathlib import Path
20+ from typing import (
21+ IO ,
22+ Any ,
23+ Union ,
24+ cast ,
25+ )
1726
1827from cwlformat .formatter import stringify_dict
19- from ruamel .yaml .dumper import RoundTripDumper
20- from ruamel .yaml .main import YAML , dump
28+ from ruamel .yaml .main import YAML
2129from ruamel .yaml .representer import RoundTripRepresenter
2230from schema_salad .sourceline import SourceLine , add_lc_filename
2331
24- if TYPE_CHECKING :
25- from _typeshed import StrPath
32+ from cwl_utils .loghandler import _logger as _cwlutilslogger
33+
34+ _logger = logging .getLogger ("cwl-graph-split" ) # pylint: disable=invalid-name
35+ defaultStreamHandler = logging .StreamHandler () # pylint: disable=invalid-name
36+ _logger .addHandler (defaultStreamHandler )
37+ _logger .setLevel (logging .INFO )
38+ _cwlutilslogger .setLevel (100 )
2639
2740
2841def arg_parser () -> argparse .ArgumentParser :
@@ -73,7 +86,7 @@ def run(args: list[str]) -> int:
7386 with open (options .cwlfile ) as source_handle :
7487 graph_split (
7588 source_handle ,
76- options .outdir ,
89+ Path ( options .outdir ) ,
7790 options .output_format ,
7891 options .mainfile ,
7992 options .pretty ,
@@ -83,7 +96,7 @@ def run(args: list[str]) -> int:
8396
8497def graph_split (
8598 sourceIO : IO [str ],
86- output_dir : "StrPath" ,
99+ output_dir : Path ,
87100 output_format : str ,
88101 mainfile : str ,
89102 pretty : bool ,
@@ -100,6 +113,13 @@ def graph_split(
100113
101114 version = source .pop ("cwlVersion" )
102115
116+ # Check outdir parent exists
117+ if not output_dir .parent .is_dir ():
118+ raise NotADirectoryError (f"Parent directory of { output_dir } does not exist" )
119+ # If output_dir is not a directory, create it
120+ if not output_dir .is_dir ():
121+ output_dir .mkdir ()
122+
103123 def my_represent_none (
104124 self : Any , data : Any
105125 ) -> Any : # pylint: disable=unused-argument
@@ -111,7 +131,7 @@ def my_represent_none(
111131 for entry in source ["$graph" ]:
112132 entry_id = entry .pop ("id" ).lstrip ("#" )
113133 entry ["cwlVersion" ] = version
114- imports = rewrite (entry , entry_id )
134+ imports = rewrite (entry , entry_id , output_dir )
115135 if imports :
116136 for import_name in imports :
117137 rewrite_types (entry , f"#{ import_name } " , False )
@@ -121,25 +141,28 @@ def my_represent_none(
121141 else :
122142 entry_id = mainfile
123143
124- output_file = os . path . join ( output_dir , entry_id + ".cwl" )
144+ output_file = output_dir / ( re . sub ( ".cwl$" , "" , entry_id ) + ".cwl" )
125145 if output_format == "json" :
126146 json_dump (entry , output_file )
127147 elif output_format == "yaml" :
128- yaml_dump (entry , output_file , pretty )
148+ with output_file .open ("w" , encoding = "utf-8" ) as output_handle :
149+ yaml_dump (entry , output_handle , pretty )
129150
130151
131- def rewrite (document : Any , doc_id : str ) -> set [str ]:
152+ def rewrite (
153+ document : Any , doc_id : str , output_dir : Path , pretty : bool = False
154+ ) -> set [str ]:
132155 """Rewrite the given element from the CWL $graph."""
133156 imports = set ()
134157 if isinstance (document , list ) and not isinstance (document , str ):
135158 for entry in document :
136- imports .update (rewrite (entry , doc_id ))
159+ imports .update (rewrite (entry , doc_id , output_dir , pretty ))
137160 elif isinstance (document , dict ):
138161 this_id = document ["id" ] if "id" in document else None
139162 for key , value in document .items ():
140163 with SourceLine (document , key , Exception ):
141164 if key == "run" and isinstance (value , str ) and value [0 ] == "#" :
142- document [key ] = f"{ value [1 :]} .cwl"
165+ document [key ] = f"{ re . sub ( '.cwl$' , '' , value [1 :]) } .cwl"
143166 elif key in ("id" , "outputSource" ) and value .startswith ("#" + doc_id ):
144167 document [key ] = value [len (doc_id ) + 2 :]
145168 elif key == "out" and isinstance (value , list ):
@@ -179,15 +202,15 @@ def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
179202 elif key == "$import" :
180203 rewrite_import (document )
181204 elif key == "class" and value == "SchemaDefRequirement" :
182- return rewrite_schemadef (document )
205+ return rewrite_schemadef (document , output_dir , pretty )
183206 else :
184- imports .update (rewrite (value , doc_id ))
207+ imports .update (rewrite (value , doc_id , output_dir , pretty ))
185208 return imports
186209
187210
188211def rewrite_import (document : MutableMapping [str , Any ]) -> None :
189212 """Adjust the $import directive."""
190- external_file = document ["$import" ].split ("/" )[0 ][ 1 :]
213+ external_file = document ["$import" ].split ("/" )[0 ]. lstrip ( "#" )
191214 document ["$import" ] = external_file
192215
193216
@@ -215,22 +238,25 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
215238 rewrite_types (entry , entry_file , sameself )
216239
217240
218- def rewrite_schemadef (document : MutableMapping [str , Any ]) -> set [str ]:
241+ def rewrite_schemadef (
242+ document : MutableMapping [str , Any ], output_dir : Path , pretty : bool = False
243+ ) -> set [str ]:
219244 """Dump the schemadefs to their own file."""
220245 for entry in document ["types" ]:
221246 if "$import" in entry :
222247 rewrite_import (entry )
223248 elif "name" in entry and "/" in entry ["name" ]:
224- entry_file , entry ["name" ] = entry ["name" ].split ("/" )
225- for field in entry [ "fields" ] :
249+ entry_file , entry ["name" ] = entry ["name" ].lstrip ( "#" ). split ("/" )
250+ for field in entry . get ( "fields" , []) :
226251 field ["name" ] = field ["name" ].split ("/" )[2 ]
227252 rewrite_types (field , entry_file , True )
228- with open ( entry_file [ 1 :], "a" , encoding = "utf-8" ) as entry_handle :
229- dump ([ entry ] , entry_handle , Dumper = RoundTripDumper )
230- entry ["$import" ] = entry_file [ 1 :]
253+ with ( output_dir / entry_file ). open ( "a" , encoding = "utf-8" ) as entry_handle :
254+ yaml_dump ( entry , entry_handle , pretty )
255+ entry ["$import" ] = entry_file
231256 del entry ["name" ]
232257 del entry ["type" ]
233- del entry ["fields" ]
258+ if "fields" in entry :
259+ del entry ["fields" ]
234260 seen_imports = set ()
235261
236262 def seen_import (entry : MutableMapping [str , Any ]) -> bool :
@@ -247,26 +273,26 @@ def seen_import(entry: MutableMapping[str, Any]) -> bool:
247273 return seen_imports
248274
249275
250- def json_dump (entry : Any , output_file : str ) -> None :
276+ def json_dump (entry : Any , output_file : Path ) -> None :
251277 """Output object as JSON."""
252- with open (output_file , "w" , encoding = "utf-8" ) as result_handle :
278+ with output_file . open ("w" , encoding = "utf-8" ) as result_handle :
253279 json .dump (entry , result_handle , indent = 4 )
254280
255281
256- def yaml_dump (entry : Any , output_file : str , pretty : bool ) -> None :
282+ def yaml_dump (
283+ entry : Any ,
284+ output_handle : TextIOWrapper ,
285+ pretty : bool ,
286+ ) -> None :
257287 """Output object as YAML."""
258- yaml = YAML (typ = "rt" )
288+ if pretty :
289+ output_handle .write (stringify_dict (entry ))
290+ return
291+ yaml = YAML (typ = "rt" , pure = True )
259292 yaml .default_flow_style = False
260- yaml .map_indent = 4
261- yaml .sequence_indent = 2
262- with open (output_file , "w" , encoding = "utf-8" ) as result_handle :
263- if pretty :
264- result_handle .write (stringify_dict (entry ))
265- else :
266- yaml .dump (
267- entry ,
268- result_handle ,
269- )
293+ yaml .indent = 4
294+ yaml .block_seq_indent = 2
295+ yaml .dump (entry , output_handle )
270296
271297
272298if __name__ == "__main__" :
0 commit comments