Skip to content

Commit bed22a1

Browse files
tv-karthikeyaAmit Raj
authored andcommitted
adding device id support for flux for all stages
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 4e2634e commit bed22a1

File tree

6 files changed

+25
-526
lines changed

6 files changed

+25
-526
lines changed

QEfficient/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ def check_qaic_sdk():
5252
from QEfficient.compile.compile_helper import compile
5353

5454
# Imports for the diffusers
55-
from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline
56-
from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import (
57-
QEFFStableDiffusion3Pipeline,
58-
)
5955
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEFFFluxPipeline
6056
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
6157
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
@@ -77,8 +73,6 @@ def check_qaic_sdk():
7773
"QEFFAutoModelForImageTextToText",
7874
"QEFFAutoModelForSpeechSeq2Seq",
7975
"QEFFCommonLoader",
80-
"QEFFStableDiffusionPipeline",
81-
"QEFFStableDiffusion3Pipeline",
8276
"QEFFFluxPipeline",
8377
]
8478

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from diffusers import FluxPipeline
1616
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
1717
from diffusers.utils.torch_utils import randn_tensor
18-
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
18+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps # TODO
1919
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
2020

2121
from QEfficient.diffusers.pipelines.pipeline_utils import QEffTextEncoder, QEffClipTextEncoder, QEffVAE, QEffFluxTransformerModel
@@ -274,7 +274,7 @@ def _get_t5_prompt_embeds(
274274
prompt: Union[str, List[str]] = None,
275275
num_images_per_prompt: int = 1,
276276
max_sequence_length: int = 512,
277-
device_ids: List[int] = None,
277+
device_ids: Optional[List[int]] = None,
278278
dtype: Optional[torch.dtype] = None,
279279
):
280280
"""
@@ -326,7 +326,6 @@ def _get_t5_prompt_embeds(
326326
aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)}
327327
prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"])
328328

329-
self.text_encoder_2.qpc_session.deactivate()
330329

331330
# # # AIC Testing
332331
# prompt_embeds_pytorch = self.text_encoder_2.model(text_input_ids, output_hidden_states=False)
@@ -345,7 +344,7 @@ def _get_clip_prompt_embeds(
345344
self,
346345
prompt: Union[str, List[str]],
347346
num_images_per_prompt: int = 1,
348-
device_ids: List[int] = None,
347+
device_ids: Optional[List[int]] = None,
349348
):
350349
"""
351350
Get CLIP prompt embeddings for a given text encoder and tokenizer.
@@ -395,7 +394,6 @@ def _get_clip_prompt_embeds(
395394
aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input)
396395
aic_text_encoder_emb = aic_embeddings["pooler_output"]
397396

398-
self.text_encoder.qpc_session.deactivate() #To deactivate CLIP instance
399397

400398
# # # [TEMP] CHECK ACC # #
401399
# prompt_embeds_pytorch = self.text_encoder.model(text_input_ids, output_hidden_states=False)
@@ -418,11 +416,12 @@ def encode_prompt(
418416
self,
419417
prompt: Union[str, List[str]],
420418
prompt_2: Optional[Union[str, List[str]]] = None,
421-
device_ids: List[int] = None,
422419
num_images_per_prompt: int = 1,
423420
prompt_embeds: Optional[torch.FloatTensor] = None,
424421
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
425422
max_sequence_length: int = 512,
423+
device_ids_text_encoder_1 : Optional[List[int]] = None,
424+
device_ids_text_encoder_2 : Optional[List[int]] = None
426425
):
427426
r"""
428427
Encode the given prompts into text embeddings using the two text encoders (CLIP and T5).
@@ -437,8 +436,6 @@ def encode_prompt(
437436
prompt_2 (`str` or `List[str]`, *optional*):
438437
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
439438
used in all text-encoders
440-
device: (`torch.device`):
441-
torch device
442439
num_images_per_prompt (`int`):
443440
number of images that should be generated per prompt
444441
prompt_embeds (`torch.FloatTensor`, *optional*):
@@ -447,6 +444,8 @@ def encode_prompt(
447444
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
448445
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
449446
If not provided, pooled text embeddings will be generated from `prompt` input argument.
447+
device_ids_text_encoder_1 (List[int], optional): List of device IDs to use for CLIP instance .
448+
device_ids_text_encoder_2 (List[int], optional): List of device IDs to use for T5 .
450449
"""
451450

