66# ----------------------------------------------------------------------------
77
88import copy
9- import hashlib
109
1110import torch
1211import torch .nn as nn
2322 T5ModelTransform ,
2423)
2524from QEfficient .utils import constants
26- from QEfficient .utils .cache import to_hashable
2725
2826
2927class QEffTextEncoder (QEFFBaseModel ):
@@ -57,6 +55,18 @@ def get_onnx_config(self):
5755
5856 return example_inputs , dynamic_axes , output_names
5957
58+ @property
59+ def get_model_config (self ) -> dict :
60+ """
61+ Get the model configuration as a dictionary.
62+
63+ Returns
64+ -------
65+ dict
66+ The configuration dictionary of the underlying HuggingFace model.
67+ """
68+ return self .model .config .__dict__
69+
6070 def export (
6171 self ,
6272 inputs ,
@@ -76,15 +86,6 @@ def export(
7686 def compile (self , specializations , ** compiler_options ):
7787 self ._compile (specializations = specializations , ** compiler_options )
7888
79- @property
80- def model_hash (self ) -> str :
81- # Compute the hash with: model_config, continuous_batching, transforms
82- mhash = hashlib .sha256 ()
83- mhash .update (to_hashable (self .model .config .to_diff_dict ()))
84- mhash .update (to_hashable (self ._transform_names ()))
85- mhash = mhash .hexdigest ()[:16 ]
86- return mhash
87-
8889 @property
8990 def model_name (self ) -> str :
9091 mname = self .model .__class__ .__name__
@@ -125,18 +126,21 @@ def export(
125126 export_kwargs = export_kwargs ,
126127 )
127128
129+ @property
130+ def get_model_config (self ) -> dict :
131+ """
132+ Get the model configuration as a dictionary.
133+
134+ Returns
135+ -------
136+ dict
137+ The configuration dictionary of the underlying HuggingFace model.
138+ """
139+ return self .model .config .__dict__
140+
128141 def compile (self , specializations , ** compiler_options ):
129142 self ._compile (specializations = specializations , ** compiler_options )
130143
131- @property
132- def model_hash (self ) -> str :
133- # Compute the hash with: model_config, continuous_batching, transforms
134- mhash = hashlib .sha256 ()
135- mhash .update (to_hashable (dict (self .model .config )))
136- mhash .update (to_hashable (self ._transform_names ()))
137- mhash = mhash .hexdigest ()[:16 ]
138- return mhash
139-
140144 @property
141145 def model_name (self ) -> str :
142146 mname = self .model .__class__ .__name__
@@ -197,14 +201,16 @@ def compile(self, specializations, **compiler_options):
197201 self ._compile (specializations = specializations , ** compiler_options )
198202
199203 @property
200- def model_hash (self ) -> str :
201- # Compute the hash with: model_config, continuous_batching, transforms
202- mhash = hashlib .sha256 ()
203- mhash .update (to_hashable (dict (self .model .config )))
204- mhash .update (to_hashable (self ._transform_names ()))
205- mhash .update (to_hashable (self .type ))
206- mhash = mhash .hexdigest ()[:16 ]
207- return mhash
204+ def get_model_config (self ) -> dict :
205+ """
206+ Get the model configuration as a dictionary.
207+
208+ Returns
209+ -------
210+ dict
211+ The configuration dictionary of the underlying HuggingFace model.
212+ """
213+ return self .model .config .__dict__
208214
209215 @property
210216 def model_name (self ) -> str :
@@ -250,13 +256,16 @@ def compile(self, specializations, **compiler_options):
250256 self ._compile (specializations = specializations , ** compiler_options )
251257
252258 @property
253- def model_hash (self ) -> str :
254- # Compute the hash with: model_config, continuous_batching, transforms
255- mhash = hashlib .sha256 ()
256- mhash .update (to_hashable (self .model .config .to_diff_dict ()))
257- mhash .update (to_hashable (self ._transform_names ()))
258- mhash = mhash .hexdigest ()[:16 ]
259- return mhash
259+ def get_model_config (self ) -> dict :
260+ """
261+ Get the model configuration as a dictionary.
262+
263+ Returns
264+ -------
265+ dict
266+ The configuration dictionary of the underlying HuggingFace model.
267+ """
268+ return self .model .config .__dict__
260269
261270 @property
262271 def model_name (self ) -> str :
@@ -282,7 +291,8 @@ def __init__(self, model: nn.modules, use_onnx_function):
282291 if use_onnx_function :
283292 self ._pytorch_transforms .append (OnnxFunctionTransform )
284293 model , _ = OnnxFunctionTransform .apply (model )
285- self .model = model
294+ # Ensure the model and all its submodules are on CPU to avoid meta device issues
295+ self .model = model .to ("cpu" )
286296
287297 def get_onnx_config (self , batch_size = 1 , seq_length = 256 , cl = 4096 ):
288298 example_inputs = {
@@ -313,6 +323,18 @@ def get_onnx_config(self, batch_size=1, seq_length=256, cl=4096):
313323
314324 return example_inputs , dynamic_axes , output_names
315325
326+ @property
327+ def get_model_config (self ) -> dict :
328+ """
329+ Get the model configuration as a dictionary.
330+
331+ Returns
332+ -------
333+ dict
334+ The configuration dictionary of the underlying HuggingFace model.
335+ """
336+ return self .model .config .__dict__
337+
316338 def export (
317339 self ,
318340 inputs ,
@@ -347,17 +369,6 @@ def get_specializations(self, batch_size: int, seq_len: int, cl: int):
347369 def compile (self , specializations , ** compiler_options ):
348370 self ._compile (specializations = specializations , ** compiler_options )
349371
350- @property
351- def model_hash (self ) -> str :
352- # Compute the hash with: model_config, continuous_batching, transforms
353- mhash = hashlib .sha256 ()
354- dict_model_config = dict (self .model .config )
355- dict_model_config .pop ("_use_default_values" , None )
356- mhash .update (to_hashable (dict_model_config ))
357- mhash .update (to_hashable (self ._transform_names ()))
358- mhash = mhash .hexdigest ()[:16 ]
359- return mhash
360-
361372 @property
362373 def model_name (self ) -> str :
363374 mname = self .model .__class__ .__name__
0 commit comments