Skip to content

Commit 175d956

Browse files
authored
Add left aligned cache support. (#133)
* Add left aligned cache insertion support. * Fix ring buffer config propogation issue; Fix the left aligned insert no return issue; * Updates the generate function to support left aligned cache. * Fix the cache insertion connection issue and insertion error and related tests. * Fix tests and lint errors. * Fix lint issues.
1 parent a9343b9 commit 175d956

File tree

6 files changed

+113
-43
lines changed

6 files changed

+113
-43
lines changed

jetstream_pt/cache_manager.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,33 @@ def __init__(
9191
cache_v: torch.Tensor, # previous cache
9292
position: int, # position to store the cache
9393
sharding,
94+
env=None,
9495
):
9596
super().__init__()
9697
self.cache_k = cache_k
9798
self.cache_v = cache_v
9899
self.pos = position
99100
self.sharding = sharding
101+
self.env = env
100102

101103
def update(self, key, value):
102104
"""Update kv cache"""
103105
keyj, valuej = torchjax.to_torch((key, value))
104-
# pylint: disable-next=all
105-
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
106-
# pylint: disable-next=all
107-
self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej)
106+
if self.env.ring_buffer:
107+
# pylint: disable-next=all
108+
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
109+
# pylint: disable-next=all
110+
self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej)
111+
else:
112+
batch = jnp.arange(self.env.batch_size)
113+
# pylint: disable-next=all
114+
self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set(
115+
keyj.squeeze(2)
116+
)
117+
# pylint: disable-next=all
118+
self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set(
119+
valuej.squeeze(2)
120+
)
108121
return self.cache_k, self.cache_v
109122

110123
def state(self):
@@ -113,13 +126,13 @@ def state(self):
113126
return self.cache_k.jax(), self.cache_v.jax()
114127

115128
@classmethod
116-
def empty(cls, shape, device, bf16_enable):
129+
def empty(cls, shape, device, bf16_enable, env):
117130
"""Create empty kv caches"""
118131
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
119132
k = jnp.zeros(shape, device=device, dtype=default_dtype)
120133
v = jnp.zeros(shape, device=device, dtype=default_dtype)
121134
k, v = torchjax.to_torch((k, v))
122-
return cls(k, v, 0, device)
135+
return cls(k, v, 0, device, env=env)
123136

124137

125138
# pylint: disable-next=all
@@ -155,6 +168,7 @@ def __init__(
155168
cache_v_scaler,
156169
input_pos, # used to write cache
157170
sharding=None,
171+
env=None,
158172
):
159173
super().__init__()
160174
self.cache_k = cache_k
@@ -163,6 +177,7 @@ def __init__(
163177
self.v_scaler = cache_v_scaler
164178
self.input_pos = input_pos
165179
self.sharding = sharding
180+
self.env = env
166181

167182
def state(self):
168183
"""Get kv cache state"""
@@ -174,7 +189,7 @@ def scalers(self):
174189

175190
@classmethod
176191
# pylint: disable-next=all
177-
def empty(cls, shape, device, bf16_enable):
192+
def empty(cls, shape, device, bf16_enable, env):
178193
"""Create empty kv caches"""
179194
cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8)
180195
cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8)
@@ -185,7 +200,7 @@ def empty(cls, shape, device, bf16_enable):
185200
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
186201
(cache_k, cache_v, kscaler, vscaler)
187202
)
188-
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)
203+
return cls(cache_k, cache_v, kscaler, vscaler, 0, device, env=env)
189204

190205
def quantize(self, val):
191206
"""Quantize value"""
@@ -198,8 +213,15 @@ def update(self, xk, xv):
198213
"""Update kv cache"""
199214
k_quant, kscale = self.quantize(xk)
200215
v_quant, vscale = self.quantize(xv)
201-
self.cache_k[:, :, self.input_pos, :] = k_quant
202-
self.cache_v[:, :, self.input_pos, :] = v_quant
203-
self.k_scaler[:, :, self.input_pos, :] = kscale
204-
self.v_scaler[:, :, self.input_pos, :] = vscale
216+
if self.env.ring_buffer:
217+
self.cache_k[:, :, self.input_pos, :] = k_quant
218+
self.cache_v[:, :, self.input_pos, :] = v_quant
219+
self.k_scaler[:, :, self.input_pos, :] = kscale
220+
self.v_scaler[:, :, self.input_pos, :] = vscale
221+
else:
222+
batch = jnp.arange(self.env.batch_size)
223+
self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2)
224+
self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2)
225+
self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2)
226+
self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2)
205227
return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler

jetstream_pt/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@
8383
"for performance tuning and debugging only",
8484
required=False,
8585
)
86+
flags.DEFINE_bool(
87+
"ring_buffer",
88+
True,
89+
"Whether to enable ring buffer",
90+
required=False,
91+
)
8692
flags.DEFINE_float(
8793
"temperature",
8894
1.0,
@@ -175,6 +181,7 @@ def create_engine_from_config_flags():
175181
sampling_algorithm=FLAGS.sampling_algorithm,
176182
nucleus_topp=FLAGS.nucleus_topp,
177183
topk=FLAGS.topk,
184+
ring_buffer=FLAGS.ring_buffer,
178185
)
179186

