File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
torchao/prototype/tensor_conversion Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -124,9 +124,16 @@ def _find_tied_params(model):
124124
125125
126126def _convert_model_for_aarch64 (
127- model , * , tensor_type = "auto" , intx_packing_format = "opaque_torchao_auto"
127+ model ,
128+ * ,
129+ tensor_type = "auto" ,
130+ intx_packing_format = "opaque_torchao_auto" ,
131+ convert_tied_embedding = True ,
132+ convert_linear = True ,
128133):
129- module_name_to_tied_param = _find_tied_params (model )
134+ module_name_to_tied_param = (
135+ _find_tied_params (model ) if convert_tied_embedding else {}
136+ )
130137
131138 # Iterate through modules in model and convert IntxUnpackedToInt8Tensor tensors to Int8LutTensor
132139 for name , module in model .named_modules ():
@@ -138,7 +145,7 @@ def _convert_model_for_aarch64(
138145 print ("Skipping converting nn.Embedding {name} because it is not tied" )
139146 continue
140147
141- if not isinstance (module , nn .Linear ):
148+ if not ( convert_linear and isinstance (module , nn .Linear ) ):
142149 continue
143150
144151 weight = module .weight
You can’t perform that action at this time.
0 commit comments