@@ -223,7 +223,7 @@ def __init__(
223223 if from_text :
224224 self .out_keys += [self .text_response_key , self .token_key ]
225225 if self .return_log_probs :
226- self .out_keys += ["log_probs" ]
226+ self .out_keys += [self . log_prob_key ]
227227
228228 def forward (
229229 self ,
@@ -303,7 +303,7 @@ def _from_vllm_generate_text(self, td):
303303 ),
304304 )
305305 in_keys = [
306- "log_probs" ,
306+ self . log_prob_key ,
307307 self .token_response_key ,
308308 self .text_response_key ,
309309 self .token_key ,
@@ -394,7 +394,7 @@ def _from_vllm_logprobs_text(self, td):
394394 if isinstance (input_ids_response , list ):
395395 input_ids_response = torch .nested .nested_tensor (input_ids_response )
396396 out ["tokens_response" ] = input_ids_response
397- out ["log_probs" ] = lps
397+ out [self . log_prob_key ] = lps
398398 inputs = td .select (* self .in_keys , strict = False )
399399 if inputs .ndim < out .ndim :
400400 # This happens when n > 1
@@ -423,18 +423,19 @@ def _from_vllm_generate_tokens(self, td):
423423 ).to_padded_tensor (padding = self .padding_value )
424424 tokens_response_td .rename_key_ ("token_ids" , "tokens_response" )
425425 if self .return_log_probs :
426- tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
426+ tokens_response_td .rename_key_ ("logprobs" , self . log_prob_key )
427427 if self .pad_output :
428428 padded_values = (
429429 tokens_response_td ["tokens_response" ] == self .padding_value
430430 )
431431 if padded_values .any ():
432- lps = tokens_response_td ["log_probs" ]
432+ lps = tokens_response_td [self . log_prob_key ]
433433 lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
434- tokens_response_td ["log_probs" ] = lps
434+ tokens_response_td [self . log_prob_key ] = lps
435435 out = tokens_response_td .empty (recurse = True )
436436 out .update (
437- tokens_response_td , keys_to_update = (self .token_response_key , "log_probs" )
437+ tokens_response_td ,
438+ keys_to_update = (self .token_response_key , self .log_prob_key ),
438439 )
439440 inputs = td .select (* self .in_keys , strict = False )
440441 if inputs .ndim < out .ndim :
@@ -467,7 +468,7 @@ def _from_vllm_logprobs_tokens(self, td):
467468 padded = tokens_response == self .padding_value
468469 prompt_logprobs = torch .where (~ padded , prompt_logprobs , 0.0 )
469470 out = tokens_out ._tensordict .empty (recurse = True )
470- out .set ("log_probs" , prompt_logprobs )
471+ out .set (self . log_prob_key , prompt_logprobs )
471472 out .set (self .token_response_key , tokens_response )
472473 inputs = td .select (* self .in_keys , strict = False )
473474 if inputs .ndim < out .ndim :
@@ -501,13 +502,13 @@ def _get_output_tokens_and_log_probs(self, tokens_out):
501502 )
502503
503504 if self .return_log_probs or "logprobs" in tokens_response_td :
504- tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
505+ tokens_response_td .rename_key_ ("logprobs" , self . log_prob_key )
505506 if self .pad_output :
506507 padded_values = tokens_response_td ["tokens_response" ] == padding_value
507508 if padded_values .any ():
508- lps = tokens_response_td ["log_probs" ]
509+ lps = tokens_response_td [self . log_prob_key ]
509510 lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
510- tokens_response_td ["log_probs" ] = lps
511+ tokens_response_td [self . log_prob_key ] = lps
511512 return tokens_response_td
512513
513514 def _to_list (self , tokens , attention_mask ):
0 commit comments