@@ -733,7 +733,7 @@ def __init__(
733733 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
734734 # Will break for older PT versions which don't have torch.mps
735735 self ._sync_storage = torch .mps .synchronize
736- elif torch . npu . is_available () and hasattr (torch , "npu" ):
736+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
737737 self ._sync_storage = torch .npu .synchronize
738738 elif self .storing_device .type == "cpu" :
739739 self ._sync_storage = _do_nothing
@@ -749,7 +749,7 @@ def __init__(
749749 self ._sync_env = torch .cuda .synchronize
750750 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
751751 self ._sync_env = torch .mps .synchronize
752- elif torch . npu . is_available () and hasattr (torch , "npu" ):
752+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
753753 self ._sync_env = torch .npu .synchronize
754754 elif self .env_device .type == "cpu" :
755755 self ._sync_env = _do_nothing
@@ -764,7 +764,7 @@ def __init__(
764764 self ._sync_policy = torch .cuda .synchronize
765765 elif torch .backends .mps .is_available () and hasattr (torch , "mps" ):
766766 self ._sync_policy = torch .mps .synchronize
767- elif torch . npu . is_available () and hasattr (torch , "npu" ):
767+ elif hasattr (torch , "npu" ) and torch . npu . is_available ( ):
768768 self ._sync_policy = torch .npu .synchronize
769769 elif self .policy_device .type == "cpu" :
770770 self ._sync_policy = _do_nothing
0 commit comments