Skip to content

Commit 19e0ccd

Browse files
author
Amit Raj
committed
Parallel compilation and onnx subfunction is added
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent 03f9ded commit 19e0ccd

File tree

5 files changed

+127
-24
lines changed

5 files changed

+127
-24
lines changed

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from QEfficient.diffusers.pipelines.pipeline_utils import (
2525
ModulePerf,
2626
QEffPipelineOutput,
27+
compile_modules_parallel,
28+
compile_modules_sequential,
2729
config_manager,
2830
set_module_device_ids,
2931
)
@@ -192,7 +194,7 @@ def get_default_config_path() -> str:
192194
"""
193195
return os.path.join(os.path.dirname(__file__), "flux_config.json")
194196

195-
def compile(self, compile_config: Optional[str] = None) -> None:
197+
def compile(self, compile_config: Optional[str] = None, parallel: bool = False) -> None:
196198
"""
197199
Compile ONNX models for deployment on Qualcomm AI hardware.
198200
@@ -202,6 +204,8 @@ def compile(self, compile_config: Optional[str] = None) -> None:
202204
Args:
203205
compile_config (str, optional): Path to JSON configuration file.
204206
If None, uses default configuration.
207+
parallel (bool): If True, compile modules in parallel using ProcessPoolExecutor.
208+
If False, compile sequentially (default: False).
205209
"""
206210
# Ensure all modules are exported to ONNX before compilation
207211
if any(
@@ -219,21 +223,20 @@ def compile(self, compile_config: Optional[str] = None) -> None:
219223
if self.custom_config is None:
220224
config_manager(self, config_source=compile_config)
221225

222-
# Compile each module with its specific configuration
223-
for module_name, module_obj in tqdm(self.modules.items(), desc="Compiling modules", unit="module"):
224-
module_config = self.custom_config["modules"]
225-
specializations = module_config[module_name]["specializations"]
226-
compile_kwargs = module_config[module_name]["compilation"]
227-
228-
# Set dynamic specialization values based on image dimensions
229-
if module_name == "transformer":
230-
specializations["cl"] = self.cl
231-
elif module_name == "vae_decoder":
232-
specializations["latent_height"] = self.latent_height
233-
specializations["latent_width"] = self.latent_width
226+
# Prepare dynamic specialization updates based on image dimensions
227+
specialization_updates = {
228+
"transformer": {"cl": self.cl},
229+
"vae_decoder": {
230+
"latent_height": self.latent_height,
231+
"latent_width": self.latent_width,
232+
},
233+
}
234234

235-
# Compile the module to QPC format
236-
module_obj.compile(specializations=[specializations], **compile_kwargs)
235+
# Use generic utility functions for compilation
236+
if parallel:
237+
compile_modules_parallel(self.modules, self.custom_config, specialization_updates)
238+
else:
239+
compile_modules_sequential(self.modules, self.custom_config, specialization_updates)
237240

238241
def _get_t5_prompt_embeds(
239242
self,
@@ -467,6 +470,7 @@ def __call__(
467470
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
468471
max_sequence_length: int = 512,
469472
custom_config_path: Optional[str] = None,
473+
parallel_compile: bool = False,
470474
):
471475
"""
472476
Generate images from text prompts using the Flux pipeline.
@@ -501,6 +505,8 @@ def __call__(
501505
callback_on_step_end_tensor_inputs (List[str]): Tensors to pass to callback
502506
max_sequence_length (int): Maximum sequence length for T5 (default: 512)
503507
custom_config_path (str, optional): Path to custom compilation config
508+
parallel_compile (bool): If True, compile modules in parallel for faster compilation.
509+
If False, compile sequentially (default: False).
504510
505511
Returns:
506512
QEffPipelineOutput or tuple: Generated images and performance metrics
@@ -512,7 +518,7 @@ def __call__(
512518
config_manager(self, custom_config_path)
513519
set_module_device_ids(self)
514520

515-
self.compile(compile_config=custom_config_path)
521+
self.compile(compile_config=custom_config_path, parallel=parallel_compile)
516522

517523
# Validate all inputs
518524
self.check_inputs(

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ class QEffFluxTransformerModel(QEFFBaseModel):
374374
_onnx_transforms (List): ONNX transformations applied after export
375375
"""
376376

377-
_pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform]
377+
_pytorch_transforms = [AttentionTransform, NormalizationTransform, CustomOpsTransform]
378378
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
379379

380380
def __init__(self, model: nn.Module, use_onnx_function: bool) -> None:
@@ -386,13 +386,17 @@ def __init__(self, model: nn.Module, use_onnx_function: bool) -> None:
386386
use_onnx_function (bool): Whether to export transformer blocks as ONNX functions
387387
for better modularity and potential optimization
388388
"""
389-
super().__init__(model)
390389

391390
# Optionally apply ONNX function transform for modular export
391+
392392
if use_onnx_function:
393-
self._pytorch_transforms.append(OnnxFunctionTransform)
394393
model, _ = OnnxFunctionTransform.apply(model)
395394

395+
super().__init__(model)
396+
397+
if use_onnx_function:
398+
self._pytorch_transforms.append(OnnxFunctionTransform)
399+
396400
# Ensure model is on CPU to avoid meta device issues
397401
self.model = model.to("cpu")
398402

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
# ----------------------------------------------------------------------------
77

88
import os
9+
from concurrent.futures import ThreadPoolExecutor, as_completed
910
from dataclasses import dataclass
10-
from typing import List, Optional, Union
11+
from typing import Any, Dict, List, Optional, Union
1112

1213
import numpy as np
1314
import PIL.Image
15+
from tqdm import tqdm
1416

1517
from QEfficient.utils._utils import load_json
18+
from QEfficient.utils.logging_utils import logger
1619

1720

1821
def config_manager(cls, config_source: Optional[str] = None):
@@ -51,14 +54,101 @@ def set_module_device_ids(cls):
5154
module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
5255

5356

57+
def compile_modules_parallel(
58+
modules: Dict[str, Any],
59+
config: Dict[str, Any],
60+
specialization_updates: Dict[str, Dict[str, Any]] = None,
61+
) -> None:
62+
"""
63+
Compile multiple pipeline modules in parallel using ThreadPoolExecutor.
64+
65+
Args:
66+
modules: Dictionary of module_name -> module_object pairs to compile
67+
config: Configuration dictionary containing module-specific compilation settings
68+
specialization_updates: Optional dictionary of module_name -> specialization_updates
69+
to apply dynamic values (e.g., image dimensions)
70+
"""
71+
72+
def _prepare_and_compile(module_name: str, module_obj: Any) -> None:
73+
"""Prepare specializations and compile a single module."""
74+
specializations = config["modules"][module_name]["specializations"].copy()
75+
compile_kwargs = config["modules"][module_name]["compilation"]
76+
77+
if specialization_updates and module_name in specialization_updates:
78+
specializations.update(specialization_updates[module_name])
79+
80+
module_obj.compile(specializations=[specializations], **compile_kwargs)
81+
82+
# Execute compilations in parallel
83+
with ThreadPoolExecutor(max_workers=len(modules)) as executor:
84+
futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()}
85+
86+
with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar:
87+
for future in as_completed(futures):
88+
try:
89+
future.result()
90+
except Exception as e:
91+
logger.error(f"Compilation failed for {futures[future]}: {e}")
92+
raise
93+
pbar.update(1)
94+
95+
96+
def compile_modules_sequential(
97+
modules: Dict[str, Any],
98+
config: Dict[str, Any],
99+
specialization_updates: Dict[str, Dict[str, Any]] = None,
100+
) -> None:
101+
"""
102+
Compile multiple pipeline modules sequentially.
103+
104+
This function provides a generic way to compile diffusion pipeline modules
105+
sequentially, which is the default behavior for backward compatibility.
106+
107+
Args:
108+
modules: Dictionary of module_name -> module_object pairs to compile
109+
config: Configuration dictionary containing module-specific compilation settings
110+
specialization_updates: Optional dictionary of module_name -> specialization_updates
111+
to apply dynamic values (e.g., image dimensions)
112+
113+
"""
114+
for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"):
115+
module_config = config["modules"]
116+
specializations = module_config[module_name]["specializations"].copy()
117+
compile_kwargs = module_config[module_name]["compilation"]
118+
119+
# Apply dynamic specialization updates if provided
120+
if specialization_updates and module_name in specialization_updates:
121+
specializations.update(specialization_updates[module_name])
122+
123+
# Compile the module to QPC format
124+
module_obj.compile(specializations=[specializations], **compile_kwargs)
125+
126+
54127
@dataclass(frozen=True)
55128
class ModulePerf:
129+
"""
130+
Data class to store performance metrics for a pipeline module.
131+
132+
Attributes:
133+
module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder')
134+
perf: Performance metric in seconds. Can be a single float for modules that run once,
135+
or a list of floats for modules that run multiple times (e.g., transformer steps)
136+
"""
137+
56138
module_name: str
57139
perf: int
58140

59141

60142
@dataclass(frozen=True)
61143
class QEffPipelineOutput:
144+
"""
145+
Data class to store the output of a QEfficient diffusion pipeline.
146+
147+
Attributes:
148+
pipeline_module: List of ModulePerf objects containing performance metrics for each module
149+
images: Generated images as either a list of PIL Images or numpy array
150+
"""
151+
62152
pipeline_module: list[ModulePerf]
63153
images: Union[List[PIL.Image.Image], np.ndarray]
64154

QEfficient/utils/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,8 @@ def wrapper(self, *args, **kwargs):
568568
model_params=self.hash_params,
569569
output_names=all_args.get("output_names"),
570570
dynamic_axes=all_args.get("dynamic_axes"),
571-
export_kwargs=all_args.get("export_kwargs", None),
571+
# TODO: Re-enable export_kwargs hashing before merging this PR
572+
# export_kwargs=all_args.get("export_kwargs", None),
572573
onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None),
573574
)
574575
export_dir = export_dir.with_name(export_dir.name + "-" + export_hash)

examples/diffusers/flux/flux_1_shnell_custom.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
# Note: Smaller dimensions = faster generation but lower resolution
3636

3737
# Option 1: Basic initialization with custom image dimensions
38+
# NOTE: use_onnx_function=True enables modular ONNX export optimizations (Experimental so not recommended)
39+
# This feature improves export performance by breaking down the model into smaller,
40+
# more manageable ONNX functions, which can lead to better compilation and runtime efficiency.
3841
pipeline = QEFFFluxPipeline.from_pretrained(
39-
"black-forest-labs/FLUX.1-schnell",
40-
height=512,
41-
width=512,
42+
"black-forest-labs/FLUX.1-schnell", height=256, width=256, use_onnx_function=False
4243
)
4344

4445
# Option 2: Advanced initialization with custom modules
@@ -109,6 +110,7 @@
109110
num_inference_steps=4,
110111
max_sequence_length=256,
111112
generator=torch.manual_seed(42),
113+
parallel_compile=True,
112114
)
113115

114116
images = output.images[0]

0 commit comments

Comments
 (0)