@@ -549,3 +549,122 @@ def update(
549549 ctx_v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
550550 v_out = torch .where ((is_sliding_layer & (position_ids .max () >= (layer_ctx_len - 1 ))), v_out , ctx_v_out )
551551 return k_out , v_out
552+
553+
554+ # This is a hack for now, until we get to merging this code with HybridCache class,
555+ # We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and
556+ # ours are made to work with AIC
557+ class QEffHybridCacheForGPTOSS :
558+ def __init__ (self , config , batch_size , max_cache_len , sliding_window_len ):
559+ self .max_cache_len = max_cache_len
560+ self .batch_size = batch_size
561+ self .sliding_window_len = sliding_window_len
562+ self .key_cache : List [torch .Tensor ] = []
563+ self .value_cache : List [torch .Tensor ] = []
564+
565+ @classmethod
566+ def from_legacy_cache (
567+ cls , config , past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None
568+ ) -> "HybridCache" :
569+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
570+ backward compatibility."""
571+ cache = cls (
572+ config ,
573+ batch_size = past_key_values [0 ][0 ].shape [0 ],
574+ max_cache_len = past_key_values [1 ][0 ].shape [2 ],
575+ sliding_window_len = past_key_values [0 ][0 ].shape [2 ],
576+ )
577+ if past_key_values is not None :
578+ for layer_idx in range (len (past_key_values )):
579+ key_states , value_states = past_key_values [layer_idx ]
580+ cache .update (key_states , value_states , layer_idx )
581+ return cache
582+
583+ def __len__ (self ):
584+ """
585+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
586+ to the number of layers in the model.
587+ """
588+ return len (self .key_cache )
589+
590+ def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
591+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
592+ # TODO: deprecate this function in favor of `cache_position`
593+ is_empty_layer = (
594+ len (self .key_cache ) == 0 # no cache in any layer
595+ or len (self .key_cache ) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
596+ or len (self .key_cache [layer_idx ]) == 0 # the layer has no cache
597+ )
598+ layer_seq_length = self .key_cache [layer_idx ].shape [- 2 ] if not is_empty_layer else 0
599+ return layer_seq_length
600+
601+ def to_legacy_cache (self ) -> Tuple [Tuple [torch .Tensor ], Tuple [torch .Tensor ]]:
602+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
603+ backward compatibility."""
604+ legacy_cache = ()
605+ for layer_idx in range (len (self )):
606+ legacy_cache += ((self .key_cache [layer_idx ], self .value_cache [layer_idx ]),)
607+ return legacy_cache
608+
609+ def update (
610+ self ,
611+ key_states : torch .Tensor ,
612+ value_states : torch .Tensor ,
613+ layer_idx : int ,
614+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
615+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
616+ if len (self .key_cache ) <= layer_idx :
617+ self .key_cache .append (key_states )
618+ self .value_cache .append (value_states )
619+ k_out , v_out = key_states , value_states
620+ else :
621+ position_ids = cache_kwargs .get ("position_ids" )
622+ is_sliding_layer = cache_kwargs .get ("is_sliding" )
623+ sliding_window = cache_kwargs .get ("sliding_window" )
624+ batch_index = cache_kwargs .get ("batch_index" , None ) # Check and fetch batch index value from the kwargs
625+
626+ if is_sliding_layer :
627+ kv_position_ids = torch .where (position_ids == - 1 , position_ids , position_ids % sliding_window )
628+ else :
629+ kv_position_ids = position_ids
630+
631+ if batch_index is not None :
632+ if torch .onnx .is_in_onnx_export ():
633+ invalid_scatter_index = torch .iinfo (torch .int32 ).max
634+ scatter_position_ids = torch .where (kv_position_ids < 0 , invalid_scatter_index , kv_position_ids )
635+ else :
636+ scatter_position_ids = kv_position_ids
637+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
638+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
639+ )
640+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
641+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
642+ )
643+ else :
644+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], kv_position_ids , key_states )
645+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (
646+ self .value_cache [layer_idx ], kv_position_ids , value_states
647+ )
648+
649+ k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
650+
651+ # Original Gather
652+ ctx_len = self .key_cache [layer_idx ].shape [2 ]
653+ ctx_indices = torch .arange (ctx_len )[None , None , ...]
654+ gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
655+ invalid_mask = ctx_indices > gather_limit
656+ if torch .onnx .is_in_onnx_export ():
657+ invalid_idx_value = torch .iinfo (torch .int32 ).max
658+ else :
659+ invalid_idx_value = 0
660+ ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
661+
662+ if batch_index is not None :
663+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices )
664+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices )
665+ else :
666+ k_out = CtxGatherFunc .apply (k_out , ctx_indices )
667+ v_out = CtxGatherFunc .apply (v_out , ctx_indices )
668+
669+ v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
670+ return k_out , v_out
0 commit comments