|
6 | 6 | # ---------------------------------------------------------------------------- |
7 | 7 |
|
8 | 8 | import os |
| 9 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
9 | 10 | from dataclasses import dataclass |
10 | | -from typing import List, Optional, Union |
| 11 | +from typing import Any, Dict, List, Optional, Union |
11 | 12 |
|
12 | 13 | import numpy as np |
13 | 14 | import PIL.Image |
| 15 | +from tqdm import tqdm |
14 | 16 |
|
15 | 17 | from QEfficient.utils._utils import load_json |
| 18 | +from QEfficient.utils.logging_utils import logger |
16 | 19 |
|
17 | 20 |
|
18 | 21 | def config_manager(cls, config_source: Optional[str] = None): |
@@ -51,14 +54,101 @@ def set_module_device_ids(cls): |
51 | 54 | module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"] |
52 | 55 |
|
53 | 56 |
|
| 57 | +def compile_modules_parallel( |
| 58 | + modules: Dict[str, Any], |
| 59 | + config: Dict[str, Any], |
| 60 | + specialization_updates: Dict[str, Dict[str, Any]] = None, |
| 61 | +) -> None: |
| 62 | + """ |
| 63 | + Compile multiple pipeline modules in parallel using ThreadPoolExecutor. |
| 64 | +
|
| 65 | + Args: |
| 66 | + modules: Dictionary of module_name -> module_object pairs to compile |
| 67 | + config: Configuration dictionary containing module-specific compilation settings |
| 68 | + specialization_updates: Optional dictionary of module_name -> specialization_updates |
| 69 | + to apply dynamic values (e.g., image dimensions) |
| 70 | + """ |
| 71 | + |
| 72 | + def _prepare_and_compile(module_name: str, module_obj: Any) -> None: |
| 73 | + """Prepare specializations and compile a single module.""" |
| 74 | + specializations = config["modules"][module_name]["specializations"].copy() |
| 75 | + compile_kwargs = config["modules"][module_name]["compilation"] |
| 76 | + |
| 77 | + if specialization_updates and module_name in specialization_updates: |
| 78 | + specializations.update(specialization_updates[module_name]) |
| 79 | + |
| 80 | + module_obj.compile(specializations=[specializations], **compile_kwargs) |
| 81 | + |
| 82 | + # Execute compilations in parallel |
| 83 | + with ThreadPoolExecutor(max_workers=len(modules)) as executor: |
| 84 | + futures = {executor.submit(_prepare_and_compile, name, obj): name for name, obj in modules.items()} |
| 85 | + |
| 86 | + with tqdm(total=len(futures), desc="Compiling modules", unit="module") as pbar: |
| 87 | + for future in as_completed(futures): |
| 88 | + try: |
| 89 | + future.result() |
| 90 | + except Exception as e: |
| 91 | + logger.error(f"Compilation failed for {futures[future]}: {e}") |
| 92 | + raise |
| 93 | + pbar.update(1) |
| 94 | + |
| 95 | + |
| 96 | +def compile_modules_sequential( |
| 97 | + modules: Dict[str, Any], |
| 98 | + config: Dict[str, Any], |
| 99 | + specialization_updates: Dict[str, Dict[str, Any]] = None, |
| 100 | +) -> None: |
| 101 | + """ |
| 102 | + Compile multiple pipeline modules sequentially. |
| 103 | +
|
| 104 | + This function provides a generic way to compile diffusion pipeline modules |
| 105 | + sequentially, which is the default behavior for backward compatibility. |
| 106 | +
|
| 107 | + Args: |
| 108 | + modules: Dictionary of module_name -> module_object pairs to compile |
| 109 | + config: Configuration dictionary containing module-specific compilation settings |
| 110 | + specialization_updates: Optional dictionary of module_name -> specialization_updates |
| 111 | + to apply dynamic values (e.g., image dimensions) |
| 112 | +
|
| 113 | + """ |
| 114 | + for module_name, module_obj in tqdm(modules.items(), desc="Compiling modules", unit="module"): |
| 115 | + module_config = config["modules"] |
| 116 | + specializations = module_config[module_name]["specializations"].copy() |
| 117 | + compile_kwargs = module_config[module_name]["compilation"] |
| 118 | + |
| 119 | + # Apply dynamic specialization updates if provided |
| 120 | + if specialization_updates and module_name in specialization_updates: |
| 121 | + specializations.update(specialization_updates[module_name]) |
| 122 | + |
| 123 | + # Compile the module to QPC format |
| 124 | + module_obj.compile(specializations=[specializations], **compile_kwargs) |
| 125 | + |
| 126 | + |
54 | 127 | @dataclass(frozen=True) |
55 | 128 | class ModulePerf: |
| 129 | + """ |
| 130 | + Data class to store performance metrics for a pipeline module. |
| 131 | +
|
| 132 | + Attributes: |
| 133 | + module_name: Name of the pipeline module (e.g., 'text_encoder', 'transformer', 'vae_decoder') |
| 134 | + perf: Performance metric in seconds. Can be a single float for modules that run once, |
| 135 | + or a list of floats for modules that run multiple times (e.g., transformer steps) |
| 136 | + """ |
| 137 | + |
56 | 138 | module_name: str |
57 | 139 | perf: int |
58 | 140 |
|
59 | 141 |
|
60 | 142 | @dataclass(frozen=True) |
61 | 143 | class QEffPipelineOutput: |
| 144 | + """ |
| 145 | + Data class to store the output of a QEfficient diffusion pipeline. |
| 146 | +
|
| 147 | + Attributes: |
| 148 | + pipeline_module: List of ModulePerf objects containing performance metrics for each module |
| 149 | + images: Generated images as either a list of PIL Images or numpy array |
| 150 | + """ |
| 151 | + |
62 | 152 | pipeline_module: list[ModulePerf] |
63 | 153 | images: Union[List[PIL.Image.Image], np.ndarray] |
64 | 154 |
|
|
0 commit comments