@@ -27,7 +27,7 @@ def optimize(
2727 model_input : str ,
2828 model_output : Path ,
2929 provider : str ,
30- controlnet : bool
30+ submodel_names : list [ str ]
3131):
3232 from google .protobuf import __version__ as protobuf_version
3333
@@ -51,10 +51,6 @@ def optimize(
5151 config .unet_sample_size = pipeline .unet .config .sample_size
5252
5353 model_info = {}
54- submodel_names = ["tokenizer" , "tokenizer_2" , "vae_encoder" , "vae_decoder" , "unet" , "text_encoder" , "text_encoder_2" ]
55-
56- if controlnet :
57- submodel_names .append ("controlnet" )
5854
5955 for submodel_name in submodel_names :
6056 if submodel_name == "tokenizer" or submodel_name == "tokenizer_2" :
@@ -81,16 +77,18 @@ def save_onnx_Models(model_dir, model_info, model_output, submodel_names):
8177 conversion_dir = model_output / conversion_type
8278 conversion_dir .mkdir (parents = True , exist_ok = True )
8379
80+ only_unet = "unet" in submodel_names and len (submodel_names ) <= 2
8481 # Copy the config and other files required by some applications
85- model_index_path = model_dir / "model_index.json"
86- if os .path .exists (model_index_path ):
87- shutil .copy (model_index_path , conversion_dir )
88- if os .path .exists (model_dir / "tokenizer" ):
89- shutil .copytree (model_dir / "tokenizer" , conversion_dir / "tokenizer" )
90- if os .path .exists (model_dir / "tokenizer_2" ):
91- shutil .copytree (model_dir / "tokenizer_2" , conversion_dir / "tokenizer_2" )
92- if os .path .exists (model_dir / "scheduler" ):
93- shutil .copytree (model_dir / "scheduler" , conversion_dir / "scheduler" )
82+ if only_unet is False :
83+ model_index_path = model_dir / "model_index.json"
84+ if os .path .exists (model_index_path ):
85+ shutil .copy (model_index_path , conversion_dir )
86+ if os .path .exists (model_dir / "tokenizer" ):
87+ shutil .copytree (model_dir / "tokenizer" , conversion_dir / "tokenizer" )
88+ if os .path .exists (model_dir / "tokenizer_2" ):
89+ shutil .copytree (model_dir / "tokenizer_2" , conversion_dir / "tokenizer_2" )
90+ if os .path .exists (model_dir / "scheduler" ):
91+ shutil .copytree (model_dir / "scheduler" , conversion_dir / "scheduler" )
9492
9593 # Save models files
9694 for submodel_name in submodel_names :
@@ -212,6 +210,8 @@ def parse_common_args(raw_args):
212210 parser .add_argument ("--controlnet" , action = "store_true" , help = "Create ControlNet Unet Model" )
213211 parser .add_argument ("--clean" , action = "store_true" , help = "Deletes the Olive cache" )
214212 parser .add_argument ("--tempdir" , default = None , type = str , help = "Root directory for tempfile directories and files" )
213+ parser .add_argument ("--only_unet" , action = "store_true" , help = "Only convert UNET model" )
214+
215215 return parser .parse_known_args (raw_args )
216216
217217
@@ -237,10 +237,17 @@ def main(raw_args=None):
237237
238238 set_tempdir (common_args .tempdir )
239239
240+ submodel_names = ["tokenizer" , "tokenizer_2" , "vae_encoder" , "vae_decoder" , "unet" , "text_encoder" , "text_encoder_2" ]
241+
242+ if common_args .only_unet :
243+ submodel_names = ["unet" ]
244+
245+ if common_args .controlnet :
246+ submodel_names .append ("controlnet" )
247+
240248 with warnings .catch_warnings ():
241249 warnings .simplefilter ("ignore" )
242- optimize (script_dir , common_args .model_input ,
243- model_output , provider , common_args .controlnet )
250+ optimize (script_dir , common_args .model_input , model_output , provider , submodel_names )
244251
245252
246253if __name__ == "__main__" :
0 commit comments