Skip to content

Commit 533eee5

Browse files
authored
forward and tokenize chooser use the same shape (#3196)
* forward and tokenize chooser use the same shape concate or filter happened to cpu tensor to avoid dynamic shape in hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use hpu set seed Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 51a0b9d commit 533eee5

File tree

6 files changed

+376
-481
lines changed

6 files changed

+376
-481
lines changed

backends/gaudi/server/text_generation_server/layers/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
HPUPagedAttentionMetadata,
44
trim_attn_metadata,
55
trim_seqlen_metadata,
6+
_async_h2d_tensor_copy,
67
)
78

89
from .hpu import (
@@ -25,4 +26,5 @@
2526
"HPUPagedAttentionMetadata",
2627
"trim_seqlen_metadata",
2728
"trim_attn_metadata",
29+
"_async_h2d_tensor_copy",
2830
]

backends/gaudi/server/text_generation_server/layers/attention/common.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,42 +75,27 @@ def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
7575
@dataclass
7676
class Seqlen:
7777
input_lengths: torch.Tensor
78-
cache_lengths: torch.Tensor
79-
cu_seqlen_q: Optional[torch.Tensor]
80-
cu_seqlen_k: Optional[torch.Tensor]
8178

8279
def __init__(
8380
self,
8481
input_lengths,
85-
cache_lengths,
86-
cu_seqlen_q=None,
8782
):
8883
self.input_lengths = input_lengths
89-
self.cache_lengths = cache_lengths
90-
device = self.input_lengths.device
91-
shape = self.input_lengths.shape
92-
if cu_seqlen_q is None:
93-
cu_seqlen_q = torch.arange(
94-
shape[0] + 1,
95-
device=device,
96-
dtype=torch.int32,
97-
)
98-
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
99-
100-
# cuda graphs don't like this and this is necessary to clamp within mistral
101-
# Although FA2 might not want the clamping
102-
# cu_seqlen_k[0] = 0
103-
total = self.input_lengths + self.cache_lengths
104-
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
105-
106-
self.cu_seqlen_q = cu_seqlen_q
107-
self.cu_seqlen_k = cu_seqlen_k
10884

10985
def clamp(self, max):
11086
# Flash decoding doesn't need to clamp
11187
return self
11288

11389

90+
def _async_h2d_tensor_copy(source, device="hpu"):
91+
if source is None:
92+
return None
93+
assert source.device.type == "cpu", "Source tensor is not present in host memory!"
94+
target = torch.empty(source.shape, dtype=source.dtype, device=device)
95+
target.copy_(source, non_blocking=True)
96+
return target
97+
98+
11499
def trim_seqlen_metadata(metadata: Seqlen) -> object:
115100
# NOTE(kzawora): To anyone working on this in the future:
116101
# Trimming metadata is required when using HPUGraphs.
@@ -137,9 +122,6 @@ def trim_seqlen_metadata(metadata: Seqlen) -> object:
137122
"TrimmedSeqlen",
138123
[
139124
"input_lengths",
140-
"cache_lengths",
141-
"cu_seqlen_q",
142-
"cu_seqlen_k",
143125
],
144126
)
145127
return attention_metadata

0 commit comments

Comments
 (0)