1717from diffusers .pipelines .flux .pipeline_output import FluxPipelineOutput
1818from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion import retrieve_timesteps # TODO
1919
20- from QEfficient .diffusers .pipelines .config_manager import config_manager
20+ from QEfficient .diffusers .pipelines .config_manager import config_manager , set_module_device_ids
2121from QEfficient .diffusers .pipelines .pipeline_utils import (
2222 QEffClipTextEncoder ,
2323 QEffFluxTransformerModel ,
@@ -37,85 +37,27 @@ class QEFFFluxPipeline(FluxPipeline):
3737 It provides methods for text-to-image generation leveraging these optimized components.
3838 """
3939
40- def __init__ (
41- self ,
42- model = None ,
43- use_onnx_function = False ,
44- text_encoder = None ,
45- text_encoder_2 = None ,
46- transformer = None ,
47- vae = None ,
48- tokenizer = None ,
49- tokenizer_2 = None ,
50- scheduler = None ,
51- height : Optional [int ] = 512 ,
52- width : Optional [int ] = 512 ,
53- * args ,
54- ** kwargs ,
55- ):
56- # Validate input: either model or individual components
57- has_model = model is not None
58- has_components = all (
59- [
60- text_encoder is not None ,
61- text_encoder_2 is not None ,
62- transformer is not None ,
63- vae is not None ,
64- tokenizer is not None ,
65- tokenizer_2 is not None ,
66- scheduler is not None ,
67- ]
68- )
69-
70- if not has_model and not has_components :
71- raise ValueError (
72- "Either provide 'model' parameter OR all individual components "
73- "(text_encoder, text_encoder_2, transformer, vae, "
74- "tokenizer, tokenizer_2, scheduler)"
75- )
76-
77- if has_model and has_components :
78- raise ValueError ("Cannot provide both 'model' and individual components. Choose one approach." )
79-
80- # Use ONNX subfunction to optimize the onnx graph
40+ def __init__ (self , model , use_onnx_function , * args , ** kwargs ):
41+ self .text_encoder = QEffClipTextEncoder (model .text_encoder )
42+ self .text_encoder_2 = QEffTextEncoder (model .text_encoder_2 )
43+ self .transformer = QEffFluxTransformerModel (model .transformer , use_onnx_function = use_onnx_function )
44+ self .vae_decode = QEffVAE (model , "decoder" )
8145 self .use_onnx_function = use_onnx_function
82-
83- self .text_encoder = (
84- QEffClipTextEncoder (text_encoder ) if has_components else QEffClipTextEncoder (model .text_encoder )
85- )
86- self .text_encoder_2 = (
87- QEffTextEncoder (text_encoder_2 ) if has_components else QEffTextEncoder (model .text_encoder_2 )
88- )
89- self .transformer = QEffFluxTransformerModel (
90- transformer if has_components else model .transformer , self .use_onnx_function
91- )
92- self .vae_decode = QEffVAE (vae if has_components else model , "decoder" )
93-
9446 self .has_module = [
9547 ("text_encoder" , self .text_encoder ),
9648 ("text_encoder_2" , self .text_encoder_2 ),
9749 ("transformer" , self .transformer ),
9850 ("vae_decoder" , self .vae_decode ),
9951 ]
10052
101- self .tokenizer = tokenizer if has_components else model .tokenizer
102- self .text_encoder .tokenizer = tokenizer if has_components else model .tokenizer
103- self .text_encoder_2 .tokenizer = tokenizer_2 if has_components else model .tokenizer_2
53+ self .tokenizer = model .tokenizer
54+ self .text_encoder .tokenizer = model .tokenizer
55+ self .text_encoder_2 .tokenizer = model .tokenizer_2
10456 self .tokenizer_max_length = model .tokenizer_max_length
10557 self .scheduler = model .scheduler
10658
107- self .height = height
108- self .width = width
109-
110- self .register_modules (
111- vae = self .vae_decode ,
112- text_encoder = self .text_encoder ,
113- text_encoder_2 = self .text_encoder_2 ,
114- tokenizer = self .tokenizer ,
115- tokenizer_2 = self .text_encoder_2 .tokenizer ,
116- transformer = self .transformer ,
117- scheduler = self .scheduler ,
118- )
59+ self .height = kwargs .get ("height" , 256 )
60+ self .width = kwargs .get ("width" , 256 )
11961
12062 self .vae_decode .model .forward = lambda latent_sample , return_dict : self .vae_decode .model .decode (
12163 latent_sample , return_dict
@@ -200,7 +142,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
200142
201143 def compile (
202144 self ,
203- * ,
204145 compile_config : Optional [str ] = None ,
205146 ) -> str :
206147 """
@@ -226,11 +167,12 @@ def compile(
226167 self .export ()
227168
228169 # Initialize configuration manager (JSON-only approach)
229- config_manager (self , config_source = compile_config )
170+ if self .custom_config is None :
171+ config_manager (self , config_source = compile_config )
230172
231173 for module_name , module_obj in self .has_module :
232174 # Get specialization values directly from config
233- module_config = self ._compile_config ["modules" ]
175+ module_config = self .custom_config ["modules" ]
234176 specializations = [module_config [module_name ]["specializations" ]]
235177
236178 # Get compilation parameters from configuration
@@ -396,8 +338,6 @@ def encode_prompt(
396338 prompt_embeds : Optional [torch .FloatTensor ] = None ,
397339 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
398340 max_sequence_length : int = 512 ,
399- device_ids_text_encoder_1 : Optional [List [int ]] = None ,
400- device_ids_text_encoder_2 : Optional [List [int ]] = None ,
401341 ):
402342 r"""
403343 Encode the given prompts into text embeddings using the two text encoders (CLIP and T5).
@@ -433,14 +373,14 @@ def encode_prompt(
433373 # We only use the pooled prompt output from the CLIPTextModel
434374 pooled_prompt_embeds = self ._get_clip_prompt_embeds (
435375 prompt = prompt ,
436- device_ids = device_ids_text_encoder_1 ,
376+ device_ids = self . text_encoder . device_ids ,
437377 num_images_per_prompt = num_images_per_prompt ,
438378 )
439379 prompt_embeds = self ._get_t5_prompt_embeds (
440380 prompt = prompt_2 ,
441381 num_images_per_prompt = num_images_per_prompt ,
442382 max_sequence_length = max_sequence_length ,
443- device_ids = device_ids_text_encoder_2 ,
383+ device_ids = self . text_encoder_2 . device_ids ,
444384 )
445385
446386 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 )
@@ -453,8 +393,6 @@ def __call__(
453393 negative_prompt : Union [str , List [str ]] = None ,
454394 negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
455395 true_cfg_scale : float = 1.0 ,
456- height : Optional [int ] = None ,
457- width : Optional [int ] = None ,
458396 num_inference_steps : int = 28 ,
459397 timesteps : List [int ] = None ,
460398 guidance_scale : float = 3.5 ,
@@ -471,10 +409,7 @@ def __call__(
471409 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
472410 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
473411 max_sequence_length : int = 512 ,
474- device_ids_text_encoder_1 : Optional [List [int ]] = None ,
475- device_ids_text_encoder_2 : Optional [List [int ]] = None ,
476- device_ids_transformer : Optional [List [int ]] = None ,
477- device_ids_vae_decoder : Optional [List [int ]] = None ,
412+ custom_config_path : Optional [str ] = None ,
478413 ):
479414 r"""
480415 Function invoked when calling the pipeline for generation.
@@ -551,10 +486,6 @@ def __call__(
551486 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
552487 `._callback_tensor_inputs` attribute of your pipeline class.
553488 max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
554- device_ids_text_encoder1 (List[int], optional): List of device IDs to use for CLIP instance.
555- device_ids_text_encoder2 (List[int], optional): List of device IDs to use for T5.
556- device_ids_transformer (List[int], optional): List of device IDs to use for Flux transformer.
557- device_ids_vae_decoder (List[int], optional): List of device IDs to use for VAE decoder.
558489
559490 Returns:
560491 [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
@@ -578,11 +509,16 @@ def __call__(
578509 image.save("flux-schnell_aic.png")
579510 ```
580511 """
581- height = self .height
582- width = self .width
583512 device = "cpu"
584513
585514 # 1. Check inputs. Raise error if not correct
515+ if custom_config_path is not None :
516+ config_manager (self , custom_config_path )
517+ set_module_device_ids (self )
518+
519+ # Calling compile with custom config
520+ self .compile (compile_config = custom_config_path )
521+
586522 self .check_inputs (
587523 prompt ,
588524 prompt_2 ,
@@ -624,8 +560,6 @@ def __call__(
624560 pooled_prompt_embeds = pooled_prompt_embeds ,
625561 num_images_per_prompt = num_images_per_prompt ,
626562 max_sequence_length = max_sequence_length ,
627- device_ids_text_encoder_1 = device_ids_text_encoder_1 ,
628- device_ids_text_encoder_2 = device_ids_text_encoder_2 ,
629563 )
630564 if do_true_cfg :
631565 (
@@ -639,8 +573,6 @@ def __call__(
639573 pooled_prompt_embeds = negative_pooled_prompt_embeds ,
640574 num_images_per_prompt = num_images_per_prompt ,
641575 max_sequence_length = max_sequence_length ,
642- device_ids_text_encoder_1 = device_ids_text_encoder_1 ,
643- device_ids_text_encoder_2 = device_ids_text_encoder_2 ,
644576 )
645577
646578 # 4. Prepare timesteps
@@ -665,7 +597,7 @@ def __call__(
665597 ###### AIC related changes of transformers ######
666598 if self .transformer .qpc_session is None :
667599 self .transformer .qpc_session = QAICInferenceSession (
668- str (self .transformer .qpc_path ), device_ids = device_ids_transformer
600+ str (self .transformer .qpc_path ), device_ids = self . transformer . device_ids
669601 )
670602
671603 output_buffer = {
@@ -728,40 +660,13 @@ def __call__(
728660 "adaln_out" : adaln_out .detach ().numpy (),
729661 }
730662
731- # noise_pred_torch = self.transformer.model(
732- # hidden_states=latents,
733- # encoder_hidden_states = prompt_embeds,
734- # pooled_projections=pooled_prompt_embeds,
735- # timestep=torch.tensor(timestep),
736- # img_ids = latent_image_ids,
737- # txt_ids = text_ids,
738- # adaln_emb = adaln_dual_emb,
739- # adaln_single_emb=adaln_single_emb,
740- # adaln_out = adaln_out,
741- # return_dict=False,
742- # )[0]
743-
744663 start_time = time .time ()
745664 outputs = self .transformer .qpc_session .run (inputs_aic )
746665 end_time = time .time ()
747666 print (f"Time : { end_time - start_time :.2f} seconds" )
748667
749- # ########################## Onnx
750- # input_names = [i.name for i in session.get_inputs()]
751- # output_names = [o.name for o in session.get_outputs()]
752- # start_time = time.time()
753- # ort_pred = session.run(None, {k: v for k, v in inputs_aic.items() if k in input_names})
754- # end_time = time.time()
755- # print(f"Onnx positive transformer denoising Time: {end_time - start_time:.2f} seconds")
756-
757668 noise_pred = torch .from_numpy (outputs ["output" ])
758669
759- # # # # ###### ACCURACY TESTING #######
760- # mad=np.max(np.abs(noise_pred_torch.detach().numpy()-outputs['output']))
761- # print(f">>>>>>>>> at t = {t} FLUX transfromer model MAD pytorch vs QAIC :{mad}")
762- # mad_O=np.max(np.abs(ort_pred[0]- outputs['output']))
763- # print(f">>>>>>>>> FLUX transfromer model MAD Onnxrt vs QAIC : {mad_O}")
764-
765670 # compute the previous noisy sample x_t -> x_t-1
766671 latents_dtype = latents .dtype
767672 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
@@ -792,7 +697,7 @@ def __call__(
792697
793698 if self .vae_decode .qpc_session is None :
794699 self .vae_decode .qpc_session = QAICInferenceSession (
795- str (self .vae_decode .qpc_path ), device_ids = device_ids_vae_decoder
700+ str (self .vae_decode .qpc_path ), device_ids = self . vae_decode . device_ids
796701 )
797702
798703 output_buffer = {
@@ -808,17 +713,9 @@ def __call__(
808713 inputs = {"latent_sample" : latents .numpy ()}
809714 image = self .vae_decode .qpc_session .run (inputs )
810715
811- # ##### ACCURACY TESTING #######
812- # image_torch = self.vae_decode.model(latents, return_dict=False)[0]
813- # mad= np.max(np.abs(image['sample']-image_torch.detach().numpy()))
814- # print(">>>>>>>>>>>> VAE mad: ",mad)
815-
816716 image_tensor = torch .from_numpy (image ["sample" ])
817717 image = self .image_processor .postprocess (image_tensor , output_type = output_type )
818718
819- # Offload all models
820- # self.maybe_free_model_hooks()
821-
822719 if not return_dict :
823720 return (image ,)
824721
0 commit comments