@@ -625,6 +625,53 @@ def filter(self, indices):
625625 return self
626626
627627
628+ class LogitBiasProcessor :
629+ """Process logits with logit biases."""
630+
631+ def __init__ (
632+ self , logit_biases : Optional [dict ], tokenizer : PreTrainedTokenizerBase
633+ ):
634+ self .tokenizer = tokenizer
635+ self .logit_biases = logit_biases or {}
636+
637+ # Pre-compute token IDs for each token string
638+ self .token_id_mapping = {}
639+
640+ def __call__ (self , input_ids : torch .Tensor , scores : torch .Tensor ) -> torch .Tensor :
641+ # If no logit biases, return scores unchanged
642+ if not self .logit_biases :
643+ return scores
644+
645+ # Apply bias to the corresponding scores
646+ for token_str , bias_value in self .logit_biases .items ():
647+ # Get token ID, either from cache or by computing it
648+ if token_str not in self .token_id_mapping :
649+ if token_str .isdigit ():
650+ # If the token string is already a numeric ID
651+ token_id = int (token_str )
652+ else :
653+ # Otherwise, use the tokenizer to get the ID
654+ tokens = self .tokenizer .encode (token_str , add_special_tokens = False )
655+ token_id = tokens [0 ] if tokens else - 1 # Use -1 for not found
656+
657+ self .token_id_mapping [token_str ] = token_id
658+
659+ token_id = self .token_id_mapping [token_str ]
660+
661+ # Apply bias if token ID is valid
662+ if 0 <= token_id < scores .size (- 1 ):
663+ scores [:, token_id ] += bias_value
664+
665+ return scores
666+
667+ def filter (self , indices ):
668+ """Keep only the logit biases for the specified indices."""
669+ new_logit_biases = {
670+ k : self .logit_biases [k ] for k in indices if k in self .logit_biases
671+ }
672+ return LogitBiasProcessor (new_logit_biases , self .tokenizer )
673+
674+
628675class HeterogeneousLogitBiasProcessor :
629676 """Process logits with different logit biases for each sequence in the batch."""
630677
0 commit comments