@@ -87,6 +87,8 @@ def __call__(
8787 previous_chunk = None ,
8888 slot : None | int = None ,
8989 page_state : None | page_manager .PageState = None ,
90+ kv_cache : jax .Array | None = None ,
91+ attention_metadata : dict [str , Any ] | None = None ,
9092 ):
9193 cfg = self .config
9294 mesh = self .mesh
@@ -149,13 +151,15 @@ def __call__(
149151 model_mode = model_mode ,
150152 )
151153
152- attention_lnx = attention_layer (
154+ attention_lnx , kv_cache = attention_layer (
153155 lnx ,
154156 lnx ,
155157 decoder_positions ,
156158 decoder_segment_ids = decoder_segment_ids ,
157159 deterministic = deterministic ,
158160 model_mode = model_mode ,
161+ kv_cache = kv_cache ,
162+ attention_metadata = attention_metadata ,
159163 )
160164
161165 if model_mode == MODEL_MODE_PREFILL :
@@ -209,7 +213,10 @@ def __call__(
209213 jnp .sum (layer_output == 0 ) / jnp .size (layer_output ),
210214 )
211215
212- return layer_output , None if cfg .scan_layers else layer_output
216+ if cfg .scan_layers :
217+ return layer_output , None
218+ else :
219+ return layer_output , kv_cache
213220
214221
215222class SequentialBlockDecoderLayers (nn .Module ):
@@ -691,6 +698,8 @@ def __call__(
691698 bidirectional_mask : None | Any = None ,
692699 image_embeddings : None | jnp .ndarray = None ,
693700 image_masks : None | jnp .ndarray = None ,
701+ kv_caches : list [jax .Array ] | None = None ,
702+ attention_metadata = None ,
694703 ):
695704 cfg = self .config
696705 mesh = self .mesh
@@ -844,7 +853,8 @@ def __call__(
844853 # Iterate over the two layer groups (dense and MoE) and apply layer transformation
845854 for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
846855 for index in range (num_layers ):
847- y = layer (
856+ kv_cache = kv_caches [index ] if kv_caches is not None else None
857+ y , kv_cache = layer (
848858 config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
849859 )(
850860 y ,
@@ -855,7 +865,11 @@ def __call__(
855865 previous_chunk = previous_chunk ,
856866 page_state = page_state ,
857867 slot = slot ,
868+ kv_cache = kv_cache ,
869+ attention_metadata = attention_metadata ,
858870 )
871+ if kv_caches is not None and kv_cache is not None :
872+ kv_caches [index ] = kv_cache
859873 else :
860874 for lyr in range (cfg .num_decoder_layers ):
861875 RemattedBlockLayer = RemattedBlockLayers [0 ]
@@ -877,7 +891,8 @@ def __call__(
877891 layer = RemattedBlockLayer (
878892 config = cfg , mesh = mesh , name = f"layers_{ lyr } " , quant = self .quant , model_mode = self .model_mode , ** layer_kwargs
879893 )
880- y = layer (
894+ kv_cache = kv_caches [lyr ] if kv_caches is not None else None
895+ y , kv_cache = layer (
881896 y ,
882897 decoder_segment_ids ,
883898 decoder_positions ,
@@ -886,8 +901,12 @@ def __call__(
886901 previous_chunk = previous_chunk ,
887902 page_state = page_state ,
888903 slot = slot ,
904+ kv_cache = kv_cache ,
905+ attention_metadata = attention_metadata ,
889906 ** layer_call_kwargs ,
890907 )
908+ if kv_caches is not None and kv_cache is not None :
909+ kv_caches [lyr ] = kv_cache
891910
892911 assert isinstance (y , jax .Array )
893912
@@ -904,7 +923,7 @@ def __call__(
904923
905924 # The API of the Decoder is now a tuple, providing both the main output
906925 # and the raw hidden state needed for auxiliary tasks.
907- return logits , hidden_state
926+ return logits , hidden_state , kv_caches
908927
909928 def _apply_gemma3_scanned_blocks (
910929 self ,
@@ -957,10 +976,9 @@ def _apply_gemma3_scanned_blocks(
957976 if num_remaining_layers > 0 :
958977 # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions
959978 rem_layer_kwargs = {"num_of_layers" : num_remaining_layers }
960- # pytype: disable=wrong-keyword-args
961979 layer = RemattedGemma3Block (
962980 config = cfg , mesh = mesh , quant = self .quant , model_mode = self .model_mode , name = "layers_remainder" , ** rem_layer_kwargs
963- )
981+ ) # pytype: disable=wrong-keyword-args
964982 y , _ = layer (
965983 y ,
966984 decoder_segment_ids ,
0 commit comments