2424 TensorDictSequential as Seq ,
2525 WrapModule ,
2626)
27- from tensordict .utils import _zip_strict
27+ from tensordict .utils import _zip_strict , expand_as_right
2828
2929from torchrl .data import LLMData
3030
@@ -130,6 +130,9 @@ def from_vllm(
130130 token_key : NestedKey = ("tokens" ,)
131131 attention_mask_key : NestedKey = ("attention_mask" ,)
132132
133+ # retrieve the padding value - we use this to make the log-probs of pad token = 1
134+ padding_value = tokenizer (tokenizer .pad_token )["input_ids" ][0 ]
135+
133136 # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
134137 if tokenizer is None :
135138 tokenizer = model .get_tokenizer ()
@@ -264,8 +267,6 @@ def to_list(tokens, attention_mask):
264267 strict = True ,
265268 )
266269
267- padding_value = tokenizer (tokenizer .pad_token )["input_ids" ][0 ]
268-
269270 def get_output_tokens_and_log_probs (td , padding_value = padding_value ):
270271 td ["tokens_out" ] = _RequestOutput_tc .from_request_output (td ["tokens_out" ])
271272 if pad_output and td .ndim and not isinstance (td , LazyStackedTensorDict ):
@@ -280,10 +281,18 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
280281 layout = torch .strided
281282 ).to_padded_tensor (padding = padding_value )
282283 tokens_response_td .rename_key_ ("token_ids" , "tokens_response" )
283- # td["tokens_response"] = outputs.token_ids
284284 if return_log_probs :
285+ padded_values = tokens_response_td ["tokens_response" ] == padding_value
285286 tokens_response_td .rename_key_ ("logprobs" , "log_probs" )
286- # td["log_probs"] = outputs.logprobs.unsqueeze(-1)
287+ if padded_values .any ():
288+ print (
289+ "padded_values:" ,
290+ padded_values .sum (),
291+ torch .where (padded_values ),
292+ )
293+ lps = tokens_response_td ["log_probs" ]
294+ lps = torch .where (expand_as_right (~ padded_values , lps ), lps , 0.0 )
295+ tokens_response_td ["log_probs" ] = lps
287296 td .update (tokens_response_td )
288297 elif not generate :
289298 td ["prompt_logprobs" ] = td ["tokens_out" ].prompt_logprobs .unsqueeze (- 1 )
@@ -295,7 +304,10 @@ def get_output_tokens_and_log_probs(td, padding_value=padding_value):
295304
296305 def translate_lps (tokens_response , x ):
297306 # we disregard the tokens from the prompt to focus on those of the response
298- return x [..., - tokens_response .shape [- 1 ] :, :]
307+ padded = tokens_response == padding_value
308+ lps = x [..., - tokens_response .shape [- 1 ] :, :]
309+ lps [padded ] = 0.0
310+ return x
299311
300312 module_dict ["translate_lps" ] = Mod (
301313 translate_lps ,
0 commit comments