@@ -733,6 +733,8 @@ def __init__(
733733 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
734734 # Will break for older PT versions which don't have torch.mps
735735 self ._sync_storage = torch .mps .synchronize
736+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
737+ self ._sync_storage = torch .npu .synchronize
736738 elif self .storing_device .type == "cpu" :
737739 self ._sync_storage = _do_nothing
738740 else :
@@ -747,6 +749,8 @@ def __init__(
747749 self ._sync_env = torch .cuda .synchronize
748750 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
749751 self ._sync_env = torch .mps .synchronize
752+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
753+ self ._sync_env = torch .npu .synchronize
750754 elif self .env_device .type == "cpu" :
751755 self ._sync_env = _do_nothing
752756 else :
@@ -760,6 +764,8 @@ def __init__(
760764 self ._sync_policy = torch .cuda .synchronize
761765 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
762766 self ._sync_policy = torch .mps .synchronize
767+ elif torch .npu .is_available () and hasattr (torch , "npu" ):
768+ self ._sync_policy = torch .npu .synchronize
763769 elif self .policy_device .type == "cpu" :
764770 self ._sync_policy = _do_nothing
765771 else :
0 commit comments