@@ -93,7 +93,7 @@ def optimize(
9393 model_input : str ,
9494 model_output : Path ,
9595 provider : str ,
96- controlnet : bool
96+ image_encoder : bool
9797):
9898 from google .protobuf import __version__ as protobuf_version
9999
@@ -109,7 +109,6 @@ def optimize(
109109 shutil .rmtree (script_dir / "footprints" , ignore_errors = True )
110110 shutil .rmtree (model_output , ignore_errors = True )
111111
112-
113112 # Load the entire PyTorch pipeline to ensure all models and their configurations are downloaded and cached.
114113 # This avoids an issue where the non-ONNX components (tokenizer, scheduler, and feature extractor) are not
115114 # automatically cached correctly if individual models are fetched one at a time.
@@ -121,15 +120,10 @@ def optimize(
121120
122121 model_info = {}
123122
124- submodel_names = [ "text_encoder" , "decoder" , "prior" , "image_encoder" ]
125-
126- has_safety_checker = getattr (pipeline , "safety_checker" , None ) is not None
127-
128- if has_safety_checker :
129- submodel_names .append ("safety_checker" )
123+ submodel_names = [ "text_encoder" , "decoder" , "prior" , "vqgan" ]
130124
131- if controlnet :
132- submodel_names .append ("controlnet " )
125+ if image_encoder :
126+ submodel_names .append ("image_encoder " )
133127
134128 for submodel_name in submodel_names :
135129 print (f"\n Optimizing { submodel_name } " )
@@ -138,14 +132,7 @@ def optimize(
138132 with (script_dir / f"config_{ submodel_name } .json" ).open () as fin :
139133 olive_config = json .load (fin )
140134 olive_config = update_config_with_provider (olive_config , provider )
141-
142- if submodel_name in ("unet" , "controlnet" , "text_encoder" ):
143- olive_config ["input_model" ]["config" ]["model_path" ] = model_dir
144- else :
145- # Only the unet & text encoder are affected by LoRA, so it's better to use the base model ID for
146- # other models: the Olive cache is based on the JSON config, and two LoRA variants with the same
147- # base model ID should be able to reuse previously optimized copies.
148- olive_config ["input_model" ]["config" ]["model_path" ] = model_dir
135+ olive_config ["input_model" ]["config" ]["model_path" ] = model_dir
149136
150137 run_res = olive_run (olive_config )
151138
@@ -156,52 +143,22 @@ def optimize(
156143 from sd_utils .ort import save_onnx_pipeline
157144
158145 save_onnx_pipeline (
159- has_safety_checker , model_info , model_output , pipeline , submodel_names
146+ model_info , model_output , pipeline , submodel_names
160147 )
161148
162149 return model_info
163150
164151
165152def parse_common_args (raw_args ):
166153 parser = argparse .ArgumentParser ("Common arguments" )
167-
168154 parser .add_argument ("--model_input" , default = "stable-diffusion-v1-5" , type = str )
169155 parser .add_argument ("--model_output" , default = "stable-diffusion-v1-5" , type = Path )
170- parser .add_argument ("--controlnet" ,action = "store_true" , help = "Create ControlNet Unet Model" )
171- parser .add_argument (
172- "--provider" , default = "dml" , type = str , choices = ["dml" , "cuda" ], help = "Execution provider to use"
173- )
156+ parser .add_argument ("--image_encoder" ,action = "store_true" , help = "Create image encoder model" )
157+ parser .add_argument ("--provider" , default = "dml" , type = str , choices = ["dml" , "cuda" ], help = "Execution provider to use" )
174158 parser .add_argument ("--optimize" , action = "store_true" , help = "Runs the optimization step" )
175159 parser .add_argument ("--clean_cache" , action = "store_true" , help = "Deletes the Olive cache" )
176160 parser .add_argument ("--test_unoptimized" , action = "store_true" , help = "Use unoptimized model for inference" )
177- parser .add_argument ("--batch_size" , default = 1 , type = int , help = "Number of images to generate per batch" )
178- parser .add_argument (
179- "--prompt" ,
180- default = (
181- "castle surrounded by water and nature, village, volumetric lighting, photorealistic, "
182- "detailed and intricate, fantasy, epic cinematic shot, mountains, 8k ultra hd"
183- ),
184- type = str ,
185- )
186- parser .add_argument (
187- "--guidance_scale" ,
188- default = 7.5 ,
189- type = float ,
190- help = "Guidance scale as defined in Classifier-Free Diffusion Guidance" ,
191- )
192- parser .add_argument ("--num_images" , default = 1 , type = int , help = "Number of images to generate" )
193- parser .add_argument ("--num_inference_steps" , default = 50 , type = int , help = "Number of steps in diffusion process" )
194161 parser .add_argument ("--tempdir" , default = None , type = str , help = "Root directory for tempfile directories and files" )
195- parser .add_argument (
196- "--strength" ,
197- default = 1.0 ,
198- type = float ,
199- help = "Value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. "
200- "Values that approach 1.0 enable lots of variations but will also produce images "
201- "that are not semantically consistent with the input." ,
202- )
203- parser .add_argument ("--image_size" , default = 512 , type = int , help = "Width and height of the images to generate" )
204-
205162 return parser .parse_known_args (raw_args )
206163
207164
@@ -231,8 +188,6 @@ def main(raw_args=None):
231188 if common_args .clean_cache :
232189 shutil .rmtree (script_dir / "cache" , ignore_errors = True )
233190
234- guidance_scale = common_args .guidance_scale
235-
236191 ort_args = None , None
237192 ort_args , extra_args = parse_ort_args (extra_args )
238193
@@ -246,27 +201,10 @@ def main(raw_args=None):
246201 from sd_utils .ort import validate_args
247202
248203 validate_args (ort_args , common_args .provider )
249- optimize (common_args .model_input , common_args .model_output , common_args .provider , common_args .controlnet )
204+ optimize (common_args .model_input , common_args .model_output , common_args .provider , common_args .image_encoder )
250205
251206 if not common_args .optimize :
252- model_dir = model_output / "F32" if common_args .test_unoptimized else model_output / "F16"
253- with warnings .catch_warnings ():
254- warnings .simplefilter ("ignore" )
255-
256- from sd_utils .ort import get_ort_pipeline
257-
258- pipeline = get_ort_pipeline (model_dir , common_args , ort_args , guidance_scale )
259- run_inference_loop (
260- pipeline ,
261- common_args .prompt ,
262- common_args .num_images ,
263- common_args .batch_size ,
264- common_args .image_size ,
265- common_args .num_inference_steps ,
266- guidance_scale ,
267- common_args .strength ,
268- provider = provider ,
269- )
207+ print ("TODO: Create OnnxStableCascadePipeline" )
270208
271209
272210if __name__ == "__main__" :
0 commit comments