Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 3f492c4

Browse files
SpycshVincyZhang
andauthored
catch prepack error and fallback tor torch bf16 (#1526)
Co-authored-by: VincyZhang <wenxin.zhang@intel.com>
1 parent ba199dc commit 3f492c4

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

intel_extension_for_transformers/neural_chat/models/model_utils.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -839,13 +839,27 @@ def load_model(
839839
import intel_extension_for_pytorch as intel_ipex
840840

841841
if not use_tpp:
842-
model = intel_ipex.optimize(
843-
model.eval(),
844-
dtype=torch_dtype,
845-
inplace=True,
846-
level="O1",
847-
auto_kernel_selection=True,
848-
)
842+
try:
843+
model = intel_ipex.optimize(
844+
model.eval(),
845+
dtype=torch_dtype,
846+
inplace=True,
847+
level="O1",
848+
auto_kernel_selection=True,
849+
)
850+
except AssertionError:
851+
model = intel_ipex.optimize(
852+
model.eval(),
853+
dtype=torch_dtype,
854+
inplace=True,
855+
level="O1",
856+
auto_kernel_selection=True,
857+
weights_prepack=False,
858+
)
859+
except Exception as e:
860+
logging.info(f"IPEX optimize failure! Skip IPEX.")
861+
model = model.eval()
862+
849863
if cpu_jit and (re.search("mpt-7b", model_name, re.IGNORECASE)
850864
or re.search("neural-chat-7b-v1", model_name, re.IGNORECASE)):
851865
from intel_extension_for_transformers.transformers.llm.utils.mpt_trace import \

intel_extension_for_transformers/neural_chat/pipeline/plugins/retrieval/retrieval_agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,15 @@ def __init__(self,
154154
import torch
155155
import intel_extension_for_pytorch as ipex
156156
if precision == "bf16" and CpuInfo().bf16:
157-
self.embeddings.client = ipex.optimize(
158-
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True)
157+
try:
158+
self.embeddings.client = ipex.optimize(
159+
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True)
160+
except AssertionError:
161+
self.embeddings.client = ipex.optimize(
162+
self.embeddings.client.eval(), dtype=torch.bfloat16, inplace=True, weights_prepack=False)
163+
except Exception as e:
164+
logging.info(f"IPEX optimize failure! Skip IPEX.")
165+
self.embeddings.client = self.embeddings.client.eval()
159166
elif precision == "fp32":
160167
self.embeddings.client = ipex.optimize(
161168
self.embeddings.client.eval(), dtype=torch.float32, inplace=True)

0 commit comments

Comments
 (0)