Skip to content

Commit a220e57

Browse files
authored
[gaudi] HuggingFaceM4/idefics2-8b issue fix (#3264)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent e07056a commit a220e57

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(self, prefix: str, config, weights, layer_id):
111111
)
112112
self.num_heads = config.num_attention_heads
113113
self.hidden_size = config.hidden_size
114-
if hasattr(config, "head_dim"):
114+
if getattr(config, "head_dim", None) is not None:
115115
self.head_size = config.head_dim
116116
else:
117117
self.head_size = self.hidden_size // self.num_heads

backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,8 +1050,6 @@ def forward(
10501050
attention_mask=attention_mask_forward,
10511051
**kwargs,
10521052
)
1053-
if batch.prefill_cache_indices is not None:
1054-
batch.prefill_cache_indices = None
10551053
batch.image_grid_thw = None
10561054
batch.free_encoder_cache()
10571055
return logits, speculative_logits

backends/gaudi/server/text_generation_server/utils/debug.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import glob
55
import time
66

7-
from optimum.habana.utils import to_gb_rounded
87
import habana_frameworks.torch as htorch
8+
import numpy as np
99

1010
START_TS = None
1111
DBG_TRACE_FILENAME = os.environ.get("DBG_TRACE_FILENAME")
@@ -14,6 +14,19 @@
1414
os.remove(f)
1515

1616

17+
def to_gb_rounded(mem: float) -> float:
18+
"""
19+
Rounds and converts to GB.
20+
21+
Args:
22+
mem (float): memory in bytes
23+
24+
Returns:
25+
float: memory in GB rounded to the second decimal
26+
"""
27+
return np.round(mem / 1024**3, 2)
28+
29+
1730
def count_hpu_graphs():
1831
return len(glob.glob(".graph_dumps/*PreGraph*"))
1932

0 commit comments

Comments
 (0)