2121 QEffTextEncoder ,
2222 QEffVAE ,
2323)
24- from QEfficient .diffusers .pipelines .pipeline_utils import QEffPipelineOutput , config_manager , set_module_device_ids
24+ from QEfficient .diffusers .pipelines .pipeline_utils import (
25+ ModulePerf ,
26+ QEffPipelineOutput ,
27+ config_manager ,
28+ set_module_device_ids ,
29+ )
2530from QEfficient .generation .cloud_infer import QAICInferenceSession
2631
2732
@@ -42,13 +47,13 @@ def __init__(self, model, use_onnx_function, *args, **kwargs):
4247 self .vae_decode = QEffVAE (model , "decoder" )
4348 self .use_onnx_function = use_onnx_function
4449
45- # Add all modules of FluxPipeline
46- self .has_module = [
47- ( "text_encoder" , self .text_encoder ) ,
48- ( "text_encoder_2" , self .text_encoder_2 ) ,
49- ( "transformer" , self .transformer ) ,
50- ( "vae_decoder" , self .vae_decode ) ,
51- ]
50+ # All modules of FluxPipeline stored in a dictionary for easy access and iteration
51+ self .modules = {
52+ "text_encoder" : self .text_encoder ,
53+ "text_encoder_2" : self .text_encoder_2 ,
54+ "transformer" : self .transformer ,
55+ "vae_decoder" : self .vae_decode ,
56+ }
5257
5358 self .tokenizer = model .tokenizer
5459 self .text_encoder .tokenizer = model .tokenizer
@@ -127,7 +132,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
127132 :str: Path of the generated ``ONNX`` graph.
128133 """
129134
130- for module_name , module_obj in self .has_module :
135+ for module_name , module_obj in self .modules . items () :
131136 example_inputs_text_encoder , dynamic_axes_text_encoder , output_names_text_encoder = (
132137 module_obj .get_onnx_config ()
133138 )
@@ -183,7 +188,7 @@ def compile(
183188 if self .custom_config is None :
184189 config_manager (self , config_source = compile_config )
185190
186- for module_name , module_obj in self .has_module :
191+ for module_name , module_obj in self .modules . items () :
187192 # Get specialization values directly from config
188193 module_config = self .custom_config ["modules" ]
189194 specializations = [module_config [module_name ]["specializations" ]]
@@ -256,19 +261,18 @@ def _get_t5_prompt_embeds(
256261 self .text_encoder_2 .qpc_session .set_buffers (text_encoder_2_output )
257262
258263 aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
259- import time
260264
261265 start_t5_time = time .time ()
262266 prompt_embeds = torch .tensor (self .text_encoder_2 .qpc_session .run (aic_text_input )["last_hidden_state" ])
263267 end_t5_time = time .time ()
264- self . text_encoder_2 . inference_time = end_t5_time - start_t5_time
268+ text_encoder_2_perf = end_t5_time - start_t5_time
265269
266270 _ , seq_len , _ = prompt_embeds .shape
267271 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
268272 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
269273 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
270274
271- return prompt_embeds
275+ return prompt_embeds , text_encoder_2_perf
272276
273277 def _get_clip_prompt_embeds (
274278 self ,
@@ -322,20 +326,17 @@ def _get_clip_prompt_embeds(
322326
323327 aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
324328
325- import time
326-
327- global start_text_encoder_time
328329 start_text_encoder_time = time .time ()
329330 aic_embeddings = self .text_encoder .qpc_session .run (aic_text_input )
330331 end_text_encoder_time = time .time ()
331- self . text_encoder . inference_time = end_text_encoder_time - start_text_encoder_time
332+ text_encoder_perf = end_text_encoder_time - start_text_encoder_time
332333 prompt_embeds = torch .tensor (aic_embeddings ["pooler_output" ])
333334
334335 # duplicate text embeddings for each generation per prompt, using mps friendly method
335336 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt )
336337 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , - 1 )
337338
338- return prompt_embeds
339+ return prompt_embeds , text_encoder_perf
339340
340341 def encode_prompt (
341342 self ,
@@ -378,20 +379,20 @@ def encode_prompt(
378379 prompt_2 = [prompt_2 ] if isinstance (prompt_2 , str ) else prompt_2
379380
380381 # We only use the pooled prompt output from the CLIPTextModel
381- pooled_prompt_embeds = self ._get_clip_prompt_embeds (
382+ pooled_prompt_embeds , text_encoder_perf = self ._get_clip_prompt_embeds (
382383 prompt = prompt ,
383384 device_ids = self .text_encoder .device_ids ,
384385 num_images_per_prompt = num_images_per_prompt ,
385386 )
386- prompt_embeds = self ._get_t5_prompt_embeds (
387+ prompt_embeds , text_encoder_2_perf = self ._get_t5_prompt_embeds (
387388 prompt = prompt_2 ,
388389 num_images_per_prompt = num_images_per_prompt ,
389390 max_sequence_length = max_sequence_length ,
390391 device_ids = self .text_encoder_2 .device_ids ,
391392 )
392393
393394 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 )
394- return prompt_embeds , pooled_prompt_embeds , text_ids
395+ return prompt_embeds , pooled_prompt_embeds , text_ids , [ text_encoder_perf , text_encoder_2_perf ]
395396
396397 def __call__ (
397398 self ,
@@ -539,18 +540,15 @@ def __call__(
539540 negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
540541 )
541542 do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
542- (
543- prompt_embeds ,
544- pooled_prompt_embeds ,
545- text_ids ,
546- ) = self .encode_prompt (
543+ (prompt_embeds , pooled_prompt_embeds , text_ids , text_encoder_perf ) = self .encode_prompt (
547544 prompt = prompt ,
548545 prompt_2 = prompt_2 ,
549546 prompt_embeds = prompt_embeds ,
550547 pooled_prompt_embeds = pooled_prompt_embeds ,
551548 num_images_per_prompt = num_images_per_prompt ,
552549 max_sequence_length = max_sequence_length ,
553550 )
551+
554552 if do_true_cfg :
555553 (
556554 negative_prompt_embeds ,
@@ -595,7 +593,7 @@ def __call__(
595593 }
596594
597595 self .transformer .qpc_session .set_buffers (output_buffer )
598- self . transformer . inference_time = []
596+ transformer_perf = []
599597 self .scheduler .set_begin_index (0 )
600598 with self .progress_bar (total = num_inference_steps ) as progress_bar :
601599 for i , t in enumerate (timesteps ):
@@ -653,7 +651,7 @@ def __call__(
653651 start_transformer_step_time = time .time ()
654652 outputs = self .transformer .qpc_session .run (inputs_aic )
655653 end_transfromer_step_time = time .time ()
656- self . transformer . inference_time .append (end_transfromer_step_time - start_transformer_step_time )
654+ transformer_perf .append (end_transfromer_step_time - start_transformer_step_time )
657655
658656 noise_pred = torch .from_numpy (outputs ["output" ])
659657
@@ -678,7 +676,6 @@ def __call__(
678676 # call the callback, if provided
679677 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
680678 progress_bar .update ()
681-
682679 if output_type == "latent" :
683680 image = latents
684681 else :
@@ -704,14 +701,22 @@ def __call__(
704701 start_decode_time = time .time ()
705702 image = self .vae_decode .qpc_session .run (inputs )
706703 end_decode_time = time .time ()
707- self . vae_decode . inference_time = end_decode_time - start_decode_time
704+ vae_decode_perf = end_decode_time - start_decode_time
708705 image_tensor = torch .from_numpy (image ["sample" ])
709706 image = self .image_processor .postprocess (image_tensor , output_type = output_type )
710707
711- total_time_taken = end_decode_time - start_text_encoder_time
708+ # Collect performance data in a dict
709+ perf_data = {
710+ "text_encoder" : text_encoder_perf [0 ],
711+ "text_encoder_2" : text_encoder_perf [1 ],
712+ "transformer" : transformer_perf ,
713+ "vae_decoder" : vae_decode_perf ,
714+ }
712715
713- return QEffPipelineOutput (
714- pipeline = self ,
715- images = image ,
716- E2E_time = total_time_taken ,
717- )
716+ # Build performance metrics dynamically
717+ perf_metrics = [ModulePerf (module_name = name , perf = perf_data [name ]) for name in self .modules .keys ()]
718+
719+ return QEffPipelineOutput (
720+ pipeline_module = perf_metrics ,
721+ images = image ,
722+ )
0 commit comments