Skip to content

Commit b096c34

Browse files
author
Amit Raj
committed
Rebased with main and fixed some issues
Signed-off-by: Amit Raj <amitraj@qti.qualcommm.com>
1 parent b969ef8 commit b096c34

File tree

6 files changed

+65
-50
lines changed

6 files changed

+65
-50
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _export(
288288
self.onnx_path = onnx_path
289289
return onnx_path
290290

291-
# @dump_qconfig
291+
@dump_qconfig
292292
def _compile(
293293
self,
294294
onnx_path: Optional[str] = None,

QEfficient/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def __call__(
601601
continue
602602

603603
timestep = t.expand(latents.shape[0]).to(latents.dtype)
604-
temb = self.transformer.model.time_text_embed(timestep, pooled_prompt_embeds)
604+
temb = self.transformer.model.time_text_embed(timestep.cpu(), pooled_prompt_embeds.cpu())
605605

606606
adaln_emb = []
607607
for i in range(len(self.transformer.model.transformer_blocks)):

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# ----------------------------------------------------------------------------
77

88
import copy
9-
import hashlib
109

1110
import torch
1211
import torch.nn as nn
@@ -23,7 +22,6 @@
2322
T5ModelTransform,
2423
)
2524
from QEfficient.utils import constants
26-
from QEfficient.utils.cache import to_hashable
2725

2826

2927
class QEffTextEncoder(QEFFBaseModel):
@@ -57,6 +55,18 @@ def get_onnx_config(self):
5755

5856
return example_inputs, dynamic_axes, output_names
5957

58+
@property
59+
def get_model_config(self) -> dict:
60+
"""
61+
Get the model configuration as a dictionary.
62+
63+
Returns
64+
-------
65+
dict
66+
The configuration dictionary of the underlying HuggingFace model.
67+
"""
68+
return self.model.config.__dict__
69+
6070
def export(
6171
self,
6272
inputs,
@@ -76,15 +86,6 @@ def export(
7686
def compile(self, specializations, **compiler_options):
7787
self._compile(specializations=specializations, **compiler_options)
7888

79-
@property
80-
def model_hash(self) -> str:
81-
# Compute the hash with: model_config, continuous_batching, transforms
82-
mhash = hashlib.sha256()
83-
mhash.update(to_hashable(self.model.config.to_diff_dict()))
84-
mhash.update(to_hashable(self._transform_names()))
85-
mhash = mhash.hexdigest()[:16]
86-
return mhash
87-
8889
@property
8990
def model_name(self) -> str:
9091
mname = self.model.__class__.__name__
@@ -125,18 +126,21 @@ def export(
125126
export_kwargs=export_kwargs,
126127
)
127128

129+
@property
130+
def get_model_config(self) -> dict:
131+
"""
132+
Get the model configuration as a dictionary.
133+
134+
Returns
135+
-------
136+
dict
137+
The configuration dictionary of the underlying HuggingFace model.
138+
"""
139+
return self.model.config.__dict__
140+
128141
def compile(self, specializations, **compiler_options):
129142
self._compile(specializations=specializations, **compiler_options)
130143

131-
@property
132-
def model_hash(self) -> str:
133-
# Compute the hash with: model_config, continuous_batching, transforms
134-
mhash = hashlib.sha256()
135-
mhash.update(to_hashable(dict(self.model.config)))
136-
mhash.update(to_hashable(self._transform_names()))
137-
mhash = mhash.hexdigest()[:16]
138-
return mhash
139-
140144
@property
141145
def model_name(self) -> str:
142146
mname = self.model.__class__.__name__
@@ -197,14 +201,16 @@ def compile(self, specializations, **compiler_options):
197201
self._compile(specializations=specializations, **compiler_options)
198202

199203
@property
200-
def model_hash(self) -> str:
201-
# Compute the hash with: model_config, continuous_batching, transforms
202-
mhash = hashlib.sha256()
203-
mhash.update(to_hashable(dict(self.model.config)))
204-
mhash.update(to_hashable(self._transform_names()))
205-
mhash.update(to_hashable(self.type))
206-
mhash = mhash.hexdigest()[:16]
207-
return mhash
204+
def get_model_config(self) -> dict:
205+
"""
206+
Get the model configuration as a dictionary.
207+
208+
Returns
209+
-------
210+
dict
211+
The configuration dictionary of the underlying HuggingFace model.
212+
"""
213+
return self.model.config.__dict__
208214

209215
@property
210216
def model_name(self) -> str:
@@ -250,13 +256,16 @@ def compile(self, specializations, **compiler_options):
250256
self._compile(specializations=specializations, **compiler_options)
251257

252258
@property
253-
def model_hash(self) -> str:
254-
# Compute the hash with: model_config, continuous_batching, transforms
255-
mhash = hashlib.sha256()
256-
mhash.update(to_hashable(self.model.config.to_diff_dict()))
257-
mhash.update(to_hashable(self._transform_names()))
258-
mhash = mhash.hexdigest()[:16]
259-
return mhash
259+
def get_model_config(self) -> dict:
260+
"""
261+
Get the model configuration as a dictionary.
262+
263+
Returns
264+
-------
265+
dict
266+
The configuration dictionary of the underlying HuggingFace model.
267+
"""
268+
return self.model.config.__dict__
260269

261270
@property
262271
def model_name(self) -> str:
@@ -282,7 +291,8 @@ def __init__(self, model: nn.modules, use_onnx_function):
282291
if use_onnx_function:
283292
self._pytorch_transforms.append(OnnxFunctionTransform)
284293
model, _ = OnnxFunctionTransform.apply(model)
285-
self.model = model
294+
# Ensure the model and all its submodules are on CPU to avoid meta device issues
295+
self.model = model.to("cpu")
286296

287297
def get_onnx_config(self, batch_size=1, seq_length=256, cl=4096):
288298
example_inputs = {
@@ -313,6 +323,18 @@ def get_onnx_config(self, batch_size=1, seq_length=256, cl=4096):
313323

314324
return example_inputs, dynamic_axes, output_names
315325

326+
@property
327+
def get_model_config(self) -> dict:
328+
"""
329+
Get the model configuration as a dictionary.
330+
331+
Returns
332+
-------
333+
dict
334+
The configuration dictionary of the underlying HuggingFace model.
335+
"""
336+
return self.model.config.__dict__
337+
316338
def export(
317339
self,
318340
inputs,
@@ -347,17 +369,6 @@ def get_specializations(self, batch_size: int, seq_len: int, cl: int):
347369
def compile(self, specializations, **compiler_options):
348370
self._compile(specializations=specializations, **compiler_options)
349371

350-
@property
351-
def model_hash(self) -> str:
352-
# Compute the hash with: model_config, continuous_batching, transforms
353-
mhash = hashlib.sha256()
354-
dict_model_config = dict(self.model.config)
355-
dict_model_config.pop("_use_default_values", None)
356-
mhash.update(to_hashable(dict_model_config))
357-
mhash.update(to_hashable(self._transform_names()))
358-
mhash = mhash.hexdigest()[:16]
359-
return mhash
360-
361372
@property
362373
def model_name(self) -> str:
363374
mname = self.model.__class__.__name__

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def set_module_device_ids(cls):
5050
for module_name, module_obj in cls.modules.items():
5151
module_obj.device_ids = config_modules[module_name]["execute"]["device_ids"]
5252

53+
5354
@dataclass(frozen=True)
5455
class ModulePerf:
5556
module_name: str

QEfficient/utils/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def create_model_params(qeff_model, **kwargs) -> Dict:
532532
"""
533533
model_params = copy.deepcopy(kwargs)
534534
model_params = {k: v for k, v in model_params.items() if k in KWARGS_INCLUSION_LIST}
535-
model_params["config"] = qeff_model.model.config.to_diff_dict()
535+
model_params["config"] = qeff_model.model.config
536536
model_params["peft_config"] = getattr(qeff_model.model, "active_peft_config", None)
537537
model_params["applied_transform_names"] = qeff_model._transform_names()
538538
return model_params

QEfficient/utils/hash_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
def json_serializable(obj):
1616
if isinstance(obj, set):
1717
return sorted(obj)
18+
# Handle objects with to_dict() method (e.g., transformers config objects)
19+
if hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict")):
20+
return obj.to_dict()
1821
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
1922

2023

0 commit comments

Comments
 (0)