180187
print("Initialize engine", time.perf_counter() - start)

jetstream_pt/engine.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class DecodeState:
6565
Tuple[jax.Array, jax.Array]
6666
] # only present in quantized kv
6767
current_position: int
68-
lens: jax.Array # [batch_size, 1]
68+
lens: jax.Array # [batch_size, 1], the output token length
6969
start: jax.Array # [batch_size, 1], the starting pos for each slot
7070
input_pos: jax.Array # [batch_size, 1] input pos for each slot
7171
mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid
@@ -157,15 +157,17 @@ def _call_model_generate(
157157
):
158158
if self.env.quant_config.enable_kv_quantization:
159159
caches_obj = [
160-
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
160+
cache_manager.Int8KVCacheGenerate(
161+
k, v, ks, vs, input_indexes, env=self.env
162+
)
161163
for (k, v), (ks, vs) in torchjax.to_torch(
162164
list(zip(caches, cache_scales))
163165
)
164166
]
165167
else:
166168
caches_obj = [
167169
cache_manager.KVCacheGenerate(
168-
k, v, input_indexes, self.cache_sharding
170+
k, v, input_indexes, self.cache_sharding, env=self.env
169171
)
170172
for k, v in torchjax.to_torch(caches)
171173
]
@@ -295,11 +297,16 @@ def _insert_no_wrap(
295297
):
296298
scales = []
297299
caches = []
298-
pos = decode_state.current_position - prefix.seq_len
300+
if self.env.ring_buffer:
301+
current_pos = decode_state.current_position
302+
else:
303+
current_pos = prefix.seq_len
304+
305+
pos = current_pos - prefix.seq_len
299306
tokens = decode_state.tokens.at[slot].set(prefix.token)
300307

