@@ -222,7 +222,7 @@ class TensorDictModuleConfig(ModelConfig):
222222 """
223223
224224 module : MLPConfig = MISSING
225- _target_ : str = "tensordict.nn.TensorDictModule "
225+ _target_ : str = "torchrl.trainers.algorithms.configs.modules._make_tensordict_module "
226226 _partial_ : bool = False
227227
228228 def __post_init__ (self ) -> None :
@@ -292,6 +292,30 @@ def __post_init__(self) -> None:
292292 super ().__post_init__ ()
293293
294294
295+ def _make_tensordict_module (* args , ** kwargs ):
296+ """Helper function to create a TensorDictModule."""
297+ from hydra .utils import instantiate
298+ from tensordict .nn import TensorDictModule
299+
300+ module = kwargs .pop ("module" )
301+ shared = kwargs .pop ("shared" , False )
302+
303+ # Instantiate the module if it's a config
304+ if hasattr (module , "_target_" ):
305+ module = instantiate (module )
306+ elif callable (module ) and hasattr (module , "func" ): # partial function
307+ module = module ()
308+
309+ # Create the TensorDictModule
310+ tensordict_module = TensorDictModule (module , ** kwargs )
311+
312+ # Apply share_memory if needed
313+ if shared :
314+ tensordict_module = tensordict_module .share_memory ()
315+
316+ return tensordict_module
317+
318+
295319def _make_tanh_normal_model (* args , ** kwargs ):
296320 """Helper function to create a TanhNormal model with ProbabilisticTensorDictSequential."""
297321 from hydra .utils import instantiate
@@ -351,10 +375,24 @@ def _make_tanh_normal_model(*args, **kwargs):
351375
352376def _make_value_model (* args , ** kwargs ):
353377 """Helper function to create a ValueOperator with the given network."""
378+ from hydra .utils import instantiate
379+
354380 from torchrl .modules import ValueOperator
355381
356382 network = kwargs .pop ("network" )
357383 shared = kwargs .pop ("shared" , False )
384+
385+ # Instantiate the network if it's a config
386+ if hasattr (network , "_target_" ):
387+ network = instantiate (network )
388+ elif callable (network ) and hasattr (network , "func" ): # partial function
389+ network = network ()
390+
391+ # Create the ValueOperator
392+ value_operator = ValueOperator (network , ** kwargs )
393+
394+ # Apply share_memory if needed
358395 if shared :
359- network = network .share_memory ()
360- return ValueOperator (network , ** kwargs )
396+ value_operator = value_operator .share_memory ()
397+
398+ return value_operator
0 commit comments