Skip to content

Commit 4fb3b4a

Browse files
committed
Update
[ghstack-poisoned]
1 parent 5743034 commit 4fb3b4a

File tree

7 files changed

+235
-40
lines changed

7 files changed

+235
-40
lines changed

torchrl/envs/llm/transforms/dataloading.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec
1717
from torchrl.envs.common import EnvBase
18+
from torchrl.envs.transforms import TensorDictPrimer, Transform
1819

1920
# Import ray service components
20-
from torchrl.envs.llm.transforms.ray_service import (
21+
from torchrl.envs.transforms.ray_service import (
2122
_map_input_output_device,
2223
_RayServiceMetaClass,
2324
RayTransform,
2425
)
25-
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
2626
from torchrl.envs.utils import make_composite_from_td
2727

2828
T = TypeVar("T")
@@ -259,7 +259,7 @@ def primers(self):
259259
@primers.setter
260260
def primers(self, value: TensorSpec):
261261
"""Set primers property."""
262-
self._ray.get(self._actor.set_attr.remote("primers", value))
262+
self._ray.get(self._actor._set_attr.remote("primers", value))
263263

264264
# TensorDictPrimer methods
265265
def init(self, tensordict: TensorDictBase | None):
@@ -857,7 +857,3 @@ def _update_primers_batch_size(self, parent_batch_size):
857857
def __repr__(self) -> str:
858858
class_name = self.__class__.__name__
859859
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"
860-
861-
def set_attr(self, name, value):
862-
"""Set attribute on the remote actor or locally."""
863-
setattr(self, name, value)

torchrl/envs/llm/transforms/kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchrl.data import Composite, Unbounded
1919
from torchrl.data.tensor_specs import DEVICE_TYPING
2020
from torchrl.envs import EnvBase, Transform
21-
from torchrl.envs.llm.transforms.ray_service import _RayServiceMetaClass, RayTransform
21+
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
2222
from torchrl.envs.transforms.transforms import Compose
2323
from torchrl.envs.transforms.utils import _set_missing_tolerance
2424
from torchrl.modules.llm.policies.common import LLMWrapperBase

torchrl/envs/llm/transforms/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from tensordict import lazy_stack, TensorDictBase
20-
from torchrl import torchrl_logger
20+
from torchrl._utils import logger as torchrl_logger
2121
from torchrl.data.llm import History
2222

2323
from torchrl.envs import Transform

torchrl/envs/transforms/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from .gym_transforms import EndOfLifeTransform
77
from .llm import KLRewardTransform
8+
from .module import ModuleTransform
89
from .r3m import R3MTransform
10+
from .ray_service import RayTransform
911
from .rb_transforms import MultiStepTransform
10-
1112
from .transforms import (
1213
ActionDiscretizer,
1314
ActionMask,
@@ -85,9 +86,9 @@
8586
"CatFrames",
8687
"CatTensors",
8788
"CenterCrop",
88-
"ConditionalPolicySwitch",
8989
"ClipTransform",
9090
"Compose",
91+
"ConditionalPolicySwitch",
9192
"ConditionalSkip",
9293
"Crop",
9394
"DTypeCastTransform",
@@ -104,6 +105,7 @@
104105
"InitTracker",
105106
"KLRewardTransform",
106107
"LineariseRewards",
108+
"ModuleTransform",
107109
"MultiAction",
108110
"MultiStepTransform",
109111
"NoopResetEnv",
@@ -113,6 +115,7 @@
113115
"PinMemoryTransform",
114116
"R3MTransform",
115117
"RandomCropTensorDict",
118+
"RayTransform",
116119
"RemoveEmptySpecs",
117120
"RenameTransform",
118121
"Resize",

torchrl/envs/transforms/module.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from collections.abc import Callable
7+
from contextlib import nullcontext
8+
from typing import overload
9+
10+
import torch
11+
from tensordict import TensorDictBase
12+
from tensordict.nn import TensorDictModuleBase
13+
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
14+
from torchrl.envs.transforms.transforms import Transform
15+
16+
17+
__all__ = ["ModuleTransform", "RayModuleTransform"]
18+
19+
20+
class RayModuleTransform(RayTransform):
21+
"""Ray-based ModuleTransform for distributed processing.
22+
23+
This transform creates a Ray actor that wraps a ModuleTransform,
24+
allowing module execution in a separate Ray worker process.
25+
"""
26+
27+
def _create_actor(self, **kwargs):
28+
return self._ray.remote(ModuleTransform).remote(**kwargs)
29+
30+
@overload
31+
def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
32+
...
33+
34+
@overload
35+
def update_weights(self, params: TensorDictBase) -> None:
36+
...
37+
38+
def update_weights(self, *args, **kwargs) -> None:
39+
import ray
40+
41+
if self._update_weights_method == "tensordict":
42+
try:
43+
td = kwargs.get("params", args[0])
44+
except IndexError:
45+
raise ValueError("params must be provided")
46+
return ray.get(self._actor._update_weights_tensordict.remote(params=td))
47+
elif self._update_weights_method == "state_dict":
48+
try:
49+
state_dict = kwargs.get("state_dict", args[0])
50+
except IndexError:
51+
raise ValueError("state_dict must be provided")
52+
return ray.get(
53+
self._actor._update_weights_state_dict.remote(state_dict=state_dict)
54+
)
55+
else:
56+
raise ValueError(
57+
f"Invalid update_weights_method: {self._update_weights_method}"
58+
)
59+
60+
61+
class ModuleTransform(Transform, metaclass=_RayServiceMetaClass):
62+
"""A transform that wraps a module.
63+
64+
Keyword Args:
65+
module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided.
66+
module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided.
67+
no_grad (bool, optional): Whether to use gradient computation. Default is `False`.
68+
inverse (bool, optional): Whether to use the inverse of the module. Default is `False`.
69+
device (torch.device, optional): The device to use. Default is `None`.
70+
use_ray_service (bool, optional): Whether to use Ray service. Default is `False`.
71+
actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and
72+
an actor with this name already exists, the existing actor will be used.
73+
74+
"""
75+
76+
_RayServiceClass = RayModuleTransform
77+
78+
def __init__(
79+
self,
80+
*,
81+
module: TensorDictModuleBase | None = None,
82+
module_factory: Callable[[], TensorDictModuleBase] | None = None,
83+
no_grad: bool = False,
84+
inverse: bool = False,
85+
device: torch.device | None = None,
86+
use_ray_service: bool = False,
87+
actor_name: str | None = None,
88+
):
89+
super().__init__()
90+
if module is None and module_factory is None:
91+
raise ValueError(
92+
"At least one of `module` or `module_factory` must be provided."
93+
)
94+
if module is not None and module_factory is not None:
95+
raise ValueError(
96+
"Only one of `module` or `module_factory` must be provided."
97+
)
98+
self.module = module if module is not None else module_factory()
99+
self.no_grad = no_grad
100+
self.inverse = inverse
101+
self.device = device
102+
103+
@property
104+
def in_keys(self) -> list[str]:
105+
return self._in_keys()
106+
107+
def _in_keys(self):
108+
return self.module.in_keys if not self.inverse else []
109+
110+
@in_keys.setter
111+
def in_keys(self, value: list[str] | None):
112+
if value is not None:
113+
raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform")
114+
115+
@property
116+
def out_keys(self) -> list[str]:
117+
return self._out_keys()
118+
119+
def _out_keys(self):
120+
return self.module.out_keys if not self.inverse else []
121+
122+
@property
123+
def in_keys_inv(self) -> list[str]:
124+
return self._in_keys_inv()
125+
126+
def _in_keys_inv(self):
127+
return self.module.out_keys if self.inverse else []
128+
129+
@in_keys_inv.setter
130+
def in_keys_inv(self, value: list[str]):
131+
if value is not None:
132+
raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform")
133+
134+
@property
135+
def out_keys_inv(self) -> list[str]:
136+
return self._out_keys_inv()
137+
138+
def _out_keys_inv(self):
139+
return self.module.in_keys if self.inverse else []
140+
141+
@out_keys_inv.setter
142+
def out_keys_inv(self, value: list[str] | None):
143+
if value is not None:
144+
raise RuntimeError(
145+
f"out_keys_inv {value} cannot be set for ModuleTransform"
146+
)
147+
148+
@out_keys.setter
149+
def out_keys(self, value: list[str] | None):
150+
if value is not None:
151+
raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform")
152+
153+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
154+
if self.inverse:
155+
return tensordict
156+
with torch.no_grad() if self.no_grad else nullcontext():
157+
with (
158+
tensordict.to(self.device)
159+
if self.device is not None
160+
else nullcontext(tensordict)
161+
) as td:
162+
return self.module(td)
163+
164+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
165+
if not self.inverse:
166+
return tensordict
167+
with torch.no_grad() if self.no_grad else nullcontext():
168+
with (
169+
tensordict.to(self.device)
170+
if self.device is not None
171+
else nullcontext(tensordict)
172+
) as td:
173+
return self.module(td)
174+
175+
def _update_weights_tensordict(self, params: TensorDictBase) -> None:
176+
params.to_module(self.module)
177+
178+
def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
179+
self.module.load_state_dict(state_dict)

0 commit comments

Comments
 (0)