File tree Expand file tree Collapse file tree 3 files changed +15
-5
lines changed Expand file tree Collapse file tree 3 files changed +15
-5
lines changed Original file line number Diff line number Diff line change 181181 pytest.mark.filterwarnings(
182182 "ignore:dep_util is Deprecated. Use functions from setuptools instead"
183183 ),
184+ pytest.mark.filterwarnings(
185+ "ignore:The PyTorch API of nested tensors is in prototype"
186+ ),
184187]
185188
186189
@@ -16679,9 +16682,13 @@ def test_hf(self, from_text):
1667916682 tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
1668016683 tokenizer.pad_token = tokenizer.eos_token
1668116684
16682- model = OPTForCausalLM(OPTConfig())
16685+ model = OPTForCausalLM(OPTConfig()).eval()
1668316686 policy_inference = TransformersWrapper(
16684- model, tokenizer=tokenizer, generate=True, from_text=from_text
16687+ model,
16688+ tokenizer=tokenizer,
16689+ generate=True,
16690+ from_text=from_text,
16691+ return_log_probs=True,
1668516692 )
1668616693 policy_train = TransformersWrapper(
1668716694 model, tokenizer=tokenizer, generate=False, from_text=False
Original file line number Diff line number Diff line change @@ -606,8 +606,8 @@ def _step(self, tensordict):
606606 reward = torch .tensor ([reward_val ], dtype = torch .float32 )
607607 dest .set ("reward" , reward )
608608 dest .set ("turn" , turn )
609- dest .set ("done" , [done ])
610- dest .set ("terminated" , [done ])
609+ dest .set ("done" , torch . tensor ( [done ]) )
610+ dest .set ("terminated" , torch . tensor ( [done ]) )
611611 if self .pixels :
612612 dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
613613 return dest
Original file line number Diff line number Diff line change @@ -584,7 +584,10 @@ def _log_weight(
584584 self .tensor_keys .sample_log_prob ,
585585 adv_shape ,
586586 )
587-
587+ if prev_log_prob is None :
588+ raise KeyError (
589+ f"Couldn't find the log-prob { self .tensor_keys .sample_log_prob } in the input data."
590+ )
588591 if prev_log_prob .requires_grad :
589592 raise RuntimeError (
590593 f"tensordict stored { self .tensor_keys .sample_log_prob } requires grad."
You can’t perform that action at this time.
0 commit comments