|
5 | 5 | # |
6 | 6 | # ---------------------------------------------------------------------------- |
7 | 7 |
|
8 | | -import logging |
9 | 8 | import os |
10 | 9 | import time |
11 | 10 | from typing import Any, Callable, Dict, List, Optional, Union |
|
28 | 27 | set_module_device_ids, |
29 | 28 | ) |
30 | 29 | from QEfficient.generation.cloud_infer import QAICInferenceSession |
31 | | - |
32 | | -# Initialize logger for this module |
33 | | -logger = logging.getLogger(__name__) |
| 30 | +from QEfficient.utils.logging_utils import logger |
34 | 31 |
|
35 | 32 |
|
36 | 33 | class QEFFFluxPipeline(FluxPipeline): |
@@ -224,18 +221,18 @@ def compile(self, compile_config: Optional[str] = None) -> None: |
224 | 221 | # Compile each module with its specific configuration |
225 | 222 | for module_name, module_obj in self.modules.items(): |
226 | 223 | module_config = self.custom_config["modules"] |
227 | | - specializations = [module_config[module_name]["specializations"]] |
| 224 | + specializations = module_config[module_name]["specializations"] |
228 | 225 | compile_kwargs = module_config[module_name]["compilation"] |
229 | 226 |
|
230 | 227 | # Set dynamic specialization values based on image dimensions |
231 | 228 | if module_name == "transformer": |
232 | | - specializations[0]["cl"] = self.cl |
| 229 | + specializations["cl"] = self.cl |
233 | 230 | elif module_name == "vae_decoder": |
234 | | - specializations[0]["latent_height"] = self.latent_height |
235 | | - specializations[0]["latent_width"] = self.latent_width |
| 231 | + specializations["latent_height"] = self.latent_height |
| 232 | + specializations["latent_width"] = self.latent_width |
236 | 233 |
|
237 | 234 | # Compile the module to QPC format |
238 | | - module_obj.compile(specializations=specializations, **compile_kwargs) |
| 235 | + module_obj.compile(specializations=[specializations], **compile_kwargs) |
239 | 236 |
|
240 | 237 | def _get_t5_prompt_embeds( |
241 | 238 | self, |
|
0 commit comments