Skip to content

Commit 572303e

Browse files
author
Amit Raj
committed
Modification of Pipeline-1
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 24a4e39 commit 572303e

File tree

4 files changed

+50
-134
lines changed

4 files changed

+50
-134
lines changed

QEfficient/diffusers/pipelines/config_manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,16 @@ def config_manager(cls, config_source: Optional[str] = None):
3232
if not os.path.exists(config_source):
3333
raise FileNotFoundError(f"Configuration file not found: {config_source}")
3434

35-
cls._compile_config = load_json(config_source)
35+
cls.custom_config = load_json(config_source)
36+
37+
38+
def set_module_device_ids(cls):
39+
"""
40+
Set device IDs for each module based on the custom configuration.
41+
42+
Iterates through all modules in the pipeline and assigns device IDs
43+
from the configuration file to each module's device_ids attribute.
44+
"""
45+
config_modules = cls.custom_config["modules"]
46+
for module_name, module_obj in cls.has_module:
47+
module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]

QEfficient/diffusers/pipelines/flux/config/default_flux_execute_config.json

Whitespace-only changes.

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 26 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
1818
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps # TODO
1919

20-
from QEfficient.diffusers.pipelines.config_manager import config_manager
20+
from QEfficient.diffusers.pipelines.config_manager import config_manager, set_module_device_ids
2121
from QEfficient.diffusers.pipelines.pipeline_utils import (
2222
QEffClipTextEncoder,
2323
QEffFluxTransformerModel,
@@ -37,85 +37,27 @@ class QEFFFluxPipeline(FluxPipeline):
3737
It provides methods for text-to-image generation leveraging these optimized components.
3838
"""
3939

40-
def __init__(
41-
self,
42-
model=None,
43-
use_onnx_function=False,
44-
text_encoder=None,
45-
text_encoder_2=None,
46-
transformer=None,
47-
vae=None,
48-
tokenizer=None,
49-
tokenizer_2=None,
50-
scheduler=None,
51-
height: Optional[int] = 512,
52-
width: Optional[int] = 512,
53-
*args,
54-
**kwargs,
55-
):
56-
# Validate input: either model or individual components
57-
has_model = model is not None
58-
has_components = all(
59-
[
60-
text_encoder is not None,
61-
text_encoder_2 is not None,
62-
transformer is not None,
63-
vae is not None,
64-
tokenizer is not None,
65-
tokenizer_2 is not None,
66-
scheduler is not None,
67-
]
68-
)
69-
70-
if not has_model and not has_components:
71-
raise ValueError(
72-
"Either provide 'model' parameter OR all individual components "
73-
"(text_encoder, text_encoder_2, transformer, vae, "
74-
"tokenizer, tokenizer_2, scheduler)"
75-
)
76-
77-
if has_model and has_components:
78-
raise ValueError("Cannot provide both 'model' and individual components. Choose one approach.")
79-
80-
# Use ONNX subfunction to optimize the onnx graph
40+
def __init__(self, model, use_onnx_function, *args, **kwargs):
41+
self.text_encoder = QEffClipTextEncoder(model.text_encoder)
42+
self.text_encoder_2 = QEffTextEncoder(model.text_encoder_2)
43+
self.transformer = QEffFluxTransformerModel(model.transformer, use_onnx_function=use_onnx_function)
44+
self.vae_decode = QEffVAE(model, "decoder")
8145
self.use_onnx_function = use_onnx_function
82-
83-
self.text_encoder = (
84-
QEffClipTextEncoder(text_encoder) if has_components else QEffClipTextEncoder(model.text_encoder)
85-
)
86-
self.text_encoder_2 = (
87-
QEffTextEncoder(text_encoder_2) if has_components else QEffTextEncoder(model.text_encoder_2)
88-
)
89-
self.transformer = QEffFluxTransformerModel(
90-
transformer if has_components else model.transformer, self.use_onnx_function
91-
)
92-
self.vae_decode = QEffVAE(vae if has_components else model, "decoder")
93-
9446
self.has_module = [
9547
("text_encoder", self.text_encoder),
9648
("text_encoder_2", self.text_encoder_2),
9749
("transformer", self.transformer),
9850
("vae_decoder", self.vae_decode),
9951
]
10052

101-
self.tokenizer = tokenizer if has_components else model.tokenizer
102-
self.text_encoder.tokenizer = tokenizer if has_components else model.tokenizer
103-
self.text_encoder_2.tokenizer = tokenizer_2 if has_components else model.tokenizer_2
53+
self.tokenizer = model.tokenizer
54+
self.text_encoder.tokenizer = model.tokenizer
55+
self.text_encoder_2.tokenizer = model.tokenizer_2
10456
self.tokenizer_max_length = model.tokenizer_max_length
10557
self.scheduler = model.scheduler
10658

107-
self.height = height
108-
self.width = width
109-
110-
self.register_modules(
111-
vae=self.vae_decode,
112-
text_encoder=self.text_encoder,
113-
text_encoder_2=self.text_encoder_2,
114-
tokenizer=self.tokenizer,
115-
tokenizer_2=self.text_encoder_2.tokenizer,
116-
transformer=self.transformer,
117-
scheduler=self.scheduler,
118-
)
59+
self.height = kwargs.get("height", 256)
60+
self.width = kwargs.get("width", 256)
11961

12062
self.vae_decode.model.forward = lambda latent_sample, return_dict: self.vae_decode.model.decode(
12163
latent_sample, return_dict
@@ -200,7 +142,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
200142

201143
def compile(
202144
self,
203-
*,
204145
compile_config: Optional[str] = None,
205146
) -> str:
206147
"""
@@ -226,11 +167,12 @@ def compile(
226167
self.export()
227168

228169
# Initialize configuration manager (JSON-only approach)
229-
config_manager(self, config_source=compile_config)
170+
if self.custom_config is None:
171+
config_manager(self, config_source=compile_config)
230172

231173
for module_name, module_obj in self.has_module:
232174
# Get specialization values directly from config
233-
module_config = self._compile_config["modules"]
175+
module_config = self.custom_config["modules"]
234176
specializations = [module_config[module_name]["specializations"]]
235177

236178
# Get compilation parameters from configuration
@@ -396,8 +338,6 @@ def encode_prompt(
396338
prompt_embeds: Optional[torch.FloatTensor] = None,
397339
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
398340
max_sequence_length: int = 512,
399-
device_ids_text_encoder_1: Optional[List[int]] = None,
400-
device_ids_text_encoder_2: Optional[List[int]] = None,
401341
):
402342
r"""
403343
Encode the given prompts into text embeddings using the two text encoders (CLIP and T5).
@@ -433,14 +373,14 @@ def encode_prompt(
433373
# We only use the pooled prompt output from the CLIPTextModel
434374
pooled_prompt_embeds = self._get_clip_prompt_embeds(
435375
prompt=prompt,
436-
device_ids=device_ids_text_encoder_1,
376+
device_ids=self.text_encoder.device_ids,
437377
num_images_per_prompt=num_images_per_prompt,
438378
)
439379
prompt_embeds = self._get_t5_prompt_embeds(
440380
prompt=prompt_2,
441381
num_images_per_prompt=num_images_per_prompt,
442382
max_sequence_length=max_sequence_length,
443-
device_ids=device_ids_text_encoder_2,
383+
device_ids=self.text_encoder_2.device_ids,
444384
)
445385

446386
text_ids = torch.zeros(prompt_embeds.shape[1], 3)
@@ -453,8 +393,6 @@ def __call__(
453393
negative_prompt: Union[str, List[str]] = None,
454394
negative_prompt_2: Optional[Union[str, List[str]]] = None,
455395
true_cfg_scale: float = 1.0,
456-
height: Optional[int] = None,
457-
width: Optional[int] = None,
458396
num_inference_steps: int = 28,
459397
timesteps: List[int] = None,
460398
guidance_scale: float = 3.5,
@@ -471,10 +409,7 @@ def __call__(
471409
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
472410
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
473411
max_sequence_length: int = 512,
474-
device_ids_text_encoder_1: Optional[List[int]] = None,
475-
device_ids_text_encoder_2: Optional[List[int]] = None,
476-
device_ids_transformer: Optional[List[int]] = None,
477-
device_ids_vae_decoder: Optional[List[int]] = None,
412+
custom_config_path: Optional[str] = None,
478413
):
479414
r"""
480415
Function invoked when calling the pipeline for generation.
@@ -551,10 +486,6 @@ def __call__(
551486
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
552487
`._callback_tensor_inputs` attribute of your pipeline class.
553488
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
554-
device_ids_text_encoder1 (List[int], optional): List of device IDs to use for CLIP instance.
555-
device_ids_text_encoder2 (List[int], optional): List of device IDs to use for T5.
556-
device_ids_transformer (List[int], optional): List of device IDs to use for Flux transformer.
557-
device_ids_vae_decoder (List[int], optional): List of device IDs to use for VAE decoder.
558489
559490
Returns:
560491
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
@@ -578,11 +509,16 @@ def __call__(
578509
image.save("flux-schnell_aic.png")
579510
```
580511
"""
581-
height = self.height
582-
width = self.width
583512
device = "cpu"
584513

585514
# 1. Check inputs. Raise error if not correct
515+
if custom_config_path is not None:
516+
config_manager(self, custom_config_path)
517+
set_module_device_ids(self)
518+
519+
# Calling compile with custom config
520+
self.compile(compile_config=custom_config_path)
521+
586522
self.check_inputs(
587523
prompt,
588524
prompt_2,
@@ -624,8 +560,6 @@ def __call__(
624560
pooled_prompt_embeds=pooled_prompt_embeds,
625561
num_images_per_prompt=num_images_per_prompt,
626562
max_sequence_length=max_sequence_length,
627-
device_ids_text_encoder_1=device_ids_text_encoder_1,
628-
device_ids_text_encoder_2=device_ids_text_encoder_2,
629563
)
630564
if do_true_cfg:
631565
(
@@ -639,8 +573,6 @@ def __call__(
639573
pooled_prompt_embeds=negative_pooled_prompt_embeds,
640574
num_images_per_prompt=num_images_per_prompt,
641575
max_sequence_length=max_sequence_length,
642-
device_ids_text_encoder_1=device_ids_text_encoder_1,
643-
device_ids_text_encoder_2=device_ids_text_encoder_2,
644576
)
645577

646578
# 4. Prepare timesteps
@@ -665,7 +597,7 @@ def __call__(
665597
###### AIC related changes of transformers ######
666598
if self.transformer.qpc_session is None:
667599
self.transformer.qpc_session = QAICInferenceSession(
668-
str(self.transformer.qpc_path), device_ids=device_ids_transformer
600+
str(self.transformer.qpc_path), device_ids=self.transformer.device_ids
669601
)
670602

671603
output_buffer = {
@@ -728,40 +660,13 @@ def __call__(
728660
"adaln_out": adaln_out.detach().numpy(),
729661
}
730662

731-
# noise_pred_torch = self.transformer.model(
732-
# hidden_states=latents,
733-
# encoder_hidden_states = prompt_embeds,
734-
# pooled_projections=pooled_prompt_embeds,
735-
# timestep=torch.tensor(timestep),
736-
# img_ids = latent_image_ids,
737-
# txt_ids = text_ids,
738-
# adaln_emb = adaln_dual_emb,
739-
# adaln_single_emb=adaln_single_emb,
740-
# adaln_out = adaln_out,
741-
# return_dict=False,
742-
# )[0]
743-
744663
start_time = time.time()
745664
outputs = self.transformer.qpc_session.run(inputs_aic)
746665
end_time = time.time()
747666
print(f"Time : {end_time - start_time:.2f} seconds")
748667

749-
# ########################## Onnx
750-
# input_names = [i.name for i in session.get_inputs()]
751-
# output_names = [o.name for o in session.get_outputs()]
752-
# start_time = time.time()
753-
# ort_pred = session.run(None, {k: v for k, v in inputs_aic.items() if k in input_names})
754-
# end_time = time.time()
755-
# print(f"Onnx positive transformer denoising Time: {end_time - start_time:.2f} seconds")
756-
757668
noise_pred = torch.from_numpy(outputs["output"])
758669

759-
# # # # ###### ACCURACY TESTING #######
760-
# mad=np.max(np.abs(noise_pred_torch.detach().numpy()-outputs['output']))
761-
# print(f">>>>>>>>> at t = {t} FLUX transfromer model MAD pytorch vs QAIC :{mad}")
762-
# mad_O=np.max(np.abs(ort_pred[0]- outputs['output']))
763-
# print(f">>>>>>>>> FLUX transfromer model MAD Onnxrt vs QAIC : {mad_O}")
764-
765670
# compute the previous noisy sample x_t -> x_t-1
766671
latents_dtype = latents.dtype
767672
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -792,7 +697,7 @@ def __call__(
792697

793698
if self.vae_decode.qpc_session is None:
794699
self.vae_decode.qpc_session = QAICInferenceSession(
795-
str(self.vae_decode.qpc_path), device_ids=device_ids_vae_decoder
700+
str(self.vae_decode.qpc_path), device_ids=self.vae_decode.device_ids
796701
)
797702

798703
output_buffer = {
@@ -808,17 +713,9 @@ def __call__(
808713
inputs = {"latent_sample": latents.numpy()}
809714
image = self.vae_decode.qpc_session.run(inputs)
810715

811-
# ##### ACCURACY TESTING #######
812-
# image_torch = self.vae_decode.model(latents, return_dict=False)[0]
813-
# mad= np.max(np.abs(image['sample']-image_torch.detach().numpy()))
814-
# print(">>>>>>>>>>>> VAE mad: ",mad)
815-
816716
image_tensor = torch.from_numpy(image["sample"])
817717
image = self.image_processor.postprocess(image_tensor, output_type=output_type)
818718

819-
# Offload all models
820-
# self.maybe_free_model_hooks()
821-
822719
if not return_dict:
823720
return (image,)
824721

examples/diffusers/flux/flux_1_schnell.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
from QEfficient import QEFFFluxPipeline
1010

11-
pipeline = QEFFFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", height=256, width=256)
11+
pipeline = QEFFFluxPipeline.from_pretrained(
12+
"black-forest-labs/FLUX.1-schnell",
13+
height=256,
14+
width=256,
15+
)
16+
17+
# pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
1218

1319
######## for single layer
1420
original_blocks = pipeline.transformer.model.transformer_blocks
@@ -20,15 +26,16 @@
2026
pipeline.transformer.model.config.num_single_layers = 1
2127

2228

23-
pipeline.compile(compile_config="QEfficient/diffusers/pipelines/flux/config/default_flux_compile_config.json")
29+
# pipeline.compile(compile_config="QEfficient/diffusers/pipelines/flux/config/default_flux_compile_config.json")
2430

25-
generator = torch.manual_seed(42)
2631
# NOTE: guidance_scale <=1 is not supported
2732
image = pipeline(
2833
"A cat holding a sign that says hello world",
2934
guidance_scale=0.0,
3035
num_inference_steps=4,
3136
max_sequence_length=256,
32-
generator=generator,
37+
custom_config_path="QEfficient/diffusers/pipelines/flux/config/default_flux_compile_config.json",
38+
generator=torch.manual_seed(42),
3339
).images[0]
40+
3441
image.save("flux-schnell_aic_1024.png")

0 commit comments

Comments
 (0)