Skip to content

Commit b969ef8

Browse files
author
Amit Raj
committed
Replaced output dict with dataclass to make it more user friendly
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 284d83d commit b969ef8

File tree

4 files changed

+79
-70
lines changed

4 files changed

+79
-70
lines changed

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
QEffTextEncoder,
2222
QEffVAE,
2323
)
24-
from QEfficient.diffusers.pipelines.pipeline_utils import QEffPipelineOutput, config_manager, set_module_device_ids
24+
from QEfficient.diffusers.pipelines.pipeline_utils import (
25+
ModulePerf,
26+
QEffPipelineOutput,
27+
config_manager,
28+
set_module_device_ids,
29+
)
2530
from QEfficient.generation.cloud_infer import QAICInferenceSession
2631

2732

@@ -42,13 +47,13 @@ def __init__(self, model, use_onnx_function, *args, **kwargs):
4247
self.vae_decode = QEffVAE(model, "decoder")
4348
self.use_onnx_function = use_onnx_function
4449

45-
# Add all modules of FluxPipeline
46-
self.has_module = [
47-
("text_encoder", self.text_encoder),
48-
("text_encoder_2", self.text_encoder_2),
49-
("transformer", self.transformer),
50-
("vae_decoder", self.vae_decode),
51-
]
50+
# All modules of FluxPipeline stored in a dictionary for easy access and iteration
51+
self.modules = {
52+
"text_encoder": self.text_encoder,
53+
"text_encoder_2": self.text_encoder_2,
54+
"transformer": self.transformer,
55+
"vae_decoder": self.vae_decode,
56+
}
5257

