@@ -625,110 +625,92 @@ def filter(self, indices):
625625 return self
626626
627627
628- class LogitBiasProcessor :
629- """Process logits with logit biases."""
628+ class LogitBiasProcessor (LogitsProcessor ):
629+ """
630+ `LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
631+ corresponding bias values. Bias are applied to the logits during each forward pass.
632+
633+ Supports token IDs provided as strings (e.g., {"9707": -100}).
634+ """
630635
631636 def __init__ (
632- self , logit_biases : Optional [dict ], tokenizer : PreTrainedTokenizerBase
637+ self ,
638+ logit_biases : dict ,
639+ tokenizer : PreTrainedTokenizerBase ,
640+ device : torch .device ,
633641 ):
634- self .tokenizer = tokenizer
635- self .logit_biases = logit_biases or {}
642+ assert logit_biases , "LogitBiasProcessor requires non-empty logit_biases"
636643
637- # Pre-compute token IDs for each token string
638- self .token_id_mapping = {}
644+ vocab_size = len (tokenizer )
639645
640- def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
641- # If no logit biases, return scores unchanged
642- if not self .logit_biases :
643- return scores
644-
645- # Apply bias to the corresponding scores
646- for token_str , bias_value in self .logit_biases .items ():
647- # Get token ID, either from cache or by computing it
648- if token_str not in self .token_id_mapping :
649- if token_str .isdigit ():
650- # If the token string is already a numeric ID
651- token_id = int (token_str )
652- else :
653- # Otherwise, use the tokenizer to get the ID
654- tokens = self .tokenizer .encode (token_str , add_special_tokens = False )
655- token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
656-
657- self .token_id_mapping [token_str ] = token_id
658-
659- token_id = self .token_id_mapping [token_str ]
660-
661- # Apply bias if token ID is valid
662- if 0 <= token_id < scores .size (- 1 ):
663- scores [:, token_id ] += bias_value
646+ # Convert keys to integers and values to a list
647+ token_ids = torch .tensor (
648+ [int (k ) for k in logit_biases .keys ()], dtype = torch .long
649+ )
650+ bias_values = torch .tensor (list (logit_biases .values ()), dtype = torch .float )
664651
665- return scores
652+ # Create a tensor and directly copy bias values at the corresponding indices
653+ self .bias_tensor = torch .zeros (vocab_size , dtype = torch .float )
654+ self .bias_tensor .index_put_ ((token_ids ,), bias_values , accumulate = True )
666655
667- def filter (self , indices ):
668- """Keep only the logit biases for the specified indices."""
669- new_logit_biases = {
670- k : self .logit_biases [k ] for k in indices if k in self .logit_biases
671- }
672- return LogitBiasProcessor (new_logit_biases , self .tokenizer )
656+ def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
657+ # Apply bias tensor as a broadcasted addition
658+ if self .bias_tensor .shape [0 ] != scores .shape [1 ]:
659+ # Fix if the bias tensor is smaller than the scores
660+ self .bias_tensor = torch .nn .functional .pad (
661+ self .bias_tensor , (0 , scores .shape [1 ] - self .bias_tensor .shape [0 ])
662+ )
663+ scores .add_ (self .bias_tensor .to (device = scores .device , dtype = scores .dtype ))
664+ return scores
673665
674666
675- class HeterogeneousLogitBiasProcessor :
676- """Process logits with different logit biases for each sequence in the batch."""
667+ class HeterogeneousLogitBiasProcessor (LogitsProcessor ):
668+ """
669+ Process logits with different logit biases for each sequence in the batch.
670+ """
677671
678672 def __init__ (
679673 self ,
680674 logit_biases : List [Optional [dict ]],
681675 tokenizer : PreTrainedTokenizerBase ,
682676 device : torch .device ,
683677 ):
684- self .device = device
685678 self .tokenizer = tokenizer
686679 self .logit_biases = logit_biases
687- self .batch_size = len (logit_biases )
680+ # import ipdb; ipdb.set_trace()
681+ self .vocab_size = len (tokenizer )
688682
689- # Pre-compute token IDs for each token string
690- self .token_id_mapping = {}
683+ # Create batch_size x vocab_size bias matrix
684+ self .bias_matrix = torch .zeros (
685+ (len (logit_biases ), self .vocab_size ), dtype = torch .float , device = device
686+ )
691687
692- # Create a mapping of indices that have logit biases
693- self .indices_with_biases = {
694- i : bias_dict
695- for i , bias_dict in enumerate (self .logit_biases )
696- if bias_dict is not None and len (bias_dict ) > 0
697- }
688+ # for each logit bias dictionary, convert keys to integers and values to a list
689+ for i , logit_bias in enumerate (logit_biases ):
690+ token_ids = torch .tensor (
691+ [int (k ) for k in logit_bias .keys ()], dtype = torch .long
692+ ).to (device = device )
693+ bias_values = torch .tensor (list (logit_bias .values ()), dtype = torch .float ).to (
694+ device = device
695+ )
696+ # Create a tensor and directly copy bias values at the corresponding indices
697+ self .bias_matrix [i ].index_put_ ((token_ids ,), bias_values , accumulate = True )
698698
699699 def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
700- # If no indices have biases, return scores unchanged
701- if not self .indices_with_biases :
702- return scores
703-
704- # For each index with a bias, apply the bias to the corresponding scores
705- for i , bias_dict in self .indices_with_biases .items ():
706- for token_str , bias_value in bias_dict .items ():
707- # Get token ID, either from cache or by computing it
708- if token_str not in self .token_id_mapping :
709- if token_str .isdigit ():
710- # If the token string is already a numeric ID
711- token_id = int (token_str )
712- else :
713- # Otherwise, use the tokenizer to get the ID
714- tokens = self .tokenizer .encode (
715- token_str , add_special_tokens = False
716- )
717- token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
718-
719- self .token_id_mapping [token_str ] = token_id
720-
721- token_id = self .token_id_mapping [token_str ]
722-
723- # Apply bias if token ID is valid
724- if 0 <= token_id < scores .size (- 1 ):
725- scores [i , token_id ] += bias_value
700+ # Apply bias matrix as a broadcasted addition
701+ if self .bias_matrix .shape [1 ] != scores .shape [1 ]:
702+ # Fix if the bias matrix is smaller than the scores
703+ self .bias_matrix = torch .nn .functional .pad (
704+ self .bias_matrix , (0 , scores .shape [1 ] - self .bias_matrix .shape [1 ])
705+ )
726706
707+ scores .add_ (self .bias_matrix .to (device = scores .device , dtype = scores .dtype ))
727708 return scores
728709
729- def filter (self , indices : List [int ]):
730- """Keep only the logit biases for the specified indices."""
710+ def filter (self , indices ):
731711 new_logit_biases = [self .logit_biases [i ] for i in indices ]
712+ if not any (bias and len (bias ) > 0 for bias in new_logit_biases ):
713+ return None
732714 return HeterogeneousLogitBiasProcessor (
733715 new_logit_biases , self .tokenizer , self .device
734716 )
0 commit comments