@@ -426,107 +426,6 @@ def get_model_config(self) -> dict:
426426 return self .model .model .vision_model .config .__dict__
427427
428428
429- class QEffSD3Transformer2DModel (QEFFBaseModel ):
430- _pytorch_transforms = [AttentionTransform , CustomOpsTransform ]
431- _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
432-
433- """
434- QEffSD3Transformer2DModel is a wrapper class for Stable Diffusion 3 Transformer2D models that provides ONNX export and compilation capabilities.
435-
436- This class extends QEFFBaseModel to handle SD3 Transformer2D models with specific transformations and optimizations
437- for efficient inference on Qualcomm AI hardware. It is designed for the newer Stable Diffusion 3 architecture
438- that uses transformer-based diffusion models instead of traditional UNet architectures.
439- """
440-
441- def __init__ (self , model : nn .modules ):
442- super ().__init__ (model )
443- self .model = model
444-
445- def get_onnx_config (self ):
446- example_inputs = {
447- "hidden_states" : torch .randn (
448- 2 ,
449- self .model .config .in_channels ,
450- self .model .config .sample_size ,
451- self .model .config .sample_size ,
452- ),
453- "encoder_hidden_states" : torch .randn (2 , 333 , self .model .config .joint_attention_dim ),
454- "pooled_projections" : torch .randn (2 , self .model .config .pooled_projection_dim ),
455- "timestep" : torch .randint (0 , 20 , (2 ,), dtype = torch .int64 ),
456- }
457-
458- output_names = ["output" ]
459-
460- dynamic_axes = {
461- "hidden_states" : {0 : "batch_size" , 1 : "latent_channels" , 2 : "latent_height" , 3 : "latent_width" },
462- "encoder_hidden_states" : {0 : "batch_size" , 1 : "seq_len" },
463- "pooled_projections" : {0 : "batch_size" },
464- "timestep" : {0 : "steps" },
465- "output" : {0 : "batch_size" , 1 : "latent_channels" , 2 : "latent_height" , 3 : "latent_width" },
466- }
467- return example_inputs , dynamic_axes , output_names
468-
469- def export (self , inputs , output_names , dynamic_axes , export_dir = None ):
470- return self ._export (inputs , output_names , dynamic_axes , export_dir )
471-
472- def get_specializations (
473- self ,
474- batch_size : int ,
475- seq_len : int ,
476- ):
477- specializations = [
478- {
479- "batch_size" : 2 * batch_size ,
480- "latent_channels" : 16 ,
481- "latent_height" : self .model .config .sample_size ,
482- "latent_width" : self .model .config .sample_size ,
483- "seq_len" : seq_len ,
484- "steps" : 1 ,
485- }
486- ]
487-
488- return specializations
489-
490- def compile (
491- self ,
492- compile_dir ,
493- compile_only ,
494- specializations ,
495- convert_to_fp16 ,
496- mxfp6_matmul ,
497- mdp_ts_num_devices ,
498- aic_num_cores ,
499- custom_io ,
500- ** compiler_options ,
501- ) -> str :
502- return self ._compile (
503- compile_dir = compile_dir ,
504- compile_only = compile_only ,
505- specializations = specializations ,
506- convert_to_fp16 = convert_to_fp16 ,
507- mxfp6_matmul = mxfp6_matmul ,
508- mdp_ts_num_devices = mdp_ts_num_devices ,
509- aic_num_cores = aic_num_cores ,
510- custom_io = custom_io ,
511- ** compiler_options ,
512- )
513-
514- @property
515- def model_hash (self ) -> str :
516- # Compute the hash with: model_config, continuous_batching, transforms
517- mhash = hashlib .sha256 ()
518- mhash .update (to_hashable (dict (self .model .config )))
519- mhash .update (to_hashable (self ._transform_names ()))
520- mhash = mhash .hexdigest ()[:16 ]
521- return mhash
522-
523- @property
524- def model_name (self ) -> str :
525- mname = self .model .__class__ .__name__
526- if mname .startswith ("QEff" ) or mname .startswith ("QEFF" ):
527- mname = mname [4 :]
528- return mname
529-
530429class QEffFluxTransformerModel (QEFFBaseModel ):
531430 _pytorch_transforms = [AttentionTransform , CustomOpsTransform , NormalizationTransform ]
532431 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
@@ -618,7 +517,9 @@ def compile(
618517 def model_hash (self ) -> str :
619518 # Compute the hash with: model_config, continuous_batching, transforms
620519 mhash = hashlib .sha256 ()
621- mhash .update (to_hashable (dict (self .model .config )))
520+ dict_model_config = dict (self .model .config )
521+ dict_model_config .pop ("_use_default_values" , None )
522+ mhash .update (to_hashable (dict_model_config ))
622523 mhash .update (to_hashable (self ._transform_names ()))
623524 mhash = mhash .hexdigest ()[:16 ]
624525 return mhash
0 commit comments