5358
self.tokenizer = model.tokenizer
5459
self.text_encoder.tokenizer = model.tokenizer
@@ -127,7 +132,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
127132
:str: Path of the generated ``ONNX`` graph.
128133
"""
129134

130-
for module_name, module_obj in self.has_module:
135+
for module_name, module_obj in self.modules.items():
131136
example_inputs_text_encoder, dynamic_axes_text_encoder, output_names_text_encoder = (
132137
module_obj.get_onnx_config()
133138
)
@@ -183,7 +188,7 @@ def compile(
183188
if self.custom_config is None:
184189
config_manager(self, config_source=compile_config)
185190

186-
for module_name, module_obj in self.has_module:
191+
for module_name, module_obj in self.modules.items():
187192
# Get specialization values directly from config
188193
module_config = self.custom_config["modules"]
189194
specializations = [module_config[module_name]["specializations"]]
@@ -256,19 +261,18 @@ def _get_t5_prompt_embeds(
256261
self.text_encoder_2.qpc_session.set_buffers(text_encoder_2_output)
257262

258263
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
259-
import time
260264

261265
start_t5_time = time.time()
262266
prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
263267
end_t5_time = time.time()
264-
self.text_encoder_2.inference_time = end_t5_time - start_t5_time
268+
text_encoder_2_perf = end_t5_time - start_t5_time
265269

266270
_, seq_len, _ = prompt_embeds.shape
267271
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
268272
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
269273
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
270274

271-
return prompt_embeds
275+
return prompt_embeds, text_encoder_2_perf
272276

273277
def _get_clip_prompt_embeds(
274278
self,
@@ -322,20 +326,17 @@ def _get_clip_prompt_embeds(
322326

323327
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
324328

325-
import time
326-
327-
global start_text_encoder_time
328329
start_text_encoder_time = time.time()
329330
aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
330331
end_text_encoder_time = time.time()
331-
self.text_encoder.inference_time = end_text_encoder_time - start_text_encoder_time
332+
text_encoder_perf = end_text_encoder_time - start_text_encoder_time
332333
prompt_embeds = torch.tensor(aic_embeddings["pooler_output"])
333334

334335
# duplicate text embeddings for each generation per prompt, using mps friendly method
335336
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
336337
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
337338

338-
return prompt_embeds
339+
return prompt_embeds, text_encoder_perf
339340

340341
def encode_prompt(
341342
self,
@@ -378,20 +379,20 @@ def encode_prompt(
378379
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
379380

380381
# We only use the pooled prompt output from the CLIPTextModel
381-
pooled_prompt_embeds = self._get_clip_prompt_embeds(
382+
pooled_prompt_embeds, text_encoder_perf = self._get_clip_prompt_embeds(
382383
prompt=prompt,
383384
device_ids=self.text_encoder.device_ids,
384385
num_images_per_prompt=num_images_per_prompt,
385386
)
386-
prompt_embeds = self._get_t5_prompt_embeds(
387+
prompt_embeds, text_encoder_2_perf = self._get_t5_prompt_embeds(
387388
prompt=prompt_2,
388389
num_images_per_prompt=num_images_per_prompt,
389390
max_sequence_length=max_sequence_length,
390391
device_ids=self.text_encoder_2.device_ids,
391392
)
392393

393394
text_ids = torch.zeros(prompt_embeds.shape[1], 3)
394-
return prompt_embeds, pooled_prompt_embeds, text_ids
395+
return prompt_embeds, pooled_prompt_embeds, text_ids, [text_encoder_perf, text_encoder_2_perf]
395396

396397
def __call__(
397398
self,
@@ -539,18 +540,15 @@ def __call__(
539540
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
540541
)
541542
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
542-
(
543-
prompt_embeds,
544-
pooled_prompt_embeds,
545-
text_ids,
546-
) = self.encode_prompt(
543+
(prompt_embeds, pooled_prompt_embeds, text_ids, text_encoder_perf) = self.encode_prompt(
547544
prompt=prompt,
548545
prompt_2=prompt_2,
549546
prompt_embeds=prompt_embeds,
550547
pooled_prompt_embeds=pooled_prompt_embeds,
551548
num_images_per_prompt=num_images_per_prompt,
552549
max_sequence_length=max_sequence_length,
553550
)
551+
554552
if do_true_cfg:
555553
(
556554
negative_prompt_embeds,
@@ -595,7 +593,7 @@ def __call__(
595593
}
596594

597595
self.transformer.qpc_session.set_buffers(output_buffer)
598-
self.transformer.inference_time = []
596+
transformer_perf = []
599597
self.scheduler.set_begin_index(0)
600598
with self.progress_bar(total=num_inference_steps) as progress_bar:
601599
for i, t in enumerate(timesteps):
@@ -653,7 +651,7 @@ def __call__(
653651
start_transformer_step_time = time.time()
654652
outputs = self.transformer.qpc_session.run(inputs_aic)
655653
end_transfromer_step_time = time.time()
656-
self.transformer.inference_time.append(end_transfromer_step_time - start_transformer_step_time)
654+
transformer_perf.append(end_transfromer_step_time - start_transformer_step_time)
657655

658656
noise_pred = torch.from_numpy(outputs["output"])
659657

@@ -678,7 +676,6 @@ def __call__(
678676
# call the callback, if provided
679677
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
680678
progress_bar.update()
681-
682679
if output_type == "latent":
683680
image = latents
684681
else:
@@ -704,14 +701,22 @@ def __call__(
704701
start_decode_time = time.time()
705702
image = self.vae_decode.qpc_session.run(inputs)
706703
end_decode_time = time.time()
707-
self.vae_decode.inference_time = end_decode_time - start_decode_time
704+
vae_decode_perf = end_decode_time - start_decode_time
708705
image_tensor = torch.from_numpy(image["sample"])
709706
image = self.image_processor.postprocess(image_tensor, output_type=output_type)
710707

711-
total_time_taken = end_decode_time - start_text_encoder_time
708+
# Collect performance data in a dict
709+
perf_data = {
710+
"text_encoder": text_encoder_perf[0],
711+
"text_encoder_2": text_encoder_perf[1],
712+
"transformer": transformer_perf,
713+
"vae_decoder": vae_decode_perf,
714+
}
712715

713-
return QEffPipelineOutput(
714-
pipeline=self,
715-
images=image,
716-
E2E_time=total_time_taken,
717-
)
716+
# Build performance metrics dynamically
717+
perf_metrics = [ModulePerf(module_name=name, perf=perf_data[name]) for name in self.modules.keys()]
718+
719+
return QEffPipelineOutput(
720+
pipeline_module=perf_metrics,
721+
images=image,
722+
)

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@
77

88
import os
99
from dataclasses import dataclass
10-
from typing import TYPE_CHECKING, List, Optional, Union
10+
from typing import List, Optional, Union
1111

1212
import numpy as np
1313
import PIL.Image
1414

1515
from QEfficient.utils._utils import load_json
1616

17-
if TYPE_CHECKING:
18-
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
19-
2017

2118
def config_manager(cls, config_source: Optional[str] = None):
2219
"""
@@ -50,51 +47,58 @@ def set_module_device_ids(cls):
5047
from the configuration file to each module's device_ids attribute.
5148
"""
5249
config_modules = cls.custom_config["modules"]
53-
for module_name, module_obj in cls.has_module:
50+
for module_name, module_obj in cls.modules.items():
5451
module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
5552

53+
@dataclass(frozen=True)
54+
class ModulePerf:
55+
module_name: str
56+
perf: int
57+
5658

57-
@dataclass
59+
@dataclass(frozen=True)
5860
class QEffPipelineOutput:
59-
pipeline: "QEFFFluxPipeline"
61+
pipeline_module: list[ModulePerf]
6062
images: Union[List[PIL.Image.Image], np.ndarray]
61-
E2E_time: int
6263

6364
def __repr__(self):
6465
output_str = "=" * 60 + "\n"
6566
output_str += "QEfficient Diffusers Pipeline Inference Report\n"
6667
output_str += "=" * 60 + "\n\n"
6768

68-
# End-to-End time
69-
output_str += f"End-to-End Inference Time: {self.E2E_time:.4f} s\n\n"
70-
7169
# Module-wise inference times
7270
output_str += "Module-wise Inference Times:\n"
7371
output_str += "-" * 60 + "\n"
7472

75-
# Iterate through all modules using has_module
76-
for module_name, module_obj in self.pipeline.has_module:
77-
if hasattr(module_obj, "inference_time"):
78-
inference_time = module_obj.inference_time
79-
80-
# Format module name for display
81-
display_name = module_name.replace("_", " ").title()
82-
83-
# Handle transformer specially as it has a list of times
84-
if isinstance(inference_time, list) and len(inference_time) > 0:
85-
total_time = sum(inference_time)
86-
avg_time = total_time / len(inference_time)
87-
output_str += f" {display_name:25s} {total_time:.4f} s\n"
88-
output_str += f" - Total steps: {len(inference_time)}\n"
89-
output_str += f" - Average per step: {avg_time:.4f} s\n"
90-
output_str += f" - Min step time: {min(inference_time):.4f} s\n"
91-
output_str += f" - Max step time: {max(inference_time):.4f} s\n"
92-
else:
93-
# Single inference time value
94-
output_str += f" {display_name:25s} {inference_time:.4f} s\n"
73+
# Calculate E2E time while iterating
74+
e2e_time = 0
75+
for module_perf in self.pipeline_module:
76+
module_name = module_perf.module_name
77+
inference_time = module_perf.perf
78+
79+
# Add to E2E time
80+
e2e_time += sum(inference_time) if isinstance(inference_time, list) else inference_time
81+
82+
# Format module name for display
83+
display_name = module_name.replace("_", " ").title()
84+
85+
# Handle transformer specially as it has a list of times
86+
if isinstance(inference_time, list) and len(inference_time) > 0:
87+
total_time = sum(inference_time)
88+
avg_time = total_time / len(inference_time)
89+
output_str += f" {display_name:25s} {total_time:.4f} s\n"
90+
output_str += f" - Total steps: {len(inference_time)}\n"
91+
output_str += f" - Average per step: {avg_time:.4f} s\n"
92+
output_str += f" - Min step time: {min(inference_time):.4f} s\n"
93+
output_str += f" - Max step time: {max(inference_time):.4f} s\n"
94+
else:
95+
# Single inference time value
96+
output_str += f" {display_name:25s} {inference_time:.4f} s\n"
9597

9698
output_str += "-" * 60 + "\n\n"
9799

100+
# Print E2E time after all modules
101+
output_str += f"End-to-End Inference Time: {e2e_time:.4f} s\n\n"
98102
output_str += "=" * 60 + "\n"
99103

100104
return output_str

examples/diffusers/flux/flux_1_schnell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
max_sequence_length=256,
1818
generator=torch.manual_seed(42),
1919
)
20-
image = pipeline.images[0]
20+
image = output.images[0]
2121
image.save("cat_with_sign.png")
2222

2323
print(output)

examples/diffusers/flux/flux_1_shnell_custom.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
generator=torch.manual_seed(42),
113113
)
114114

115-
image = output.images[0]
115+
images = output.images[0]
116116
# Save the generated image to disk
117-
image.save("girl_laughing.png")
117+
images.save("girl_laughing.png")
118118
print(output)

0 commit comments

Comments
 (0)