|
14 | 14 | import torch |
15 | 15 | from diffusers import FluxPipeline |
16 | 16 | from diffusers.image_processor import VaeImageProcessor |
17 | | -from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput |
18 | 17 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps # TODO |
19 | 18 |
|
20 | | -from QEfficient.diffusers.pipelines.config_manager import config_manager, set_module_device_ids |
21 | | -from QEfficient.diffusers.pipelines.pipeline_utils import ( |
| 19 | +from QEfficient.diffusers.pipelines.pipeline_module import ( |
22 | 20 | QEffFluxTransformerModel, |
23 | 21 | QEffTextEncoder, |
24 | 22 | QEffVAE, |
25 | 23 | ) |
| 24 | +from QEfficient.diffusers.pipelines.pipeline_utils import QEffPipelineOutput, config_manager, set_module_device_ids |
26 | 25 | from QEfficient.generation.cloud_infer import QAICInferenceSession |
27 | 26 |
|
28 | 27 |
|
@@ -259,10 +258,10 @@ def _get_t5_prompt_embeds( |
259 | 258 | aic_text_input = {"input_ids": text_input_ids.numpy().astype(np.int64)} |
260 | 259 | import time |
261 | 260 |
|
262 | | - start_time = time.time() |
| 261 | + start_t5_time = time.time() |
263 | 262 | prompt_embeds = torch.tensor(self.text_encoder_2.qpc_session.run(aic_text_input)["last_hidden_state"]) |
264 | | - end_time = time.time() |
265 | | - print(f"T5 Text encoder inference time: {end_time - start_time:.4f} seconds") |
| 263 | + end_t5_time = time.time() |
| 264 | + self.text_encoder_2.inference_time = end_t5_time - start_t5_time |
266 | 265 |
|
267 | 266 | _, seq_len, _ = prompt_embeds.shape |
268 | 267 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method |
@@ -325,10 +324,11 @@ def _get_clip_prompt_embeds( |
325 | 324 |
|
326 | 325 | import time |
327 | 326 |
|
328 | | - start_time = time.time() |
| 327 | + global start_text_encoder_time |
| 328 | + start_text_encoder_time = time.time() |
329 | 329 | aic_embeddings = self.text_encoder.qpc_session.run(aic_text_input) |
330 | | - end_time = time.time() |
331 | | - print(f"CLIP Text encoder inference time: {end_time - start_time:.4f} seconds") |
| 330 | + end_text_encoder_time = time.time() |
| 331 | + self.text_encoder.inference_time = end_text_encoder_time - start_text_encoder_time |
332 | 332 | prompt_embeds = torch.tensor(aic_embeddings["pooler_output"]) |
333 | 333 |
|
334 | 334 | # duplicate text embeddings for each generation per prompt, using mps friendly method |
@@ -595,7 +595,7 @@ def __call__( |
595 | 595 | } |
596 | 596 |
|
597 | 597 | self.transformer.qpc_session.set_buffers(output_buffer) |
598 | | - |
| 598 | + self.transformer.inference_time = [] |
599 | 599 | self.scheduler.set_begin_index(0) |
600 | 600 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
601 | 601 | for i, t in enumerate(timesteps): |
@@ -650,10 +650,10 @@ def __call__( |
650 | 650 | "adaln_out": adaln_out.detach().numpy(), |
651 | 651 | } |
652 | 652 |
|
653 | | - start_time = time.time() |
| 653 | + start_transformer_step_time = time.time() |
654 | 654 | outputs = self.transformer.qpc_session.run(inputs_aic) |
655 | | - end_time = time.time() |
656 | | - print(f"Transformers inference time : {end_time - start_time:.2f} seconds") |
| 655 | + end_transfromer_step_time = time.time() |
| 656 | + self.transformer.inference_time.append(end_transfromer_step_time - start_transformer_step_time) |
657 | 657 |
|
658 | 658 | noise_pred = torch.from_numpy(outputs["output"]) |
659 | 659 |
|
@@ -701,14 +701,17 @@ def __call__( |
701 | 701 | self.vae_decode.qpc_session.set_buffers(output_buffer) |
702 | 702 |
|
703 | 703 | inputs = {"latent_sample": latents.numpy()} |
704 | | - start_time = time.time() |
| 704 | + start_decode_time = time.time() |
705 | 705 | image = self.vae_decode.qpc_session.run(inputs) |
706 | | - end_time = time.time() |
707 | | - print(f"Decoder Text encoder inference time: {end_time - start_time:.4f} seconds") |
| 706 | + end_decode_time = time.time() |
| 707 | + self.vae_decode.inference_time = end_decode_time - start_decode_time |
708 | 708 | image_tensor = torch.from_numpy(image["sample"]) |
709 | 709 | image = self.image_processor.postprocess(image_tensor, output_type=output_type) |
710 | 710 |
|
711 | | - if not return_dict: |
712 | | - return (image,) |
| 711 | + total_time_taken = end_decode_time - start_text_encoder_time |
713 | 712 |
|
714 | | - return FluxPipelineOutput(images=image) |
| 713 | + return QEffPipelineOutput( |
| 714 | + pipeline=self, |
| 715 | + images=image, |
| 716 | + E2E_time=total_time_taken, |
| 717 | + ) |
0 commit comments