File tree Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -2451,13 +2451,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
24512451 if event is not None :
24522452 event .record ()
24532453 event .synchronize ()
2454- mp_event .set ()
24552454
24562455 if _non_tensor_keys :
24572456 child_pipe .send (
24582457 ("non_tensor" , cur_td .select (* _non_tensor_keys , strict = False ))
24592458 )
24602459
2460+ # Set event only after non-tensor data is sent to avoid race condition
2461+ mp_event .set ()
2462+
24612463 del cur_td
24622464
24632465 elif cmd == "step" :
@@ -2483,7 +2485,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
24832485 if event is not None :
24842486 event .record ()
24852487 event .synchronize ()
2486- mp_event .set ()
24872488
24882489 # Make sure the root is updated
24892490 root_shared_tensordict .update_ (env ._step_mdp (input ))
@@ -2493,6 +2494,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
24932494 ("non_tensor" , next_td .select (* _non_tensor_keys , strict = False ))
24942495 )
24952496
2497+ # Set event only after non-tensor data is sent to avoid race condition
2498+ mp_event .set ()
2499+
24962500 del next_td
24972501
24982502 elif cmd == "step_and_maybe_reset" :
@@ -2525,13 +2529,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
25252529 if event is not None :
25262530 event .record ()
25272531 event .synchronize ()
2528- mp_event .set ()
25292532
25302533 if _non_tensor_keys :
25312534 ntd = root_next_td .select (* _non_tensor_keys )
25322535 ntd .set ("next" , td_next .select (* _non_tensor_keys ))
25332536 child_pipe .send (("non_tensor" , ntd ))
25342537
2538+ # Set event only after non-tensor data is sent to avoid race condition
2539+ mp_event .set ()
2540+
25352541 del td , root_next_td
25362542
25372543 elif cmd == "close" :
You can’t perform that action at this time.
0 commit comments