@@ -91,20 +91,33 @@ def __init__(
9191 cache_v : torch .Tensor , # previous cache
9292 position : int , # position to store the cache
9393 sharding ,
94+ env = None ,
9495 ):
9596 super ().__init__ ()
9697 self .cache_k = cache_k
9798 self .cache_v = cache_v
9899 self .pos = position
99100 self .sharding = sharding
101+ self .env = env
100102
101103 def update (self , key , value ):
102104 """Update kv cache"""
103105 keyj , valuej = torchjax .to_torch ((key , value ))
104- # pylint: disable-next=all
105- self .cache_k ._elem = self .cache_k ._elem .at [:, :, self .pos ].set (keyj )
106- # pylint: disable-next=all
107- self .cache_v ._elem = self .cache_v ._elem .at [:, :, self .pos ].set (valuej )
106+ if self .env .ring_buffer :
107+ # pylint: disable-next=all
108+ self .cache_k ._elem = self .cache_k ._elem .at [:, :, self .pos ].set (keyj )
109+ # pylint: disable-next=all
110+ self .cache_v ._elem = self .cache_v ._elem .at [:, :, self .pos ].set (valuej )
111+ else :
112+ batch = jnp .arange (self .env .batch_size )
113+ # pylint: disable-next=all
114+ self .cache_k ._elem = self .cache_k ._elem .at [batch , :, self .pos ].set (
115+ keyj .squeeze (2 )
116+ )
117+ # pylint: disable-next=all
118+ self .cache_v ._elem = self .cache_v ._elem .at [batch , :, self .pos ].set (
119+ valuej .squeeze (2 )
120+ )
108121 return self .cache_k , self .cache_v
109122
110123 def state (self ):
@@ -113,13 +126,13 @@ def state(self):
113126 return self .cache_k .jax (), self .cache_v .jax ()
114127
115128 @classmethod
116- def empty (cls , shape , device , bf16_enable ):
129+ def empty (cls , shape , device , bf16_enable , env ):
117130 """Create empty kv caches"""
118131 default_dtype = jnp .bfloat16 if bf16_enable else jnp .float32
119132 k = jnp .zeros (shape , device = device , dtype = default_dtype )
120133 v = jnp .zeros (shape , device = device , dtype = default_dtype )
121134 k , v = torchjax .to_torch ((k , v ))
122- return cls (k , v , 0 , device )
135+ return cls (k , v , 0 , device , env = env )
123136
124137
125138# pylint: disable-next=all
@@ -155,6 +168,7 @@ def __init__(
155168 cache_v_scaler ,
156169 input_pos , # used to write cache
157170 sharding = None ,
171+ env = None ,
158172 ):
159173 super ().__init__ ()
160174 self .cache_k = cache_k
@@ -163,6 +177,7 @@ def __init__(
163177 self .v_scaler = cache_v_scaler
164178 self .input_pos = input_pos
165179 self .sharding = sharding
180+ self .env = env
166181
167182 def state (self ):
168183 """Get kv cache state"""
@@ -174,7 +189,7 @@ def scalers(self):
174189
175190 @classmethod
176191 # pylint: disable-next=all
177- def empty (cls , shape , device , bf16_enable ):
192+ def empty (cls , shape , device , bf16_enable , env ):
178193 """Create empty kv caches"""
179194 cache_k = jnp .zeros (shape , device = device , dtype = jnp .int8 )
180195 cache_v = jnp .zeros (shape , device = device , dtype = jnp .int8 )
@@ -185,7 +200,7 @@ def empty(cls, shape, device, bf16_enable):
185200 cache_k , cache_v , kscaler , vscaler = torchjax .to_torch (
186201 (cache_k , cache_v , kscaler , vscaler )
187202 )
188- return cls (cache_k , cache_v , kscaler , vscaler , 0 , device )
203+ return cls (cache_k , cache_v , kscaler , vscaler , 0 , device , env = env )
189204
190205 def quantize (self , val ):
191206 """Quantize value"""
@@ -198,8 +213,15 @@ def update(self, xk, xv):
198213 """Update kv cache"""
199214 k_quant , kscale = self .quantize (xk )
200215 v_quant , vscale = self .quantize (xv )
201- self .cache_k [:, :, self .input_pos , :] = k_quant
202- self .cache_v [:, :, self .input_pos , :] = v_quant
203- self .k_scaler [:, :, self .input_pos , :] = kscale
204- self .v_scaler [:, :, self .input_pos , :] = vscale
216+ if self .env .ring_buffer :
217+ self .cache_k [:, :, self .input_pos , :] = k_quant
218+ self .cache_v [:, :, self .input_pos , :] = v_quant
219+ self .k_scaler [:, :, self .input_pos , :] = kscale
220+ self .v_scaler [:, :, self .input_pos , :] = vscale
221+ else :
222+ batch = jnp .arange (self .env .batch_size )
223+ self .cache_k [batch , :, self .input_pos , :] = k_quant .squeeze (2 )
224+ self .cache_v [batch , :, self .input_pos , :] = v_quant .squeeze (2 )
225+ self .k_scaler [batch , :, self .input_pos , :] = kscale .squeeze (2 )
226+ self .v_scaler [batch , :, self .input_pos , :] = vscale .squeeze (2 )
205227 return self .cache_k , self .cache_v , self .k_scaler , self .v_scaler
0 commit comments