Skip to content

Commit 5d400b4

Browse files
committed
Update
[ghstack-poisoned]
1 parent 9604f48 commit 5d400b4

File tree

5 files changed

+47
-3
lines changed

5 files changed

+47
-3
lines changed

test/llm/test_updaters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_open_port():
4141
if _has_ray:
4242
import ray
4343

44-
from ray_helpers import (
44+
from torchrl.testing import (
4545
WorkerTransformerDoubleBuffer,
4646
WorkerTransformerNCCL,
4747
WorkerVLLMDoubleBuffer,

test/test_rb.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4188,6 +4188,25 @@ def test_ray_rb_iter(self):
41884188
finally:
41894189
rb.close()
41904190

4191+
def test_ray_rb_serialization(self):
4192+
import ray
4193+
4194+
class Worker:
4195+
def __init__(self, rb):
4196+
self.rb = rb
4197+
4198+
def run(self):
4199+
self.rb.extend(TensorDict({"x": torch.ones(100)}, batch_size=100))
4200+
4201+
rb = RayReplayBuffer(
4202+
storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1}
4203+
)
4204+
try:
4205+
remote_worker = ray.remote(Worker).remote(rb)
4206+
ray.get(remote_worker.run.remote())
4207+
finally:
4208+
rb.close()
4209+
41914210

41924211
class TestSharedStorageInit:
41934212
def worker(self, rb, worker_id, queue):

torchrl/data/replay_buffers/ray_buffer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
else:
148148
self.has_gpu = False
149149
self._rb = remote_cls(*args, **kwargs)
150+
self._delayed_init = False
150151

151152
def close(self):
152153
"""Terminates the Ray actor associated with this replay buffer."""

torchrl/testing/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Testing utilities for TorchRL.
7+
8+
This module provides helper classes and utilities for testing TorchRL functionality,
9+
particularly for distributed and Ray-based tests that require importable classes.
10+
"""
11+
12+
from torchrl.testing.ray_helpers import (
13+
WorkerTransformerDoubleBuffer,
14+
WorkerTransformerNCCL,
15+
WorkerVLLMDoubleBuffer,
16+
WorkerVLLMNCCL,
17+
)
18+
19+
__all__ = [
20+
"WorkerVLLMNCCL",
21+
"WorkerTransformerNCCL",
22+
"WorkerVLLMDoubleBuffer",
23+
"WorkerTransformerDoubleBuffer",
24+
]

test/llm/ray_helpers.py renamed to torchrl/testing/ray_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
"""Helper classes for Ray-based weight synchronization tests.
77
88
This module contains Ray actor classes that need to be importable by Ray workers.
9-
These classes are used in test_updaters.py but must be defined at module level
10-
so Ray can serialize and import them on remote workers.
9+
These classes are used in tests but must be defined at module level in a proper
10+
Python package (not in test files) so Ray can serialize and import them on remote workers.
1111
"""
1212

1313
import torch

0 commit comments

Comments
 (0)