Skip to content

Commit e5c9a32

Browse files
committed
Update (base update)
[ghstack-poisoned]
2 parents f92de9f + 963fdd4 commit e5c9a32

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

.github/unittest/llm/scripts_llm/install.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive
3030
#printf "Installing PyTorch with cu128"
3131
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
3232
# if [ "${CU_VERSION:-}" == cpu ] ; then
33-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
3434
# else
35-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
3636
# fi
3737
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
3838
# if [ "${CU_VERSION:-}" == cpu ] ; then
39-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
4040
# else
41-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
4242
# fi
4343
#else
4444
# printf "Failed to install pytorch"
@@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive
4747

4848
# install tensordict
4949
if [[ "$RELEASE" == 0 ]]; then
50-
pip3 install git+https://github.com/pytorch/tensordict.git
50+
pip install "pybind11[global]" ninja
51+
pip install git+https://github.com/pytorch/tensordict.git
5152
else
52-
pip3 install tensordict
53+
pip install tensordict
5354
fi
5455

5556
# smoke test

test/test_weightsync.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
from mocking_classes import ContinuousActionVecMockEnv
1213
from tensordict import TensorDict
1314
from tensordict.nn import TensorDictModule
1415
from torch import multiprocessing as mp
1516
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
16-
from torchrl.envs import GymEnv
1717
from torchrl.weight_update.weight_sync_schemes import (
1818
_resolve_model,
1919
MPTransport,
@@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self):
274274
class TestCollectorIntegration:
275275
@pytest.fixture
276276
def simple_env(self):
277-
return GymEnv("CartPole-v1")
277+
return ContinuousActionVecMockEnv()
278278

279279
@pytest.fixture
280280
def simple_policy(self, simple_env):
@@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
291291
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
292292

293293
collector = SyncDataCollector(
294-
create_env_fn=lambda: GymEnv("CartPole-v1"),
294+
create_env_fn=ContinuousActionVecMockEnv,
295295
policy=simple_policy,
296296
frames_per_batch=64,
297297
total_frames=128,
@@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):
316316

317317
collector = MultiSyncDataCollector(
318318
create_env_fn=[
319-
lambda: GymEnv("CartPole-v1"),
320-
lambda: GymEnv("CartPole-v1"),
319+
ContinuousActionVecMockEnv,
320+
ContinuousActionVecMockEnv,
321321
],
322322
policy=simple_policy,
323323
frames_per_batch=64,
@@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy):
343343

344344
collector = MultiSyncDataCollector(
345345
create_env_fn=[
346-
lambda: GymEnv("CartPole-v1"),
347-
lambda: GymEnv("CartPole-v1"),
346+
ContinuousActionVecMockEnv,
347+
ContinuousActionVecMockEnv,
348348
],
349349
policy=simple_policy,
350350
frames_per_batch=64,
@@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy):
369369
scheme = NoWeightSyncScheme()
370370

371371
collector = SyncDataCollector(
372-
create_env_fn=lambda: GymEnv("CartPole-v1"),
372+
create_env_fn=ContinuousActionVecMockEnv,
373373
policy=simple_policy,
374374
frames_per_batch=64,
375375
total_frames=128,
@@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy):
385385

386386
class TestMultiModelUpdates:
387387
def test_multi_model_state_dict_updates(self):
388-
env = GymEnv("CartPole-v1")
388+
env = ContinuousActionVecMockEnv()
389389

390390
policy = TensorDictModule(
391391
nn.Linear(
@@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self):
407407
}
408408

409409
collector = SyncDataCollector(
410-
create_env_fn=lambda: GymEnv("CartPole-v1"),
410+
create_env_fn=ContinuousActionVecMockEnv,
411411
policy=policy,
412412
frames_per_batch=64,
413413
total_frames=128,
@@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self):
438438
env.close()
439439

440440
def test_multi_model_tensordict_updates(self):
441-
env = GymEnv("CartPole-v1")
441+
env = ContinuousActionVecMockEnv()
442442

443443
policy = TensorDictModule(
444444
nn.Linear(
@@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self):
460460
}
461461

462462
collector = SyncDataCollector(
463-
create_env_fn=lambda: GymEnv("CartPole-v1"),
463+
create_env_fn=ContinuousActionVecMockEnv,
464464
policy=policy,
465465
frames_per_batch=64,
466466
total_frames=128,

torchrl/envs/libs/gym.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,12 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811
12551255

12561256
@property
12571257
def lib(self) -> ModuleType:
1258-
return gym_backend()
1258+
gym = gym_backend()
1259+
if gym is None:
1260+
raise RuntimeError(
1261+
"Gym backend is not available. Please install gym or gymnasium."
1262+
)
1263+
return gym
12591264

12601265
def _set_seed(self, seed: int | None) -> None: # noqa: F811
12611266
if self._seed_calls_reset is None:

0 commit comments

Comments
 (0)