1616 DiffusionPipeline ,
1717 ControlNetModel
1818)
19+ from diffusionkit .tests .torch2coreml import (
20+ convert_mmdit_to_mlpackage ,
21+ convert_vae_to_mlpackage
22+ )
1923import gc
24+ from huggingface_hub import snapshot_download
2025
2126import logging
2227
@@ -207,6 +212,26 @@ def _compile_coreml_model(source_model_path, output_dir, final_name):
207212 return target_path
208213
209214
215+ def _download_t5_model (args , t5_save_path ):
216+ t5_url = args .text_encoder_t5_url
217+ match = re .match (r'https://huggingface.co/(.+)/resolve/main/(.+)' , t5_url )
218+ if not match :
219+ raise ValueError (f"Invalid Hugging Face URL: { t5_url } " )
220+ repo_id , model_subpath = match .groups ()
221+
222+ download_path = snapshot_download (
223+ repo_id = repo_id ,
224+ revision = "main" ,
225+ allow_patterns = [f"{ model_subpath } /*" ]
226+ )
227+ logger .info (f"Downloaded T5 model to { download_path } " )
228+
229+ # Move the downloaded model to the top level of the Resources directory
230+ logger .info (f"Copying T5 model from { download_path } to { t5_save_path } " )
231+ cache_path = os .path .join (download_path , model_subpath )
232+ shutil .copytree (cache_path , t5_save_path )
233+
234+
210235def bundle_resources_for_swift_cli (args ):
211236 """
212237 - Compiles Core ML models from mlpackage into mlmodelc format
@@ -228,6 +253,7 @@ def bundle_resources_for_swift_cli(args):
228253 ("refiner" , "UnetRefiner" ),
229254 ("refiner_chunk1" , "UnetRefinerChunk1" ),
230255 ("refiner_chunk2" , "UnetRefinerChunk2" ),
256+ ("mmdit" , "MultiModalDiffusionTransformer" ),
231257 ("control-unet" , "ControlledUnet" ),
232258 ("control-unet_chunk1" , "ControlledUnetChunk1" ),
233259 ("control-unet_chunk2" , "ControlledUnetChunk2" ),
@@ -241,7 +267,7 @@ def bundle_resources_for_swift_cli(args):
241267 logger .warning (
242268 f"{ source_path } not found, skipping compilation to { target_name } .mlmodelc"
243269 )
244-
270+
245271 if args .convert_controlnet :
246272 for controlnet_model_version in args .convert_controlnet :
247273 controlnet_model_name = controlnet_model_version .replace ("/" , "_" )
@@ -271,6 +297,25 @@ def bundle_resources_for_swift_cli(args):
271297 f .write (requests .get (args .text_encoder_merges_url ).content )
272298 logger .info ("Done" )
273299
300+ # Fetch and save pre-converted T5 text encoder model
301+ t5_model_name = "TextEncoderT5.mlmodelc"
302+ t5_save_path = os .path .join (resources_dir , t5_model_name )
303+ if args .include_t5 :
304+ if not os .path .exists (t5_save_path ):
305+ logger .info ("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc" )
306+ _download_t5_model (args , t5_save_path )
307+ logger .info ("Done" )
308+ else :
309+ logger .info (f"Skipping T5 download as { t5_save_path } already exists" )
310+
311+ # Fetch and save T5 text tokenizer JSON files
312+ logger .info ("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json" )
313+ with open (os .path .join (resources_dir , "tokenizer_config.json" ), "wb" ) as f :
314+ f .write (requests .get (args .text_encoder_t5_config_url ).content )
315+ with open (os .path .join (resources_dir , "tokenizer.json" ), "wb" ) as f :
316+ f .write (requests .get (args .text_encoder_t5_data_url ).content )
317+ logger .info ("Done" )
318+
274319 return resources_dir
275320
276321
@@ -557,6 +602,61 @@ def forward(self, z):
557602 del traced_vae_decoder , pipe .vae .decoder , coreml_vae_decoder
558603 gc .collect ()
559604
605+ def convert_vae_decoder_sd3 (args ):
606+ """ Converts the VAE component of Stable Diffusion 3
607+ """
608+ out_path = _get_out_path (args , "vae_decoder" )
609+ if os .path .exists (out_path ):
610+ logger .info (
611+ f"`vae_decoder` already exists at { out_path } , skipping conversion."
612+ )
613+ return
614+
615+ # Convert the VAE Decoder model via DiffusionKit
616+ converted_vae_path = convert_vae_to_mlpackage (
617+ model_version = args .model_version ,
618+ latent_h = args .latent_h ,
619+ latent_w = args .latent_w ,
620+ output_dir = args .o ,
621+ )
622+
623+ # Load converted model
624+ coreml_vae_decoder = ct .models .MLModel (converted_vae_path )
625+
626+ # Set model metadata
627+ coreml_vae_decoder .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
628+ coreml_vae_decoder .license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
629+ coreml_vae_decoder .version = args .model_version
630+ coreml_vae_decodershort_description = \
631+ "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
632+ "Please refer to https://arxiv.org/pdf/2403.03206 for details."
633+
634+ # Set the input descriptions
635+ coreml_vae_decoder .input_description ["z" ] = \
636+ "The denoised latent embeddings from the unet model after the last step of reverse diffusion"
637+
638+ # Set the output descriptions
639+ coreml_vae_decoder .output_description [
640+ "image" ] = "Generated image normalized to range [-1, 1]"
641+
642+ # Set package version metadata
643+ from python_coreml_stable_diffusion ._version import __version__
644+ coreml_vae_decoder .user_defined_metadata ["com.github.apple.ml-stable-diffusion.version" ] = __version__
645+ from diffusionkit .version import __version__
646+ coreml_vae_decoder .user_defined_metadata ["com.github.argmax.diffusionkit.version" ] = __version__
647+
648+ # Save the updated model
649+ coreml_vae_decoder .save (out_path )
650+
651+ logger .info (f"Saved vae_decoder into { out_path } " )
652+
653+ # Delete the original file
654+ if os .path .exists (converted_vae_path ):
655+ shutil .rmtree (converted_vae_path )
656+
657+ del coreml_vae_decoder
658+ gc .collect ()
659+
560660
561661def convert_vae_encoder (pipe , args ):
562662 """ Converts the VAE Encoder component of Stable Diffusion
@@ -909,6 +1009,72 @@ def convert_unet(pipe, args, model_name = None):
9091009 chunk_mlprogram .main (args )
9101010
9111011
1012+ def convert_mmdit (args ):
1013+ """ Converts the MMDiT component of Stable Diffusion 3
1014+ """
1015+ out_path = _get_out_path (args , "mmdit" )
1016+ if os .path .exists (out_path ):
1017+ logger .info (
1018+ f"`mmdit` already exists at { out_path } , skipping conversion."
1019+ )
1020+ return
1021+
1022+ # Convert the MMDiT model via DiffusionKit
1023+ converted_mmdit_path = convert_mmdit_to_mlpackage (
1024+ model_version = args .model_version ,
1025+ latent_h = args .latent_h ,
1026+ latent_w = args .latent_w ,
1027+ output_dir = args .o ,
1028+ # FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
1029+ compute_precision = ct .precision .FLOAT32 ,
1030+ compute_unit = ct .ComputeUnit .CPU_AND_GPU ,
1031+ )
1032+
1033+ # Load converted model
1034+ coreml_mmdit = ct .models .MLModel (converted_mmdit_path )
1035+
1036+ # Set model metadata
1037+ coreml_mmdit .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
1038+ coreml_mmdit .license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
1039+ coreml_mmdit .version = args .model_version
1040+ coreml_mmdit .short_description = \
1041+ "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
1042+ "Please refer to https://arxiv.org/pdf/2403.03206 for details."
1043+
1044+ # Set the input descriptions
1045+ coreml_mmdit .input_description ["latent_image_embeddings" ] = \
1046+ "The low resolution latent feature maps being denoised through reverse diffusion"
1047+ coreml_mmdit .input_description ["token_level_text_embeddings" ] = \
1048+ "Output embeddings from the associated text_encoder model to condition to generated image on text. " \
1049+ "A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. "
1050+ coreml_mmdit .input_description ["pooled_text_embeddings" ] = \
1051+ "Additional embeddings that if specified are added to the embeddings that are passed along to the MMDiT model."
1052+ coreml_mmdit .input_description ["timestep" ] = \
1053+ "A value emitted by the associated scheduler object to condition the model on a given noise schedule"
1054+
1055+ # Set the output descriptions
1056+ coreml_mmdit .output_description ["denoiser_output" ] = \
1057+ "Same shape and dtype as the `latent_image_embeddings` input. " \
1058+ "The predicted noise to facilitate the reverse diffusion (denoising) process"
1059+
1060+ # Set package version metadata
1061+ from python_coreml_stable_diffusion ._version import __version__
1062+ coreml_mmdit .user_defined_metadata ["com.github.apple.ml-stable-diffusion.version" ] = __version__
1063+ from diffusionkit .version import __version__
1064+ coreml_mmdit .user_defined_metadata ["com.github.argmax.diffusionkit.version" ] = __version__
1065+
1066+ # Save the updated model
1067+ coreml_mmdit .save (out_path )
1068+
1069+ logger .info (f"Saved vae_decoder into { out_path } " )
1070+
1071+ # Delete the original file
1072+ if os .path .exists (converted_mmdit_path ):
1073+ shutil .rmtree (converted_mmdit_path )
1074+
1075+ del coreml_mmdit
1076+ gc .collect ()
1077+
9121078def convert_safety_checker (pipe , args ):
9131079 """ Converts the Safety Checker component of Stable Diffusion
9141080 """
@@ -1288,6 +1454,16 @@ def get_pipeline(args):
12881454 use_safetensors = True ,
12891455 vae = vae ,
12901456 use_auth_token = True )
1457+ elif args .sd3_version :
1458+ # SD3 uses standard SDXL diffusers pipeline besides the vae, denoiser, and T5 text encoder
1459+ sdxl_base_version = "stabilityai/stable-diffusion-xl-base-1.0"
1460+ args .xl_version = True
1461+ logger .info (f"SD3 version specified, initializing DiffusionPipeline with { sdxl_base_version } for non-SD3 components.." )
1462+ pipe = DiffusionPipeline .from_pretrained (sdxl_base_version ,
1463+ torch_dtype = torch .float16 ,
1464+ variant = "fp16" ,
1465+ use_safetensors = True ,
1466+ use_auth_token = True )
12911467 else :
12921468 pipe = DiffusionPipeline .from_pretrained (model_version ,
12931469 torch_dtype = torch .float16 ,
@@ -1316,7 +1492,10 @@ def main(args):
13161492 # Convert models
13171493 if args .convert_vae_decoder :
13181494 logger .info ("Converting vae_decoder" )
1319- convert_vae_decoder (pipe , args )
1495+ if args .sd3_version :
1496+ convert_vae_decoder_sd3 (args )
1497+ else :
1498+ convert_vae_decoder (pipe , args )
13201499 logger .info ("Converted vae_decoder" )
13211500
13221501 if args .convert_vae_encoder :
@@ -1363,6 +1542,11 @@ def main(args):
13631542 del pipe
13641543 gc .collect ()
13651544 logger .info (f"Converted refiner" )
1545+
1546+ if args .convert_mmdit :
1547+ logger .info ("Converting mmdit" )
1548+ convert_mmdit (args )
1549+ logger .info ("Converted mmdit" )
13661550
13671551 if args .quantize_nbits is not None :
13681552 logger .info (f"Quantizing weights to { args .quantize_nbits } -bit precision" )
@@ -1383,6 +1567,7 @@ def parser_spec():
13831567 parser .add_argument ("--convert-vae-decoder" , action = "store_true" )
13841568 parser .add_argument ("--convert-vae-encoder" , action = "store_true" )
13851569 parser .add_argument ("--convert-unet" , action = "store_true" )
1570+ parser .add_argument ("--convert-mmdit" , action = "store_true" )
13861571 parser .add_argument ("--convert-safety-checker" , action = "store_true" )
13871572 parser .add_argument (
13881573 "--convert-controlnet" ,
@@ -1489,6 +1674,7 @@ def parser_spec():
14891674 "If specified, enable unet to receive additional inputs from controlnet. "
14901675 "Each input added to corresponding resnet output."
14911676 )
1677+ parser .add_argument ("--include-t5" , action = "store_true" )
14921678
14931679 # Swift CLI Resource Bundling
14941680 parser .add_argument (
@@ -1508,11 +1694,30 @@ def parser_spec():
15081694 default =
15091695 "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt" ,
15101696 help = "The URL to the merged pairs used in by the text tokenizer." )
1697+ parser .add_argument (
1698+ "--text-encoder-t5-url" ,
1699+ default =
1700+ "https://huggingface.co/argmaxinc/coreml-stable-diffusion-3-medium/resolve/main/TextEncoderT5.mlmodelc" ,
1701+ help = "The URL to the pre-converted T5 encoder model." )
1702+ parser .add_argument (
1703+ "--text-encoder-t5-config-url" ,
1704+ default =
1705+ "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer_config.json" ,
1706+ help = "The URL to the merged pairs used in by the text tokenizer." )
1707+ parser .add_argument (
1708+ "--text-encoder-t5-data-url" ,
1709+ default =
1710+ "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json" ,
1711+ help = "The URL to the merged pairs used in by the text tokenizer." )
15111712 parser .add_argument (
15121713 "--xl-version" ,
15131714 action = "store_true" ,
15141715 help = ("If specified, the pre-trained model will be treated as an instantiation of "
15151716 "`diffusers.pipelines.StableDiffusionXLPipeline` instead of `diffusers.pipelines.StableDiffusionPipeline`" ))
1717+ parser .add_argument (
1718+ "--sd3-version" ,
1719+ action = "store_true" ,
1720+ help = ("If specified, the pre-trained model will be treated as an SD3 model." ))
15161721
15171722 return parser
15181723
0 commit comments