Skip to content

Commit fbdbb61

Browse files
committed
[Feature] Collectors' getattr_policy and getattr_env
ghstack-source-id: bd63914 Pull-Request: #3171
1 parent a089cc4 commit fbdbb61

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

torchrl/collectors/collectors.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,13 @@ def get_policy_version(self) -> str | int | None:
18171817
"""
18181818
return self.policy_version
18191819

1820+
def getattr_policy(self, attr):
1821+
# send command to policy to return the attr
1822+
return getattr(self.policy, attr)
1823+
1824+
def getattr_env(self, attr):
1825+
# send command to env to return the attr
1826+
return getattr(self.env, attr)
18201827

18211828

18221829
class _MultiDataCollector(DataCollectorBase):
@@ -2782,6 +2789,57 @@ def get_policy_version(self) -> str | int | None:
27822789
"""
27832790
return self.policy_version
27842791

2792+
def getattr_policy(self, attr):
2793+
"""Get an attribute from the policy of the first worker.
2794+
2795+
Args:
2796+
attr (str): The attribute name to retrieve from the policy.
2797+
2798+
Returns:
2799+
The attribute value from the policy of the first worker.
2800+
2801+
Raises:
2802+
AttributeError: If the attribute doesn't exist on the policy.
2803+
"""
2804+
_check_for_faulty_process(self.procs)
2805+
2806+
# Send command to first worker (index 0)
2807+
self.pipes[0].send((attr, "getattr_policy"))
2808+
result, msg = self.pipes[0].recv()
2809+
if msg != "getattr_policy":
2810+
raise RuntimeError(f"Expected msg='getattr_policy', got {msg}")
2811+
2812+
# If the worker returned an AttributeError, re-raise it
2813+
if isinstance(result, AttributeError):
2814+
raise result
2815+
2816+
return result
2817+
2818+
def getattr_env(self, attr):
2819+
"""Get an attribute from the environment of the first worker.
2820+
2821+
Args:
2822+
attr (str): The attribute name to retrieve from the environment.
2823+
2824+
Returns:
2825+
The attribute value from the environment of the first worker.
2826+
2827+
Raises:
2828+
AttributeError: If the attribute doesn't exist on the environment.
2829+
"""
2830+
_check_for_faulty_process(self.procs)
2831+
2832+
# Send command to first worker (index 0)
2833+
self.pipes[0].send((attr, "getattr_env"))
2834+
result, msg = self.pipes[0].recv()
2835+
if msg != "getattr_env":
2836+
raise RuntimeError(f"Expected msg='getattr_env', got {msg}")
2837+
2838+
# If the worker returned an AttributeError, re-raise it
2839+
if isinstance(result, AttributeError):
2840+
raise result
2841+
2842+
return result
27852843

27862844

27872845
@accept_remote_rref_udf_invocation
@@ -3947,6 +4005,25 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
39474005
has_timed_out = False
39484006
continue
39494007

4008+
elif msg == "getattr_policy":
4009+
attr_name = data_in
4010+
try:
4011+
result = getattr(inner_collector.policy, attr_name)
4012+
pipe_child.send((result, "getattr_policy"))
4013+
except AttributeError as e:
4014+
pipe_child.send((e, "getattr_policy"))
4015+
has_timed_out = False
4016+
continue
4017+
4018+
elif msg == "getattr_env":
4019+
attr_name = data_in
4020+
try:
4021+
result = getattr(inner_collector.env, attr_name)
4022+
pipe_child.send((result, "getattr_env"))
4023+
except AttributeError as e:
4024+
pipe_child.send((e, "getattr_env"))
4025+
has_timed_out = False
4026+
continue
39504027

39514028
elif msg == "close":
39524029
del collected_tensordict, data, next_data, data_in

0 commit comments

Comments
 (0)