Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 63056ec

Browse files
authored
Improve MPT series SQ (#1640)
Signed-off-by: Wang, Chang <chang1.wang@intel.com>
1 parent a9a0e93 commit 63056ec

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,12 @@ def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, None]:
840840
or device_map == torch.device("cpu")
841841
) and model.config.model_type == "chatglm":
842842
model = model.float()
843+
if (
844+
not torch.cuda.is_available()
845+
or device_map == "cpu"
846+
or device_map == torch.device("cpu")
847+
) and model.config.model_type == "mpt":
848+
model.config.architectures = ["MptForCausalLM"]
843849
model.eval()
844850
model_type = model.config.model_type.replace("_", "-")
845851

@@ -1077,6 +1083,7 @@ def calib_func(model):
10771083
recipes=quantization_config.recipes,
10781084
example_inputs=example_inputs,
10791085
)
1086+
10801087
model = quantization.fit(
10811088
model,
10821089
conf,

0 commit comments

Comments
 (0)