301308
x = jnp.arange(0, self.env.cache_sequence_length)
302-
cond = jnp.logical_and(x <= decode_state.current_position, x >= pos)
309+
cond = jnp.logical_and(x <= current_pos, x >= pos)
303310
mask_insert = jnp.where(cond, 0, float("-inf"))
304311
mask = decode_state.mask.at[slot].set(mask_insert)
305312
start = decode_state.start.at[slot].set(
@@ -470,18 +477,22 @@ def insert(
470477
# prefix,
471478
# decode_state,
472479
# )
473-
start_insert = decode_state.current_position - prefix.seq_len
474-
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen
475-
return jax.lax.cond(
476-
jnp.logical_and(
477-
start_insert >= 0, end_insert < self.env.cache_sequence_length
478-
),
479-
self._insert_no_wrap,
480-
self._insert_wrap,
481-
prefix,
482-
decode_state,
483-
slot,
484-
)
480+
if self.env.ring_buffer:
481+
start_insert = decode_state.current_position - prefix.seq_len
482+
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen
483+
return jax.lax.cond(
484+
jnp.logical_and(
485+
start_insert >= 0, end_insert < self.env.cache_sequence_length
486+
),
487+
self._insert_no_wrap,
488+
self._insert_wrap,
489+
prefix,
490+
decode_state,
491+
slot,
492+
)
493+
# Left aligned, starts from 0, guaranteed no wrap
494+
else:
495+
return self._insert_no_wrap(prefix, decode_state, slot)
485496

486497
def precompute_ragged_block_indices(self, decode_state: DecodeState):
487498
"""Precompute the ragged attention block indices. Ragged attention iterates the grid
@@ -545,10 +556,13 @@ def generate(
545556
) -> tuple[DecodeState, engine_api.ResultTokens]:
546557
# seq_len = padded_tokens.shape[0]
547558
pos = decode_state.current_position
548-
input_indexes = jnp.full((1,), pos)
549-
550-
# fill mask first
551-
mask = decode_state.mask.at[:, decode_state.current_position].set(0)
559+
if self.env.ring_buffer:
560+
input_indexes = jnp.full((1,), pos)
561+
mask = decode_state.mask.at[:, decode_state.current_position].set(0)
562+
else:
563+
input_indexes = decode_state.input_pos
564+
batch = jnp.arange(self.env.batch_size)
565+
mask = decode_state.mask.at[batch, decode_state.input_pos].set(0)
552566
ragged_batch_index, ragged_block_index = (
553567
self.precompute_ragged_block_indices(decode_state)
554568
)
@@ -570,7 +584,19 @@ def generate(
570584
)
571585

572586
next_token = self._sampling(logits, self.env.batch_size)
573-
lens = decode_state.lens + 1
587+
if self.env.ring_buffer:
588+
input_pos = decode_state.input_pos + 1
589+
lens = decode_state.lens + 1
590+
else:
591+
input_pos = jnp.where(
592+
decode_state.input_pos == 0,
593+
0,
594+
decode_state.input_pos + 1 % self.env.cache_len,
595+
)
596+
lens = jnp.where(
597+
decode_state.lens == 0, 0, decode_state.lens + 1 % self.env.cache_len
598+
)
599+
574600
data = jnp.concatenate(
575601
[
576602
decode_state.tokens,
@@ -597,15 +623,14 @@ def generate(
597623
(decode_state.current_position + 1) % self.env.cache_sequence_length,
598624
lens,
599625
decode_state.start,
600-
decode_state.input_pos + 1,
626+
input_pos,
601627
mask,
602628
)
603629
print(
604630
"new_pos",
605631
(decode_state.current_position + 1) % self.env.cache_sequence_length,
606632
)
607-
print("cache_seq_len", self.env.cache_sequence_length)
608-
633+
print(f"new_token: {jnp.squeeze(next_token)}")
609634
return new_decode_state, result_tokens
610635

611636
# pylint: disable-next=all
@@ -782,6 +807,7 @@ def create_pytorch_engine(
782807
sampling_algorithm="greedy",
783808
nucleus_topp=None,
784809
topk=None,
810+
ring_buffer=True,
785811
) -> PyTorchEngine:
786812
"""Returns: The pytorch engine."""
787813

@@ -851,6 +877,7 @@ def create_pytorch_engine(
851877
sampling_algorithm=sampling_algorithm,
852878
nucleus_topp=nucleus_topp,
853879
topk=topk,
880+
ring_buffer=ring_buffer,
854881
)
855882

856883
if shard_on_batch and sharding_config:

jetstream_pt/environment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class JetEngineEnvironmentData:
100100
# Starting position
101101
starting_position: int = 512
102102

103+
# Ring buffer
104+
ring_buffer: bool = True
105+
103106
# Variables used in token sampling
104107
# sampling algorithm to use ("greedy", "weighted", "neucleus", "topk")
105108
sampling_algorithm: str = "greedy"
@@ -120,11 +123,13 @@ class JetEngineEnvironment:
120123
def __init__(self, data: JetEngineEnvironmentData):
121124
self._data = data
122125

126+
self.batch_size = self._data.batch_size
123127
self.seq_len = self._data.max_input_sequence_length
124128
self.cache_len = self._data.cache_sequence_length
125129
self.ragged_mha = self._data.ragged_mha
126130
self.block_size = self._data.block_size
127131
self.starting_position = self._data.starting_position
132+
self.ring_buffer = self._data.ring_buffer
128133
P = jax.sharding.PartitionSpec
129134

130135
num_of_partitions = jax.device_count()
@@ -202,13 +207,13 @@ def make_caches_generate(self):
202207
if self._data.quant_config.enable_kv_quantization:
203208
caches.append(
204209
cache_manager.Int8KVCacheGenerate.empty(
205-
shape, self.cache_sharding, self.bf16_enable
210+
shape, self.cache_sharding, self.bf16_enable, env=self
206211
)
207212
)
208213
else:
209214
caches.append(
210215
cache_manager.KVCacheGenerate.empty(
211-
shape, self.cache_sharding, self.bf16_enable
216+
shape, self.cache_sharding, self.bf16_enable, env=self
212217
)
213218
)
214219
return caches

tests/test_model_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _make_one_cache_for_generate(self, env, pos):
8585
(cache_array_k, cache_array_v)
8686
)
8787
cache_decode = cache_manager.KVCacheGenerate(
88-
cache_array_k, cache_array_v, pos, None
88+
cache_array_k, cache_array_v, pos, None, env
8989
)
9090
return cache_decode
9191

tests/test_quantization.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def test_kv_cache(self):
7272
"""test kv cache quantization"""
7373
cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim
7474
with jax.default_device(jax.devices("cpu")[0]):
75-
cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, False)
75+
env, _ = helpers.make_env_tiny()
76+
cache = cache_manager.Int8KVCacheGenerate.empty(
77+
cache_shape, None, False, env
78+
)
7679
# seqlen is 1
7780
k = self._xla_tensor((3, 2, 1, 2))
7881
v = self._xla_tensor((3, 2, 1, 2))
@@ -101,7 +104,7 @@ def test_kv_kernel(self):
101104

102105
cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax))
103106

104-
cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None)
107+
cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env)
105108

106109
# 1 is seqlen
107110
xq = jax.random.normal(key, (3, 2, 1, 2))
@@ -119,7 +122,13 @@ def test_kv_kernel(self):
119122
cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3))
120123
cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3))
121124
cache_int = cache_manager.Int8KVCacheGenerate(
122-
cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, [0], None
125+
cache_k_int,
126+
cache_v_int,
127+
cache_k_scaler,
128+
cache_v_scaler,
129+
[0],
130+
None,
131+
env,
123132
)
124133
attention_quant = layers.Int8KVAttentionKernel(env)
125134
int_res = attention_quant(xq, xk, xv, None, cache_int)

0 commit comments

Comments
 (0)