1515from diffusers import FluxPipeline
1616from diffusers .image_processor import VaeImageProcessor , PipelineImageInput
1717from diffusers .utils .torch_utils import randn_tensor
18- from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion import retrieve_timesteps
18+ from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion import retrieve_timesteps # TODO
1919from diffusers .pipelines .flux .pipeline_output import FluxPipelineOutput
2020
2121from QEfficient .diffusers .pipelines .pipeline_utils import QEffTextEncoder , QEffClipTextEncoder , QEffVAE , QEffFluxTransformerModel
@@ -274,7 +274,7 @@ def _get_t5_prompt_embeds(
274274 prompt : Union [str , List [str ]] = None ,
275275 num_images_per_prompt : int = 1 ,
276276 max_sequence_length : int = 512 ,
277- device_ids : List [int ] = None ,
277+ device_ids : Optional [ List [int ] ] = None ,
278278 dtype : Optional [torch .dtype ] = None ,
279279 ):
280280 """
@@ -326,7 +326,6 @@ def _get_t5_prompt_embeds(
326326 aic_text_input = {"input_ids" : text_input_ids .numpy ().astype (np .int64 )}
327327 prompt_embeds = torch .tensor (self .text_encoder_2 .qpc_session .run (aic_text_input )["last_hidden_state" ])
328328
329- self .text_encoder_2 .qpc_session .deactivate ()
330329
331330 # # # AIC Testing
332331 # prompt_embeds_pytorch = self.text_encoder_2.model(text_input_ids, output_hidden_states=False)
@@ -345,7 +344,7 @@ def _get_clip_prompt_embeds(
345344 self ,
346345 prompt : Union [str , List [str ]],
347346 num_images_per_prompt : int = 1 ,
348- device_ids : List [int ] = None ,
347+ device_ids : Optional [ List [int ] ] = None ,
349348 ):
350349 """
351350 Get CLIP prompt embeddings for a given text encoder and tokenizer.
@@ -395,7 +394,6 @@ def _get_clip_prompt_embeds(
395394 aic_embeddings = self .text_encoder .qpc_session .run (aic_text_input )
396395 aic_text_encoder_emb = aic_embeddings ["pooler_output" ]
397396
398- self .text_encoder .qpc_session .deactivate () #To deactivate CLIP instance
399397
400398 # # # [TEMP] CHECK ACC # #
401399 # prompt_embeds_pytorch = self.text_encoder.model(text_input_ids, output_hidden_states=False)
@@ -418,11 +416,12 @@ def encode_prompt(
418416 self ,
419417 prompt : Union [str , List [str ]],
420418 prompt_2 : Optional [Union [str , List [str ]]] = None ,
421- device_ids : List [int ] = None ,
422419 num_images_per_prompt : int = 1 ,
423420 prompt_embeds : Optional [torch .FloatTensor ] = None ,
424421 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
425422 max_sequence_length : int = 512 ,
423+ device_ids_text_encoder_1 : Optional [List [int ]] = None ,
424+ device_ids_text_encoder_2 : Optional [List [int ]] = None
426425 ):
427426 r"""
428427 Encode the given prompts into text embeddings using the two text encoders (CLIP and T5).
@@ -437,8 +436,6 @@ def encode_prompt(
437436 prompt_2 (`str` or `List[str]`, *optional*):
438437 The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
439438 used in all text-encoders
440- device: (`torch.device`):
441- torch device
442439 num_images_per_prompt (`int`):
443440 number of images that should be generated per prompt
444441 prompt_embeds (`torch.FloatTensor`, *optional*):
@@ -447,6 +444,8 @@ def encode_prompt(
447444 pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
448445 Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
449446 If not provided, pooled text embeddings will be generated from `prompt` input argument.
447+ device_ids_text_encoder_1 (List[int], optional): List of device IDs to use for CLIP instance .
448+ device_ids_text_encoder_2 (List[int], optional): List of device IDs to use for T5 .
450449 """
451450
452451 prompt = [prompt ] if isinstance (prompt , str ) else prompt
@@ -458,14 +457,15 @@ def encode_prompt(
458457 # We only use the pooled prompt output from the CLIPTextModel
459458 pooled_prompt_embeds = self ._get_clip_prompt_embeds (
460459 prompt = prompt ,
461- device_ids = device_ids ,
460+ device_ids = device_ids_text_encoder_1 ,
462461 num_images_per_prompt = num_images_per_prompt ,
462+
463463 )
464464 prompt_embeds = self ._get_t5_prompt_embeds (
465465 prompt = prompt_2 ,
466466 num_images_per_prompt = num_images_per_prompt ,
467467 max_sequence_length = max_sequence_length ,
468- device_ids = device_ids ,
468+ device_ids = device_ids_text_encoder_2 ,
469469 )
470470
471471 text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 )
@@ -497,7 +497,10 @@ def __call__(
497497 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
498498 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
499499 max_sequence_length : int = 512 ,
500- qpc_path : str = None ,
500+ device_ids_text_encoder_1 : Optional [List [int ]] = None ,
501+ device_ids_text_encoder_2 : Optional [List [int ]] = None ,
502+ device_ids_transformer : Optional [List [int ]] = None ,
503+ device_ids_vae_decoder : Optional [List [int ]] = None ,
501504 ):
502505 r"""
503506 Function invoked when calling the pipeline for generation.
@@ -574,6 +577,10 @@ def __call__(
574577 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
575578 `._callback_tensor_inputs` attribute of your pipeline class.
576579 max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
580+ device_ids_text_encoder1 (List[int], optional): List of device IDs to use for CLIP instance.
581+ device_ids_text_encoder2 (List[int], optional): List of device IDs to use for T5.
582+ device_ids_transformer (List[int], optional): List of device IDs to use for Flux transformer.
583+ device_ids_vae_decoder (List[int], optional): List of device IDs to use for VAE decoder.
577584
578585 Returns:
579586 [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
@@ -644,6 +651,8 @@ def __call__(
644651 pooled_prompt_embeds = pooled_prompt_embeds ,
645652 num_images_per_prompt = num_images_per_prompt ,
646653 max_sequence_length = max_sequence_length ,
654+ device_ids_text_encoder_1 = device_ids_text_encoder_1 ,
655+ device_ids_text_encoder_2 = device_ids_text_encoder_2
647656 )
648657 if do_true_cfg :
649658 (
@@ -657,6 +666,8 @@ def __call__(
657666 pooled_prompt_embeds = negative_pooled_prompt_embeds ,
658667 num_images_per_prompt = num_images_per_prompt ,
659668 max_sequence_length = max_sequence_length ,
669+ device_ids_text_encoder_1 = device_ids_text_encoder_1 ,
670+ device_ids_text_encoder_2 = device_ids_text_encoder_2
660671 )
661672
662673 # 4. Prepare timesteps
@@ -680,7 +691,7 @@ def __call__(
680691 # 6. Denoising loop
681692 ###### AIC related changes of transformers ######
682693 if self .transformer .qpc_session is None :
683- self .transformer .qpc_session = QAICInferenceSession (str (self .trasformer_compile_path ))
694+ self .transformer .qpc_session = QAICInferenceSession (str (self .trasformer_compile_path ), device_ids = device_ids_transformer )
684695
685696 output_buffer = {
686697 "output" : np .random .rand (
@@ -782,11 +793,8 @@ def __call__(
782793 latents = self ._unpack_latents (latents , height , width , self .vae_scale_factor )
783794 latents = (latents / self .vae_decode .model .scaling_factor ) + self .vae_decode .model .shift_factor
784795
785-
786- self .transformer .qpc_session .deactivate ()
787-
788796 if self .vae_decode .qpc_session is None :
789- self .vae_decode .qpc_session = QAICInferenceSession (str (self .vae_decoder_compile_path ))
797+ self .vae_decode .qpc_session = QAICInferenceSession (str (self .vae_decoder_compile_path ), device_ids = device_ids_vae_decoder )
790798
791799 output_buffer = {
792800 "sample" : np .random .rand (
@@ -797,7 +805,6 @@ def __call__(
797805
798806 inputs = {"latent_sample" : latents .numpy ()}
799807 image = self .vae_decode .qpc_session .run (inputs )
800- self .vae_decode .qpc_session .deactivate ()
801808
802809 ###### ACCURACY TESTING #######
803810 # image_torch = self.vae_decode.model(latents, return_dict=False)[0]
0 commit comments