Skip to content

Commit a089cc4

Browse files
committed
[Feature] track_policy_version in collectors.py
ghstack-source-id: 91a6b1f Pull-Request: #3170
1 parent e7583b3 commit a089cc4

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed

torchrl/collectors/collectors.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
set_exploration_type,
7070
)
7171

72+
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
73+
7274
try:
7375
from torch.compiler import cudagraph_mark_step_begin
7476
except ImportError:
@@ -571,6 +573,11 @@ class SyncDataCollector(DataCollectorBase):
571573
or its subclass, responsible for updating the policy weights on remote inference workers.
572574
This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
573575
Consider using a constructor if the updater needs to be serialized.
576+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
577+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
578+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
579+
the policy version.
580+
Defaults to `False`.
574581
575582
Examples:
576583
>>> from torchrl.envs.libs.gym import GymEnv
@@ -665,6 +672,7 @@ def __init__(
665672
weight_updater: WeightUpdaterBase
666673
| Callable[[], WeightUpdaterBase]
667674
| None = None,
675+
track_policy_version: bool = False,
668676
**kwargs,
669677
):
670678
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -783,6 +791,33 @@ def __init__(
783791

784792
self.env: EnvBase = env
785793
del env
794+
795+
# Policy version tracking setup
796+
self.policy_version_tracker = track_policy_version
797+
if PolicyVersion is not None:
798+
if isinstance(track_policy_version, bool) and track_policy_version:
799+
from torchrl.envs.batched_envs import BatchedEnvBase
800+
801+
if isinstance(self.env, BatchedEnvBase):
802+
raise RuntimeError(
803+
"BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
804+
"and pass that transform to the collector."
805+
)
806+
self.policy_version_tracker = PolicyVersion()
807+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
808+
elif hasattr(
809+
track_policy_version, "increment_version"
810+
): # Check if it's a PolicyVersion instance
811+
self.policy_version_tracker = track_policy_version
812+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
813+
else:
814+
self.policy_version_tracker = None
815+
else:
816+
if track_policy_version:
817+
raise ImportError(
818+
"PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
819+
)
820+
self.policy_version_tracker = None
786821
self.replay_buffer = replay_buffer
787822
self.extend_buffer = extend_buffer
788823
if self.replay_buffer is not None:
@@ -1755,6 +1790,34 @@ def __repr__(self) -> str:
17551790
except Exception:
17561791
return f"{type(self).__name__}(not_init)"
17571792

1793+
def increment_version(self):
1794+
"""Increment the policy version."""
1795+
if self.policy_version_tracker is not None:
1796+
if not hasattr(self.policy_version_tracker, "increment_version"):
1797+
raise RuntimeError(
1798+
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
1799+
)
1800+
self.policy_version_tracker.increment_version()
1801+
1802+
@property
1803+
def policy_version(self) -> str | int | None:
1804+
"""The current policy version."""
1805+
if not hasattr(self.policy_version_tracker, "version"):
1806+
return None
1807+
return self.policy_version_tracker.version
1808+
1809+
def get_policy_version(self) -> str | int | None:
1810+
"""Get the current policy version.
1811+
1812+
This method exists to support remote calls in Ray actors, since properties
1813+
cannot be accessed directly through Ray's RPC mechanism.
1814+
1815+
Returns:
1816+
The current version number (int) or UUID (str), or None if version tracking is disabled.
1817+
"""
1818+
return self.policy_version
1819+
1820+
17581821

17591822
class _MultiDataCollector(DataCollectorBase):
17601823
"""Runs a given number of DataCollectors on separate processes.
@@ -1944,6 +2007,11 @@ class _MultiDataCollector(DataCollectorBase):
19442007
If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default,
19452008
which handles weight synchronization across multiple processes.
19462009
Consider using a constructor if the updater needs to be serialized.
2010+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
2011+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
2012+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
2013+
the policy version.
2014+
Defaults to `False`.
19472015
19482016
"""
19492017

@@ -1989,6 +2057,7 @@ def __init__(
19892057
weight_updater: WeightUpdaterBase
19902058
| Callable[[], WeightUpdaterBase]
19912059
| None = None,
2060+
track_policy_version: bool = False,
19922061
):
19932062
self.closed = True
19942063
if isinstance(create_env_fn, Sequence):
@@ -2125,6 +2194,24 @@ def __init__(
21252194

21262195
self.weight_updater = weight_updater
21272196

2197+
# Policy version tracking setup
2198+
self.policy_version_tracker = track_policy_version
2199+
if PolicyVersion is not None:
2200+
if isinstance(track_policy_version, bool) and track_policy_version:
2201+
self.policy_version_tracker = PolicyVersion()
2202+
elif hasattr(
2203+
track_policy_version, "increment_version"
2204+
): # Check if it's a PolicyVersion instance
2205+
self.policy_version_tracker = track_policy_version
2206+
else:
2207+
self.policy_version_tracker = None
2208+
else:
2209+
if track_policy_version:
2210+
raise ImportError(
2211+
"PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
2212+
)
2213+
self.policy_version_tracker = None
2214+
21282215
self.policy = policy
21292216
self.policy_factory = policy_factory
21302217

@@ -2668,6 +2755,34 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
26682755
self._frames = state_dict["frames"]
26692756
self._iter = state_dict["iter"]
26702757

2758+
def increment_version(self):
2759+
"""Increment the policy version."""
2760+
if self.policy_version_tracker is not None:
2761+
if not hasattr(self.policy_version_tracker, "increment_version"):
2762+
raise RuntimeError(
2763+
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
2764+
)
2765+
self.policy_version_tracker.increment_version()
2766+
2767+
@property
2768+
def policy_version(self) -> str | int | None:
2769+
"""The current policy version."""
2770+
if not hasattr(self.policy_version_tracker, "version"):
2771+
return None
2772+
return self.policy_version_tracker.version
2773+
2774+
def get_policy_version(self) -> str | int | None:
2775+
"""Get the current policy version.
2776+
2777+
This method exists to support remote calls in Ray actors, since properties
2778+
cannot be accessed directly through Ray's RPC mechanism.
2779+
2780+
Returns:
2781+
The current version number (int) or UUID (str), or None if version tracking is disabled.
2782+
"""
2783+
return self.policy_version
2784+
2785+
26712786

26722787
@accept_remote_rref_udf_invocation
26732788
class MultiSyncDataCollector(_MultiDataCollector):
@@ -3473,6 +3588,11 @@ class aSyncDataCollector(MultiaSyncDataCollector):
34733588
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
34743589
Truncated keys can be set through ``env.add_truncated_keys``.
34753590
Defaults to ``False``.
3591+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
3592+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
3593+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
3594+
the policy version.
3595+
Defaults to `False`.
34763596
34773597
"""
34783598

@@ -3502,6 +3622,7 @@ def __init__(
35023622
num_threads: int | None = None,
35033623
num_sub_threads: int = 1,
35043624
set_truncated: bool = False,
3625+
track_policy_version: bool = False,
35053626
**kwargs,
35063627
):
35073628
super().__init__(
@@ -3529,6 +3650,7 @@ def __init__(
35293650
num_threads=num_threads,
35303651
num_sub_threads=num_sub_threads,
35313652
set_truncated=set_truncated,
3653+
track_policy_version=track_policy_version,
35323654
**kwargs,
35333655
)
35343656

@@ -3825,6 +3947,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
38253947
has_timed_out = False
38263948
continue
38273949

3950+
38283951
elif msg == "close":
38293952
del collected_tensordict, data, next_data, data_in
38303953
inner_collector.shutdown()

0 commit comments

Comments
 (0)