|
69 | 69 | set_exploration_type, |
70 | 70 | ) |
71 | 71 |
|
| 72 | +from torchrl.envs.llm.transforms.policy_version import PolicyVersion |
| 73 | + |
72 | 74 | try: |
73 | 75 | from torch.compiler import cudagraph_mark_step_begin |
74 | 76 | except ImportError: |
@@ -571,6 +573,11 @@ class SyncDataCollector(DataCollectorBase): |
571 | 573 | or its subclass, responsible for updating the policy weights on remote inference workers. |
572 | 574 | This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. |
573 | 575 | Consider using a constructor if the updater needs to be serialized. |
| 576 | + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. |
| 577 | + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. |
| 578 | + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track |
| 579 | + the policy version. |
| 580 | + Defaults to `False`. |
574 | 581 |
|
575 | 582 | Examples: |
576 | 583 | >>> from torchrl.envs.libs.gym import GymEnv |
@@ -665,6 +672,7 @@ def __init__( |
665 | 672 | weight_updater: WeightUpdaterBase |
666 | 673 | | Callable[[], WeightUpdaterBase] |
667 | 674 | | None = None, |
| 675 | + track_policy_version: bool = False, |
668 | 676 | **kwargs, |
669 | 677 | ): |
670 | 678 | from torchrl.envs.batched_envs import BatchedEnvBase |
@@ -783,6 +791,33 @@ def __init__( |
783 | 791 |
|
784 | 792 | self.env: EnvBase = env |
785 | 793 | del env |
| 794 | + |
| 795 | + # Policy version tracking setup |
| 796 | + self.policy_version_tracker = track_policy_version |
| 797 | + if PolicyVersion is not None: |
| 798 | + if isinstance(track_policy_version, bool) and track_policy_version: |
| 799 | + from torchrl.envs.batched_envs import BatchedEnvBase |
| 800 | + |
| 801 | + if isinstance(self.env, BatchedEnvBase): |
| 802 | + raise RuntimeError( |
| 803 | + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " |
| 804 | + "and pass that transform to the collector." |
| 805 | + ) |
| 806 | + self.policy_version_tracker = PolicyVersion() |
| 807 | + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore |
| 808 | + elif hasattr( |
| 809 | + track_policy_version, "increment_version" |
| 810 | + ): # Check if it's a PolicyVersion instance |
| 811 | + self.policy_version_tracker = track_policy_version |
| 812 | + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore |
| 813 | + else: |
| 814 | + self.policy_version_tracker = None |
| 815 | + else: |
| 816 | + if track_policy_version: |
| 817 | + raise ImportError( |
| 818 | + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." |
| 819 | + ) |
| 820 | + self.policy_version_tracker = None |
786 | 821 | self.replay_buffer = replay_buffer |
787 | 822 | self.extend_buffer = extend_buffer |
788 | 823 | if self.replay_buffer is not None: |
@@ -1755,6 +1790,34 @@ def __repr__(self) -> str: |
1755 | 1790 | except Exception: |
1756 | 1791 | return f"{type(self).__name__}(not_init)" |
1757 | 1792 |
|
| 1793 | + def increment_version(self): |
| 1794 | + """Increment the policy version.""" |
| 1795 | + if self.policy_version_tracker is not None: |
| 1796 | + if not hasattr(self.policy_version_tracker, "increment_version"): |
| 1797 | + raise RuntimeError( |
| 1798 | + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." |
| 1799 | + ) |
| 1800 | + self.policy_version_tracker.increment_version() |
| 1801 | + |
| 1802 | + @property |
| 1803 | + def policy_version(self) -> str | int | None: |
| 1804 | + """The current policy version.""" |
| 1805 | + if not hasattr(self.policy_version_tracker, "version"): |
| 1806 | + return None |
| 1807 | + return self.policy_version_tracker.version |
| 1808 | + |
| 1809 | + def get_policy_version(self) -> str | int | None: |
| 1810 | + """Get the current policy version. |
| 1811 | +
|
| 1812 | + This method exists to support remote calls in Ray actors, since properties |
| 1813 | + cannot be accessed directly through Ray's RPC mechanism. |
| 1814 | +
|
| 1815 | + Returns: |
| 1816 | + The current version number (int) or UUID (str), or None if version tracking is disabled. |
| 1817 | + """ |
| 1818 | + return self.policy_version |
| 1819 | + |
| 1820 | + |
1758 | 1821 |
|
1759 | 1822 | class _MultiDataCollector(DataCollectorBase): |
1760 | 1823 | """Runs a given number of DataCollectors on separate processes. |
@@ -1944,6 +2007,11 @@ class _MultiDataCollector(DataCollectorBase): |
1944 | 2007 | If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, |
1945 | 2008 | which handles weight synchronization across multiple processes. |
1946 | 2009 | Consider using a constructor if the updater needs to be serialized. |
| 2010 | + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. |
| 2011 | + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. |
| 2012 | + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track |
| 2013 | + the policy version. |
| 2014 | + Defaults to `False`. |
1947 | 2015 |
|
1948 | 2016 | """ |
1949 | 2017 |
|
@@ -1989,6 +2057,7 @@ def __init__( |
1989 | 2057 | weight_updater: WeightUpdaterBase |
1990 | 2058 | | Callable[[], WeightUpdaterBase] |
1991 | 2059 | | None = None, |
| 2060 | + track_policy_version: bool = False, |
1992 | 2061 | ): |
1993 | 2062 | self.closed = True |
1994 | 2063 | if isinstance(create_env_fn, Sequence): |
@@ -2125,6 +2194,24 @@ def __init__( |
2125 | 2194 |
|
2126 | 2195 | self.weight_updater = weight_updater |
2127 | 2196 |
|
| 2197 | + # Policy version tracking setup |
| 2198 | + self.policy_version_tracker = track_policy_version |
| 2199 | + if PolicyVersion is not None: |
| 2200 | + if isinstance(track_policy_version, bool) and track_policy_version: |
| 2201 | + self.policy_version_tracker = PolicyVersion() |
| 2202 | + elif hasattr( |
| 2203 | + track_policy_version, "increment_version" |
| 2204 | + ): # Check if it's a PolicyVersion instance |
| 2205 | + self.policy_version_tracker = track_policy_version |
| 2206 | + else: |
| 2207 | + self.policy_version_tracker = None |
| 2208 | + else: |
| 2209 | + if track_policy_version: |
| 2210 | + raise ImportError( |
| 2211 | + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." |
| 2212 | + ) |
| 2213 | + self.policy_version_tracker = None |
| 2214 | + |
2128 | 2215 | self.policy = policy |
2129 | 2216 | self.policy_factory = policy_factory |
2130 | 2217 |
|
@@ -2668,6 +2755,34 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: |
2668 | 2755 | self._frames = state_dict["frames"] |
2669 | 2756 | self._iter = state_dict["iter"] |
2670 | 2757 |
|
| 2758 | + def increment_version(self): |
| 2759 | + """Increment the policy version.""" |
| 2760 | + if self.policy_version_tracker is not None: |
| 2761 | + if not hasattr(self.policy_version_tracker, "increment_version"): |
| 2762 | + raise RuntimeError( |
| 2763 | + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." |
| 2764 | + ) |
| 2765 | + self.policy_version_tracker.increment_version() |
| 2766 | + |
| 2767 | + @property |
| 2768 | + def policy_version(self) -> str | int | None: |
| 2769 | + """The current policy version.""" |
| 2770 | + if not hasattr(self.policy_version_tracker, "version"): |
| 2771 | + return None |
| 2772 | + return self.policy_version_tracker.version |
| 2773 | + |
| 2774 | + def get_policy_version(self) -> str | int | None: |
| 2775 | + """Get the current policy version. |
| 2776 | +
|
| 2777 | + This method exists to support remote calls in Ray actors, since properties |
| 2778 | + cannot be accessed directly through Ray's RPC mechanism. |
| 2779 | +
|
| 2780 | + Returns: |
| 2781 | + The current version number (int) or UUID (str), or None if version tracking is disabled. |
| 2782 | + """ |
| 2783 | + return self.policy_version |
| 2784 | + |
| 2785 | + |
2671 | 2786 |
|
2672 | 2787 | @accept_remote_rref_udf_invocation |
2673 | 2788 | class MultiSyncDataCollector(_MultiDataCollector): |
@@ -3473,6 +3588,11 @@ class aSyncDataCollector(MultiaSyncDataCollector): |
3473 | 3588 | a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. |
3474 | 3589 | Truncated keys can be set through ``env.add_truncated_keys``. |
3475 | 3590 | Defaults to ``False``. |
| 3591 | + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. |
| 3592 | + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. |
| 3593 | + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track |
| 3594 | + the policy version. |
| 3595 | + Defaults to `False`. |
3476 | 3596 |
|
3477 | 3597 | """ |
3478 | 3598 |
|
@@ -3502,6 +3622,7 @@ def __init__( |
3502 | 3622 | num_threads: int | None = None, |
3503 | 3623 | num_sub_threads: int = 1, |
3504 | 3624 | set_truncated: bool = False, |
| 3625 | + track_policy_version: bool = False, |
3505 | 3626 | **kwargs, |
3506 | 3627 | ): |
3507 | 3628 | super().__init__( |
@@ -3529,6 +3650,7 @@ def __init__( |
3529 | 3650 | num_threads=num_threads, |
3530 | 3651 | num_sub_threads=num_sub_threads, |
3531 | 3652 | set_truncated=set_truncated, |
| 3653 | + track_policy_version=track_policy_version, |
3532 | 3654 | **kwargs, |
3533 | 3655 | ) |
3534 | 3656 |
|
@@ -3825,6 +3947,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): |
3825 | 3947 | has_timed_out = False |
3826 | 3948 | continue |
3827 | 3949 |
|
| 3950 | + |
3828 | 3951 | elif msg == "close": |
3829 | 3952 | del collected_tensordict, data, next_data, data_in |
3830 | 3953 | inner_collector.shutdown() |
|
0 commit comments