Skip to content

Commit 5364dff

Browse files
committed
Update
[ghstack-poisoned]
1 parent f21e492 commit 5364dff

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

test/test_configs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,10 @@ def test_tensor_dict_module_config(self):
835835
in_keys=["observation"],
836836
out_keys=["action"],
837837
)
838-
assert cfg._target_ == "tensordict.nn.TensorDictModule"
838+
assert (
839+
cfg._target_
840+
== "torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
841+
)
839842
assert cfg.module._target_ == "torchrl.modules.MLP"
840843
assert cfg.in_keys == ["observation"]
841844
assert cfg.out_keys == ["action"]

torchrl/trainers/algorithms/configs/modules.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class TensorDictModuleConfig(ModelConfig):
222222
"""
223223

224224
module: MLPConfig = MISSING
225-
_target_: str = "tensordict.nn.TensorDictModule"
225+
_target_: str = "torchrl.trainers.algorithms.configs.modules._make_tensordict_module"
226226
_partial_: bool = False
227227

228228
def __post_init__(self) -> None:
@@ -292,6 +292,30 @@ def __post_init__(self) -> None:
292292
super().__post_init__()
293293

294294

295+
def _make_tensordict_module(*args, **kwargs):
296+
"""Helper function to create a TensorDictModule."""
297+
from hydra.utils import instantiate
298+
from tensordict.nn import TensorDictModule
299+
300+
module = kwargs.pop("module")
301+
shared = kwargs.pop("shared", False)
302+
303+
# Instantiate the module if it's a config
304+
if hasattr(module, "_target_"):
305+
module = instantiate(module)
306+
elif callable(module) and hasattr(module, "func"): # partial function
307+
module = module()
308+
309+
# Create the TensorDictModule
310+
tensordict_module = TensorDictModule(module, **kwargs)
311+
312+
# Apply share_memory if needed
313+
if shared:
314+
tensordict_module = tensordict_module.share_memory()
315+
316+
return tensordict_module
317+
318+
295319
def _make_tanh_normal_model(*args, **kwargs):
296320
"""Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
297321
from hydra.utils import instantiate
@@ -351,10 +375,24 @@ def _make_tanh_normal_model(*args, **kwargs):
351375

352376
def _make_value_model(*args, **kwargs):
353377
"""Helper function to create a ValueOperator with the given network."""
378+
from hydra.utils import instantiate
379+
354380
from torchrl.modules import ValueOperator
355381

356382
network = kwargs.pop("network")
357383
shared = kwargs.pop("shared", False)
384+
385+
# Instantiate the network if it's a config
386+
if hasattr(network, "_target_"):
387+
network = instantiate(network)
388+
elif callable(network) and hasattr(network, "func"): # partial function
389+
network = network()
390+
391+
# Create the ValueOperator
392+
value_operator = ValueOperator(network, **kwargs)
393+
394+
# Apply share_memory if needed
358395
if shared:
359-
network = network.share_memory()
360-
return ValueOperator(network, **kwargs)
396+
value_operator = value_operator.share_memory()
397+
398+
return value_operator

torchrl/trainers/algorithms/ppo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def __init__(
212212

213213
if not self.async_collection:
214214
# rb has been extended by the collector
215-
raise RuntimeError
216215
self.register_op("pre_epoch", rb_trainer.extend)
217216
self.register_op("process_optim_batch", rb_trainer.sample)
218217
self.register_op("post_loss", rb_trainer.update_priority)

0 commit comments

Comments
 (0)