452451
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -458,14 +457,15 @@ def encode_prompt(
458457
# We only use the pooled prompt output from the CLIPTextModel
459458
pooled_prompt_embeds = self._get_clip_prompt_embeds(
460459
prompt=prompt,
461-
device_ids=device_ids,
460+
device_ids=device_ids_text_encoder_1,
462461
num_images_per_prompt=num_images_per_prompt,
462+
463463
)
464464
prompt_embeds = self._get_t5_prompt_embeds(
465465
prompt=prompt_2,
466466
num_images_per_prompt=num_images_per_prompt,
467467
max_sequence_length=max_sequence_length,
468-
device_ids=device_ids,
468+
device_ids=device_ids_text_encoder_2,
469469
)
470470

471471
text_ids = torch.zeros(prompt_embeds.shape[1], 3)
@@ -497,7 +497,10 @@ def __call__(
497497
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
498498
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
499499
max_sequence_length: int = 512,
500-
qpc_path: str = None,
500+
device_ids_text_encoder_1 : Optional[List[int]] = None,
501+
device_ids_text_encoder_2 : Optional[List[int]] = None,
502+
device_ids_transformer : Optional[List[int]] = None,
503+
device_ids_vae_decoder : Optional[List[int]] = None,
501504
):
502505
r"""
503506
Function invoked when calling the pipeline for generation.
@@ -574,6 +577,10 @@ def __call__(
574577
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
575578
`._callback_tensor_inputs` attribute of your pipeline class.
576579
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
580+
device_ids_text_encoder1 (List[int], optional): List of device IDs to use for CLIP instance.
581+
device_ids_text_encoder2 (List[int], optional): List of device IDs to use for T5.
582+
device_ids_transformer (List[int], optional): List of device IDs to use for Flux transformer.
583+
device_ids_vae_decoder (List[int], optional): List of device IDs to use for VAE decoder.
577584
578585
Returns:
579586
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
@@ -644,6 +651,8 @@ def __call__(
644651
pooled_prompt_embeds=pooled_prompt_embeds,
645652
num_images_per_prompt=num_images_per_prompt,
646653
max_sequence_length=max_sequence_length,
654+
device_ids_text_encoder_1=device_ids_text_encoder_1,
655+
device_ids_text_encoder_2=device_ids_text_encoder_2
647656
)
648657
if do_true_cfg:
649658
(
@@ -657,6 +666,8 @@ def __call__(
657666
pooled_prompt_embeds=negative_pooled_prompt_embeds,
658667
num_images_per_prompt=num_images_per_prompt,
659668
max_sequence_length=max_sequence_length,
669+
device_ids_text_encoder_1=device_ids_text_encoder_1,
670+
device_ids_text_encoder_2=device_ids_text_encoder_2
660671
)
661672

662673
# 4. Prepare timesteps
@@ -680,7 +691,7 @@ def __call__(
680691
# 6. Denoising loop
681692
###### AIC related changes of transformers ######
682693
if self.transformer.qpc_session is None:
683-
self.transformer.qpc_session = QAICInferenceSession(str(self.trasformer_compile_path))
694+
self.transformer.qpc_session = QAICInferenceSession(str(self.trasformer_compile_path), device_ids=device_ids_transformer)
684695

685696
output_buffer = {
686697
"output": np.random.rand(
@@ -782,11 +793,8 @@ def __call__(
782793
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
783794
latents = (latents / self.vae_decode.model.scaling_factor) + self.vae_decode.model.shift_factor
784795

785-
786-
self.transformer.qpc_session.deactivate()
787-
788796
if self.vae_decode.qpc_session is None:
789-
self.vae_decode.qpc_session = QAICInferenceSession(str(self.vae_decoder_compile_path))
797+
self.vae_decode.qpc_session = QAICInferenceSession(str(self.vae_decoder_compile_path), device_ids=device_ids_vae_decoder)
790798

791799
output_buffer = {
792800
"sample": np.random.rand(
@@ -797,7 +805,6 @@ def __call__(
797805

798806
inputs = {"latent_sample": latents.numpy()}
799807
image = self.vae_decode.qpc_session.run(inputs)
800-
self.vae_decode.qpc_session.deactivate()
801808

802809
###### ACCURACY TESTING #######
803810
# image_torch = self.vae_decode.model(latents, return_dict=False)[0]

QEfficient/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)