Skip to content

Commit 5930cb3

Browse files
committed
Update
[ghstack-poisoned]
2 parents d24c52b + a5bb971 commit 5930cb3

File tree

5 files changed

+61
-25
lines changed

5 files changed

+61
-25
lines changed

.github/unittest/llm/scripts_llm/install.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ git submodule sync && git submodule update --init --recursive
3030
#printf "Installing PyTorch with cu128"
3131
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
3232
# if [ "${CU_VERSION:-}" == cpu ] ; then
33-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U
3434
# else
35-
# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
# pip install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U
3636
# fi
3737
#elif [[ "$TORCH_VERSION" == "stable" ]]; then
3838
# if [ "${CU_VERSION:-}" == cpu ] ; then
39-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
39+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu
4040
# else
41-
# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
41+
# pip install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128
4242
# fi
4343
#else
4444
# printf "Failed to install pytorch"
@@ -47,9 +47,10 @@ git submodule sync && git submodule update --init --recursive
4747

4848
# install tensordict
4949
if [[ "$RELEASE" == 0 ]]; then
50-
pip3 install git+https://github.com/pytorch/tensordict.git
50+
pip install "pybind11[global]" ninja
51+
pip install git+https://github.com/pytorch/tensordict.git
5152
else
52-
pip3 install tensordict
53+
pip install tensordict
5354
fi
5455

5556
# smoke test

test/test_specs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4585,6 +4585,26 @@ def test_names_repr(self):
45854585
assert "Composite" in repr_str
45864586
assert "obs" in repr_str
45874587

4588+
def test_zero_create_names(self):
4589+
"""Test that creating tensors with 'zero' propagates names."""
4590+
spec = Composite(
4591+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4592+
shape=(10,),
4593+
names=["batch"],
4594+
)
4595+
td = spec.zero()
4596+
td.names = ["batch"]
4597+
4598+
def test_rand_create_names(self):
4599+
"""Test that creating tensors with 'rand' propagates names."""
4600+
spec = Composite(
4601+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4602+
shape=(10,),
4603+
names=["batch"],
4604+
)
4605+
td = spec.rand()
4606+
td.names = ["batch"]
4607+
45884608

45894609
if __name__ == "__main__":
45904610
args, unknown = argparse.ArgumentParser().parse_known_args()

test/test_weightsync.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
from mocking_classes import ContinuousActionVecMockEnv
1213
from tensordict import TensorDict
1314
from tensordict.nn import TensorDictModule
1415
from torch import multiprocessing as mp
1516
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
16-
from torchrl.envs import GymEnv
1717
from torchrl.weight_update.weight_sync_schemes import (
1818
_resolve_model,
1919
MPTransport,
@@ -274,7 +274,7 @@ def test_no_weight_sync_scheme(self):
274274
class TestCollectorIntegration:
275275
@pytest.fixture
276276
def simple_env(self):
277-
return GymEnv("CartPole-v1")
277+
return ContinuousActionVecMockEnv()
278278

279279
@pytest.fixture
280280
def simple_policy(self, simple_env):
@@ -291,7 +291,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy):
291291
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
292292

293293
collector = SyncDataCollector(
294-
create_env_fn=lambda: GymEnv("CartPole-v1"),
294+
create_env_fn=ContinuousActionVecMockEnv,
295295
policy=simple_policy,
296296
frames_per_batch=64,
297297
total_frames=128,
@@ -316,8 +316,8 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy):
316316

317317
collector = MultiSyncDataCollector(
318318
create_env_fn=[
319-
lambda: GymEnv("CartPole-v1"),
320-
lambda: GymEnv("CartPole-v1"),
319+
ContinuousActionVecMockEnv,
320+
ContinuousActionVecMockEnv,
321321
],
322322
policy=simple_policy,
323323
frames_per_batch=64,
@@ -343,8 +343,8 @@ def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy):
343343

