@@ -208,7 +208,7 @@ def _get_policy_and_device(
208208 return policy , None
209209
210210 if isinstance (policy , nn .Module ):
211- param_and_buf = TensorDict .from_module (policy , as_module = True ). data
211+ param_and_buf = TensorDict .from_module (policy , as_module = True )
212212 else :
213213 # Because we want to reach the warning
214214 param_and_buf = TensorDict ()
@@ -231,19 +231,25 @@ def _get_policy_and_device(
231231 return policy , None
232232
233233 # Create a stateless policy, then populate this copy with params on device
234- def get_original_weights (policy ):
234+ def get_original_weights (policy = policy ):
235235 td = TensorDict .from_module (policy )
236236 return td .data
237237
238238 # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
239239 with param_and_buf .data .to ("meta" ).to_module (policy ):
240- policy = deepcopy (policy )
240+ policy_new_device = deepcopy (policy )
241241
242- param_and_buf .apply (
242+ param_and_buf_new_device = param_and_buf .apply (
243243 functools .partial (_map_weight , policy_device = policy_device ),
244244 filter_empty = False ,
245- ).to_module (policy )
246- return policy , get_original_weights
245+ )
246+ param_and_buf_new_device .to_module (policy_new_device )
247+ # Sanity check
248+ if set (TensorDict .from_module (policy_new_device ).keys (True , True )) != set (
249+ get_original_weights ().keys (True , True )
250+ ):
251+ raise RuntimeError ("Failed to map weights. The weight sets mismatch." )
252+ return policy_new_device , get_original_weights
247253
248254 def start (self ):
249255 """Starts the collector for asynchronous data collection.
@@ -1976,17 +1982,17 @@ def __init__(
19761982 for policy_device , env_maker , env_maker_kwargs in _zip_strict (
19771983 self .policy_device , self .create_env_fn , self .create_env_kwargs
19781984 ):
1979- (policy_copy , get_weights_fn ,) = self ._get_policy_and_device (
1985+ (policy_new_device , get_weights_fn ,) = self ._get_policy_and_device (
19801986 policy = policy ,
19811987 policy_device = policy_device ,
19821988 env_maker = env_maker ,
19831989 env_maker_kwargs = env_maker_kwargs ,
19841990 )
1985- if type (policy_copy ) is not type (policy ):
1986- policy = policy_copy
1991+ if type (policy_new_device ) is not type (policy ):
1992+ policy = policy_new_device
19871993 weights = (
1988- TensorDict .from_module (policy_copy ).data
1989- if isinstance (policy_copy , nn .Module )
1994+ TensorDict .from_module (policy_new_device ).data
1995+ if isinstance (policy_new_device , nn .Module )
19901996 else TensorDict ()
19911997 )
19921998 self ._policy_weights_dict [policy_device ] = weights
0 commit comments