@@ -101,6 +101,7 @@ def _get_op_idx_split_location(prog: Program):
101101 """ Find the op that approximately bisects the graph as measure by weights size on each side
102102 """
103103 main_block = prog .functions ["main" ]
104+ main_block .operations = list (main_block .operations )
104105 total_size_in_mb = 0
105106
106107 for op in main_block .operations :
@@ -132,6 +133,7 @@ def _get_first_chunk_outputs(block, op_idx):
132133 # to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output
133134 # of the first program and the input of the second program
134135 boundary_vars = set ()
136+ block .operations = list (block .operations )
135137 for i in range (op_idx + 1 ):
136138 op = block .operations [i ]
137139 if not op .op_type .startswith ("const" ):
@@ -181,6 +183,7 @@ def _make_second_chunk_prog(prog, op_idx):
181183 boundary_vars = _get_first_chunk_outputs (block , op_idx )
182184
183185 # This op will not be included in this program. Its output var will be made into an input
186+ block .operations = list (block .operations )
184187 boundary_op = block .operations [op_idx ]
185188
186189 # Add all boundary ops as inputs
@@ -228,7 +231,8 @@ def _make_second_chunk_prog(prog, op_idx):
228231 return prog
229232
230233
231- def main (args ):
234+ def _legacy_model_chunking (args ):
235+ # TODO: Remove this method after setting the coremltools dependency >= 8.0
232236 os .makedirs (args .o , exist_ok = True )
233237
234238 # Check filename extension
@@ -307,13 +311,6 @@ def main(args):
307311 second_chunk_model = model_chunk2 ,
308312 )
309313
310- # Remove original (non-chunked) model if requested
311- if args .remove_original :
312- logger .info (
313- "Removing original (non-chunked) model at {args.mlpackage_path}" )
314- shutil .rmtree (args .mlpackage_path )
315- logger .info ("Done." )
316-
317314 if args .merge_chunks_in_pipeline_model :
318315 # Make a single pipeline model to manage the model chunks
319316 pipeline_model = ct .utils .make_pipeline (model_chunk1 , model_chunk2 )
@@ -342,6 +339,39 @@ def main(args):
342339 logger .info ("Done." )
343340
344341
342+ def main (args ):
343+ ct_version = ct .__version__
344+
345+ if ct_version != "8.0b2" and ct_version < "8.0" :
346+ # With coremltools version <= 8.0b1,
347+ # we use the legacy implementation.
348+ # TODO: Remove the logic after setting the coremltools dependency >= 8.0.
349+ logger .info (
350+ f"coremltools version { ct_version } detected. Recommended upgrading the package version to "
351+ f"'8.0b2' when you running chunk_mlprogram.py script for the latest supports and bug fixes."
352+ )
353+ _legacy_model_chunking (args )
354+ else :
355+ # Starting from coremltools==8.0b2, there is this `bisect_model` API that
356+ # we can directly call into.
357+ from coremltools .models .utils import bisect_model
358+ logger .info (f"Start chunking model { args .mlpackage_path } into two pieces." )
359+ ct .models .utils .bisect_model (
360+ model = args .mlpackage_path ,
361+ output_dir = args .o ,
362+ merge_chunks_to_pipeline = args .merge_chunks_in_pipeline_model ,
363+ check_output_correctness = args .check_output_correctness ,
364+ )
365+ logger .info (f"Model chunking is done." )
366+
367+ # Remove original (non-chunked) model if requested
368+ if args .remove_original :
369+ logger .info (
370+ "Removing original (non-chunked) model at {args.mlpackage_path}" )
371+ shutil .rmtree (args .mlpackage_path )
372+ logger .info ("Done." )
373+
374+
345375if __name__ == "__main__" :
346376 parser = argparse .ArgumentParser ()
347377 parser .add_argument (
0 commit comments