From 94b95e2ce92d6af3529cbc73b49fd38e892ce483 Mon Sep 17 00:00:00 2001 From: lowdy1 Date: Tue, 4 Nov 2025 11:57:41 +0000 Subject: [PATCH] Add NPU Support for Single Agent --- sota-implementations/a2c/a2c_atari.py | 7 ++++++- sota-implementations/a2c/a2c_mujoco.py | 7 ++++++- sota-implementations/cql/cql_offline.py | 2 ++ sota-implementations/cql/cql_online.py | 2 ++ sota-implementations/cql/discrete_cql_offline.py | 7 ++++++- sota-implementations/cql/discrete_cql_online.py | 2 ++ sota-implementations/crossq/crossq.py | 2 ++ sota-implementations/ddpg/ddpg.py | 4 ++++ sota-implementations/decision_transformer/dt.py | 2 ++ sota-implementations/decision_transformer/online_dt.py | 2 ++ sota-implementations/dqn/dqn_atari.py | 2 ++ sota-implementations/dqn/dqn_cartpole.py | 2 ++ sota-implementations/gail/gail.py | 2 ++ sota-implementations/iql/discrete_iql.py | 2 ++ sota-implementations/iql/iql_offline.py | 4 +++- sota-implementations/iql/iql_online.py | 2 ++ sota-implementations/ppo/ppo_atari.py | 2 ++ sota-implementations/ppo/ppo_mujoco.py | 2 ++ sota-implementations/sac/sac-async.py | 2 ++ sota-implementations/sac/sac.py | 2 ++ sota-implementations/td3/td3.py | 2 ++ sota-implementations/td3_bc/td3_bc.py | 2 ++ 22 files changed, 59 insertions(+), 4 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 4135a45f8b3..86569bf7c39 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -35,7 +35,12 @@ def main(cfg: DictConfig): # noqa: F821 device = cfg.loss.device if not device: - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") + else: + device = torch.device("cpu") else: device = torch.device(device) diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 82b466e2e2e..3c56a061575 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -38,7 +38,12 @@ def main(cfg: DictConfig): # noqa: F821 device = cfg.loss.device if not device: - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") + else: + device = torch.device("cpu") else: device = torch.device(device) diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 9acd00b1627..85b5f49f61d 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -59,6 +59,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5d25a34ba10..b64c39a6a7a 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -64,6 +64,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/cql/discrete_cql_offline.py b/sota-implementations/cql/discrete_cql_offline.py index ff225ce3b31..6ca12ecda46 100644 --- a/sota-implementations/cql/discrete_cql_offline.py +++ b/sota-implementations/cql/discrete_cql_offline.py @@ -38,7 +38,12 @@ def main(cfg): # noqa: F821 device = cfg.optim.device if device in ("", None): - device = "cuda:0" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" + else: + device = "cpu" device = torch.device(device) # Create logger diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 7b0328bb73b..9a8b566131b 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -42,6 +42,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/crossq/crossq.py b/sota-implementations/crossq/crossq.py index 619f2395fb1..9cead247510 100644 --- a/sota-implementations/crossq/crossq.py +++ b/sota-implementations/crossq/crossq.py @@ -44,6 +44,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") else: device = torch.device("cpu") device = torch.device(device) diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index 5b6d308aba2..697888ae7b1 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -43,6 +43,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) @@ -51,6 +53,8 @@ def main(cfg: DictConfig): # noqa: F821 if collector_device in ("", None): if torch.cuda.is_available(): collector_device = "cuda:0" + elif torch.npu.is_available(): + collector_device = "npu:0" else: collector_device = "cpu" collector_device = torch.device(collector_device) diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index f565aafeafc..74add66a295 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -42,6 +42,8 @@ def main(cfg: DictConfig): # noqa: F821 if model_device in ("", None): if torch.cuda.is_available(): model_device = "cuda:0" + elif torch.npu.is_available(): + model_device = "npu:0" else: model_device = "cpu" model_device = torch.device(model_device) diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index baab8bbb9a6..888b66d4a81 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -40,6 +40,8 @@ def main(cfg: DictConfig): # noqa: F821 if model_device in ("", None): if torch.cuda.is_available(): model_device = "cuda:0" + elif torch.npu.is_available(): + model_device = "npu:0" else: model_device = "cpu" model_device = torch.device(model_device) diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 14ef64b2b60..673d347a00d 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -37,6 +37,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index d532278c064..ca4151f07d7 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -32,6 +32,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 0e89f48f108..9d7b51871da 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -42,6 +42,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 43a8dcafa6e..ea54fd1af8e 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -67,6 +67,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 6585534ff68..b71b76ae1c5 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -63,11 +63,13 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) - # Creante env + # Create env train_env, eval_env = make_environment( cfg, cfg.logger.eval_envs, diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index eaa37f29176..80fbf84dc09 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -66,6 +66,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 301e9edfa02..63f141d76eb 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -41,6 +41,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index a4fa1941a6b..eca6f52cca9 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -41,6 +41,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device) diff --git a/sota-implementations/sac/sac-async.py b/sota-implementations/sac/sac-async.py index b216f284840..b41fd9a5fb0 100644 --- a/sota-implementations/sac/sac-async.py +++ b/sota-implementations/sac/sac-async.py @@ -56,6 +56,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") else: device = torch.device("cpu") device = torch.device(device) diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index 7fd6284037e..c80e8d36b1d 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -45,6 +45,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") else: device = torch.device("cpu") device = torch.device(device) diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index f7b10e8cdf9..787cfd15fc4 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -43,6 +43,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = torch.device("cuda:0") + elif torch.npu.is_available(): + device = torch.device("npu:0") else: device = torch.device("cpu") else: diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 6c628904908..ae5fd3a1c9c 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -61,6 +61,8 @@ def main(cfg: DictConfig): # noqa: F821 if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" + elif torch.npu.is_available(): + device = "npu:0" else: device = "cpu" device = torch.device(device)