@@ -1817,6 +1817,13 @@ def get_policy_version(self) -> str | int | None:
18171817 """
18181818 return self .policy_version
18191819
1820+ def getattr_policy (self , attr ):
1821+ # send command to policy to return the attr
1822+ return getattr (self .policy , attr )
1823+
1824+ def getattr_env (self , attr ):
1825+ # send command to env to return the attr
1826+ return getattr (self .env , attr )
18201827
18211828
18221829class _MultiDataCollector (DataCollectorBase ):
@@ -2782,6 +2789,57 @@ def get_policy_version(self) -> str | int | None:
27822789 """
27832790 return self .policy_version
27842791
2792+ def getattr_policy (self , attr ):
2793+ """Get an attribute from the policy of the first worker.
2794+
2795+ Args:
2796+ attr (str): The attribute name to retrieve from the policy.
2797+
2798+ Returns:
2799+ The attribute value from the policy of the first worker.
2800+
2801+ Raises:
2802+ AttributeError: If the attribute doesn't exist on the policy.
2803+ """
2804+ _check_for_faulty_process (self .procs )
2805+
2806+ # Send command to first worker (index 0)
2807+ self .pipes [0 ].send ((attr , "getattr_policy" ))
2808+ result , msg = self .pipes [0 ].recv ()
2809+ if msg != "getattr_policy" :
2810+ raise RuntimeError (f"Expected msg='getattr_policy', got { msg } " )
2811+
2812+ # If the worker returned an AttributeError, re-raise it
2813+ if isinstance (result , AttributeError ):
2814+ raise result
2815+
2816+ return result
2817+
2818+ def getattr_env (self , attr ):
2819+ """Get an attribute from the environment of the first worker.
2820+
2821+ Args:
2822+ attr (str): The attribute name to retrieve from the environment.
2823+
2824+ Returns:
2825+ The attribute value from the environment of the first worker.
2826+
2827+ Raises:
2828+ AttributeError: If the attribute doesn't exist on the environment.
2829+ """
2830+ _check_for_faulty_process (self .procs )
2831+
2832+ # Send command to first worker (index 0)
2833+ self .pipes [0 ].send ((attr , "getattr_env" ))
2834+ result , msg = self .pipes [0 ].recv ()
2835+ if msg != "getattr_env" :
2836+ raise RuntimeError (f"Expected msg='getattr_env', got { msg } " )
2837+
2838+ # If the worker returned an AttributeError, re-raise it
2839+ if isinstance (result , AttributeError ):
2840+ raise result
2841+
2842+ return result
27852843
27862844
27872845@accept_remote_rref_udf_invocation
@@ -3947,6 +4005,25 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
39474005 has_timed_out = False
39484006 continue
39494007
4008+ elif msg == "getattr_policy" :
4009+ attr_name = data_in
4010+ try :
4011+ result = getattr (inner_collector .policy , attr_name )
4012+ pipe_child .send ((result , "getattr_policy" ))
4013+ except AttributeError as e :
4014+ pipe_child .send ((e , "getattr_policy" ))
4015+ has_timed_out = False
4016+ continue
4017+
4018+ elif msg == "getattr_env" :
4019+ attr_name = data_in
4020+ try :
4021+ result = getattr (inner_collector .env , attr_name )
4022+ pipe_child .send ((result , "getattr_env" ))
4023+ except AttributeError as e :
4024+ pipe_child .send ((e , "getattr_env" ))
4025+ has_timed_out = False
4026+ continue
39504027
39514028 elif msg == "close" :
39524029 del collected_tensordict , data , next_data , data_in
0 commit comments