@@ -795,6 +795,28 @@ def input_spec(self) -> TensorSpec:
795795 input_spec = self .__dict__ .get ("_input_spec" , None )
796796 return input_spec
797797
798+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799+ if type (self .base_env ).rand_action is not EnvBase .rand_action :
800+ # TODO: this will fail if the transform modifies the input.
801+ # For instance, if an env overrides rand_action and we build a
802+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803+ # env.rand_action will NOT have a discrete action!
804+ # Getting a discrete action would require coding the inverse transform of an action within
805+ # ActionDiscretizer (ie, float->int, not int->float).
806+ # We can loosely check that the action_spec isn't altered - that doesn't mean the action is
807+ # intact but it covers part of these alterations.
808+ #
809+ # The following check may be expensive to run and could be cached.
810+ if self .full_action_spec != self .base_env .full_action_spec :
811+ raise RuntimeError (
812+ f"The rand_action method from the base env { self .base_env .__class__ .__name__ } "
813+ "has been overwritten, but the transforms appended to the environment modify "
814+ "the action. To call the base env rand_action method, we should then invert the "
815+ "action transform, which is (in general) not doable."
816+ )
817+ return self .base_env .rand_action (tensordict )
818+ return super ().rand_action (tensordict )
819+
798820 def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799821 # No need to clone here because inv does it already
800822 # tensordict = tensordict.clone(False)
0 commit comments