344344
collector = MultiSyncDataCollector(
345345
create_env_fn=[
346-
lambda: GymEnv("CartPole-v1"),
347-
lambda: GymEnv("CartPole-v1"),
346+
ContinuousActionVecMockEnv,
347+
ContinuousActionVecMockEnv,
348348
],
349349
policy=simple_policy,
350350
frames_per_batch=64,
@@ -369,7 +369,7 @@ def test_collector_no_weight_sync(self, simple_policy):
369369
scheme = NoWeightSyncScheme()
370370

371371
collector = SyncDataCollector(
372-
create_env_fn=lambda: GymEnv("CartPole-v1"),
372+
create_env_fn=ContinuousActionVecMockEnv,
373373
policy=simple_policy,
374374
frames_per_batch=64,
375375
total_frames=128,
@@ -385,7 +385,7 @@ def test_collector_no_weight_sync(self, simple_policy):
385385

386386
class TestMultiModelUpdates:
387387
def test_multi_model_state_dict_updates(self):
388-
env = GymEnv("CartPole-v1")
388+
env = ContinuousActionVecMockEnv()
389389

390390
policy = TensorDictModule(
391391
nn.Linear(
@@ -407,7 +407,7 @@ def test_multi_model_state_dict_updates(self):
407407
}
408408

409409
collector = SyncDataCollector(
410-
create_env_fn=lambda: GymEnv("CartPole-v1"),
410+
create_env_fn=ContinuousActionVecMockEnv,
411411
policy=policy,
412412
frames_per_batch=64,
413413
total_frames=128,
@@ -438,7 +438,7 @@ def test_multi_model_state_dict_updates(self):
438438
env.close()
439439

440440
def test_multi_model_tensordict_updates(self):
441-
env = GymEnv("CartPole-v1")
441+
env = ContinuousActionVecMockEnv()
442442

443443
policy = TensorDictModule(
444444
nn.Linear(
@@ -460,7 +460,7 @@ def test_multi_model_tensordict_updates(self):
460460
}
461461

462462
collector = SyncDataCollector(
463-
create_env_fn=lambda: GymEnv("CartPole-v1"),
463+
create_env_fn=ContinuousActionVecMockEnv,
464464
policy=policy,
465465
frames_per_batch=64,
466466
total_frames=128,

torchrl/data/tensor_specs.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5740,16 +5740,22 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
57405740
for key, item in self.items():
57415741
if item is not None:
57425742
_dict[key] = item.rand(shape)
5743-
if self.data_cls is None:
5744-
cls = TensorDict
5743+
5744+
cls = self.data_cls if self.data_cls is not None else TensorDict
5745+
if cls is not TensorDict:
5746+
kwargs = {}
5747+
if self._td_dim_names is not None:
5748+
warnings.warn(f"names for cls {cls} is not supported for rand.")
57455749
else:
5746-
cls = self.data_cls
5750+
kwargs = {"names": self._td_dim_names}
5751+
57475752
# No need to run checks since we know Composite is compliant with
57485753
# TensorDict requirements
57495754
return cls.from_dict(
57505755
_dict,
57515756
batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
57525757
device=self.device,
5758+
**kwargs,
57535759
)
57545760

57555761
def keys(
@@ -6017,10 +6023,13 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60176023
except RuntimeError:
60186024
device = self._device
60196025

6020-
if self.data_cls is not None:
6021-
cls = self.data_cls
6026+
cls = self.data_cls if self.data_cls is not None else TensorDict
6027+
if cls is not TensorDict:
6028+
kwargs = {}
6029+
if self._td_dim_names is not None:
6030+
warnings.warn(f"names for cls {cls} is not supported for zero.")
60226031
else:
6023-
cls = TensorDict
6032+
kwargs = {"names": self._td_dim_names}
60246033

60256034
return cls.from_dict(
60266035
{
@@ -6030,6 +6039,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
60306039
},
60316040
batch_size=_size([*shape, *self._safe_shape]),
60326041
device=device,
6042+
**kwargs,
60336043
)
60346044

60356045
def __eq__(self, other: object) -> bool:

torchrl/envs/libs/gym.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,12 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811
12551255

12561256
@property
12571257
def lib(self) -> ModuleType:
1258-
return gym_backend()
1258+
gym = gym_backend()
1259+
if gym is None:
1260+
raise RuntimeError(
1261+
"Gym backend is not available. Please install gym or gymnasium."
1262+
)
1263+
return gym
12591264

12601265
def _set_seed(self, seed: int | None) -> None: # noqa: F811
12611266
if self._seed_calls_reset is None:

0 commit comments

Comments
 (0)