Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit b1168c1

Browse files
authored
Fix SQ bloom (#1636)
Signed-off-by: changwangss <chang1.wang@intel.com>
1 parent f44bf95 commit b1168c1

File tree

3 files changed

+3
-12
lines changed

3 files changed

+3
-12
lines changed

intel_extension_for_transformers/transformers/llm/evaluation/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def _reorder_cache(
3838
3939
This is required to match `past_key_values` with the correct beam_idx at every generation step.
4040
"""
41-
if self.config.model_type == "bloom":
42-
return self._reorder_cache_bloom(past_key_values, beam_idx)
41+
4342
if self.config.model_type == "chatglm":
4443
return tuple(
4544
tuple(

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -946,11 +946,7 @@ def collate_batch(batch):
946946
)
947947

948948
last_ind.append(input_ids.shape[0] - 1)
949-
if model_type in ["bloom"]:
950-
attention_mask = torch.ones(len(input_ids) + 1)
951-
attention_mask[0] = 0
952-
else:
953-
attention_mask = torch.ones(len(input_ids))
949+
attention_mask = torch.ones(len(input_ids))
954950
position_ids = torch.arange(len(input_ids))
955951
input_ids_padded.append(input_ids)
956952
attention_mask_padded.append(attention_mask)

intel_extension_for_transformers/transformers/utils/utility.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -375,11 +375,7 @@ def get_example_inputs(model_config, batch_size=1, tokenizer=None, num_beams=4):
375375
past_key_values = generate_dummy_past_key_values(config=model_config, input_bs=batch_size)
376376

377377
input_ids = input_ids[:, :512]
378-
if model_type in ["bloom", "qwen"]:
379-
attention_mask = torch.ones(input_ids.shape[0], input_ids.shape[1] + 1)
380-
attention_mask[:,0] = 0
381-
else:
382-
attention_mask = torch.ones(input_ids.shape)
378+
attention_mask = torch.ones(input_ids.shape)
383379
position_ids = torch.arange(input_ids.shape[1]).repeat(batch_size, 1)
384380

385381
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:

0 commit comments

Comments
 (0)