@@ -13686,7 +13686,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1368613686 assert target_val.device == source_val.device, key
1368713687 if target_val.dtype == torch.long:
1368813688 continue
13689- d0 += (target_val - source_val).norm().item()
13689+ with torch.no_grad():
13690+ d0 += (target_val - source_val).norm().item()
1369013691
1369113692 assert d0 > 0
1369213693 if mode == "hard":
@@ -13700,7 +13701,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1370013701 target_val = upd._targets[key]
1370113702 if target_val.dtype == torch.long:
1370213703 continue
13703- d1 += (target_val - source_val).norm().item()
13704+ with torch.no_grad():
13705+ d1 += (target_val - source_val).norm().item()
1370413706
1370513707 assert d1 == d0, i
1370613708 assert upd.counter == i
@@ -13715,7 +13717,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1371513717 target_val = upd._targets[key]
1371613718 if target_val.dtype == torch.long:
1371713719 continue
13718- d1 += (target_val - source_val).norm().item()
13720+ with torch.no_grad():
13721+ d1 += (target_val - source_val).norm().item()
1371913722 assert d1 < d0
1372013723
1372113724 elif mode == "soft":
@@ -13728,7 +13731,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1372813731 target_val = upd._targets[key]
1372913732 if target_val.dtype == torch.long:
1373013733 continue
13731- d1 += (target_val - source_val).norm().item()
13734+ with torch.no_grad():
13735+ d1 += (target_val - source_val).norm().item()
1373213736 assert d1 < d0
1373313737 with pytest.warns(UserWarning, match="already"):
1373413738 upd.init_()
@@ -13741,7 +13745,8 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
1374113745 target_val = upd._targets[key]
1374213746 if target_val.dtype == torch.long:
1374313747 continue
13744- d2 += (target_val - source_val).norm().item()
13748+ with torch.no_grad():
13749+ d2 += (target_val - source_val).norm().item()
1374513750 assert d2 < 1e-6
1374613751
1374713752
@@ -16668,17 +16673,17 @@ class TestPPO4LLMs:
1666816673 @pytest.mark.parametrize("from_text", [True, False])
1666916674 def test_hf(self, from_text):
1667016675 from torchrl.envs import LLMEnv, Transform
16671- from torchrl.modules import from_hf_transformers
16676+ from torchrl.modules import TransformersWrapper
1667216677 from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
1667316678
1667416679 tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
1667516680 tokenizer.pad_token = tokenizer.eos_token
1667616681
1667716682 model = OPTForCausalLM(OPTConfig())
16678- policy_inference = from_hf_transformers (
16683+ policy_inference = TransformersWrapper (
1667916684 model, tokenizer=tokenizer, generate=True, from_text=from_text
1668016685 )
16681- policy_train = from_hf_transformers (
16686+ policy_train = TransformersWrapper (
1668216687 model, tokenizer=tokenizer, generate=False, from_text=False
1668316688 )
1668416689 for p in policy_train.parameters():
0 commit comments