@@ -2233,7 +2233,9 @@ def __init__(self):
22332233 self .out_keys = ["action" ]
22342234
22352235 def forward (self , td ):
2236- td ["action" ] = (self .param + self .buf ).expand (td .shape )
2236+ td ["action" ] = (self .param + self .buf .to (self .param .device )).expand (
2237+ td .shape
2238+ )
22372239 return td
22382240
22392241 @pytest .mark .parametrize (
@@ -2288,6 +2290,64 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
22882290 col .shutdown ()
22892291 del col
22902292
2293+ @pytest .mark .parametrize (
2294+ "collector" ,
2295+ [
2296+ functools .partial (MultiSyncDataCollector , cat_results = "stack" ),
2297+ MultiaSyncDataCollector ,
2298+ ],
2299+ )
2300+ @pytest .mark .parametrize ("give_weights" , [True , False ])
2301+ @pytest .mark .parametrize (
2302+ "policy_device,env_device" ,
2303+ [
2304+ ["cpu" , get_default_devices ()[0 ]],
2305+ [get_default_devices ()[0 ], "cpu" ],
2306+ # ["cpu", "cuda:0"], # 1226: faster execution
2307+ # ["cuda:0", "cpu"],
2308+ # ["cuda", "cuda:0"],
2309+ # ["cuda:0", "cuda"],
2310+ ],
2311+ )
2312+ def test_param_sync_mixed_device (
2313+ self , give_weights , collector , policy_device , env_device
2314+ ):
2315+ with torch .device ("cpu" ):
2316+ policy = TestUpdateParams .Policy ()
2317+ policy .param = nn .Parameter (policy .param .data .to (policy_device ))
2318+ assert policy .buf .device == torch .device ("cpu" )
2319+
2320+ env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
2321+ device = env ().device
2322+ env = [env ]
2323+ col = collector (
2324+ env , policy , device = device , total_frames = 200 , frames_per_batch = 10
2325+ )
2326+ try :
2327+ for i , data in enumerate (col ):
2328+ if i == 0 :
2329+ assert (data ["action" ] == 0 ).all ()
2330+ # update policy
2331+ policy .param .data += 1
2332+ policy .buf .data += 2
2333+ assert policy .buf .device == torch .device ("cpu" )
2334+ if give_weights :
2335+ p_w = TensorDict .from_module (policy )
2336+ else :
2337+ p_w = None
2338+ col .update_policy_weights_ (p_w )
2339+ elif i == 20 :
2340+ if (data ["action" ] == 1 ).all ():
2341+ raise RuntimeError ("Failed to update buffer" )
2342+ elif (data ["action" ] == 2 ).all ():
2343+ raise RuntimeError ("Failed to update params" )
2344+ elif (data ["action" ] == 0 ).all ():
2345+ raise RuntimeError ("Failed to update params and buffers" )
2346+ assert (data ["action" ] == 3 ).all ()
2347+ finally :
2348+ col .shutdown ()
2349+ del col
2350+
22912351
22922352class TestAggregateReset :
22932353 def test_aggregate_reset_to_root (self ):
0 commit comments