@@ -604,8 +604,7 @@ def get_coreml_pipe(pytorch_pipe,
604604 "tokenizer" : pytorch_pipe .tokenizer ,
605605 'tokenizer_2' : pytorch_pipe .tokenizer_2 ,
606606 "scheduler" : pytorch_pipe .scheduler if scheduler_override is None else scheduler_override ,
607- "force_zeros_for_empty_prompt" : force_zeros_for_empty_prompt ,
608- 'xl' : True
607+ 'xl' : True ,
609608 }
610609
611610 model_packages_to_load = ["text_encoder" , "text_encoder_2" , "unet" , "vae_decoder" ]
@@ -618,6 +617,8 @@ def get_coreml_pipe(pytorch_pipe,
618617 }
619618 model_packages_to_load = ["text_encoder" , "unet" , "vae_decoder" ]
620619
620+ coreml_pipe_kwargs ["force_zeros_for_empty_prompt" ] = force_zeros_for_empty_prompt
621+
621622 if getattr (pytorch_pipe , "safety_checker" , None ) is not None :
622623 model_packages_to_load .append ("safety_checker" )
623624 else :
@@ -713,7 +714,7 @@ def main(args):
713714
714715 # Get Force Zeros Config if it exists
715716 force_zeros_for_empty_prompt : bool = False
716- if 'force_zeros_for_empty_prompt' in pytorch_pipe .config :
717+ if 'xl' in args . model_version and ' force_zeros_for_empty_prompt' in pytorch_pipe .config :
717718 force_zeros_for_empty_prompt = pytorch_pipe .config ['force_zeros_for_empty_prompt' ]
718719
719720 coreml_pipe = get_coreml_pipe (
0 commit comments