99import warnings
1010from copy import deepcopy
1111from functools import partial , wraps
12- from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple
12+ from typing import Any , Callable , Iterator
1313
1414import numpy as np
1515import torch
@@ -476,7 +476,7 @@ def __init__(
476476 self ,
477477 * ,
478478 device : DEVICE_TYPING = None ,
479- batch_size : Optional [ torch .Size ] = None ,
479+ batch_size : torch .Size | None = None ,
480480 run_type_checks : bool = False ,
481481 allow_done_after_reset : bool = False ,
482482 spec_locked : bool = True ,
@@ -587,10 +587,10 @@ def auto_specs_(
587587 policy : Callable [[TensorDictBase ], TensorDictBase ],
588588 * ,
589589 tensordict : TensorDictBase | None = None ,
590- action_key : NestedKey | List [NestedKey ] = "action" ,
591- done_key : NestedKey | List [NestedKey ] | None = None ,
592- observation_key : NestedKey | List [NestedKey ] = "observation" ,
593- reward_key : NestedKey | List [NestedKey ] = "reward" ,
590+ action_key : NestedKey | list [NestedKey ] = "action" ,
591+ done_key : NestedKey | list [NestedKey ] | None = None ,
592+ observation_key : NestedKey | list [NestedKey ] = "observation" ,
593+ reward_key : NestedKey | list [NestedKey ] = "reward" ,
594594 ):
595595 """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
596596
@@ -692,7 +692,7 @@ def auto_specs_(
692692 if full_action_spec is not None :
693693 self .full_action_spec = full_action_spec
694694 if full_done_spec is not None :
695- self .full_done_specs = full_done_spec
695+ self .full_done_spec = full_done_spec
696696 if full_observation_spec is not None :
697697 self .full_observation_spec = full_observation_spec
698698 if full_reward_spec is not None :
@@ -704,8 +704,7 @@ def auto_specs_(
704704
705705 @wraps (check_env_specs_func )
706706 def check_env_specs (self , * args , ** kwargs ):
707- return_contiguous = kwargs .pop ("return_contiguous" , not self ._has_dynamic_specs )
708- kwargs ["return_contiguous" ] = return_contiguous
707+ kwargs .setdefault ("return_contiguous" , not self ._has_dynamic_specs )
709708 return check_env_specs_func (self , * args , ** kwargs )
710709
711710 check_env_specs .__doc__ = check_env_specs_func .__doc__
@@ -850,8 +849,7 @@ def ndim(self):
850849
851850 def append_transform (
852851 self ,
853- transform : "Transform" # noqa: F821
854- | Callable [[TensorDictBase ], TensorDictBase ],
852+ transform : Transform | Callable [[TensorDictBase ], TensorDictBase ], # noqa: F821
855853 ) -> EnvBase :
856854 """Returns a transformed environment where the callable/transform passed is applied.
857855
@@ -995,7 +993,7 @@ def output_spec(self, value: TensorSpec) -> None:
995993
996994 @property
997995 @_cache_value
998- def action_keys (self ) -> List [NestedKey ]:
996+ def action_keys (self ) -> list [NestedKey ]:
999997 """The action keys of an environment.
1000998
1001999 By default, there will only be one key named "action".
@@ -1008,7 +1006,7 @@ def action_keys(self) -> List[NestedKey]:
10081006
10091007 @property
10101008 @_cache_value
1011- def state_keys (self ) -> List [NestedKey ]:
1009+ def state_keys (self ) -> list [NestedKey ]:
10121010 """The state keys of an environment.
10131011
10141012 By default, there will only be one key named "state".
@@ -1205,7 +1203,7 @@ def full_action_spec(self, spec: Composite) -> None:
12051203 # Reward spec
12061204 @property
12071205 @_cache_value
1208- def reward_keys (self ) -> List [NestedKey ]:
1206+ def reward_keys (self ) -> list [NestedKey ]:
12091207 """The reward keys of an environment.
12101208
12111209 By default, there will only be one key named "reward".
@@ -1217,7 +1215,7 @@ def reward_keys(self) -> List[NestedKey]:
12171215
12181216 @property
12191217 @_cache_value
1220- def observation_keys (self ) -> List [NestedKey ]:
1218+ def observation_keys (self ) -> list [NestedKey ]:
12211219 """The observation keys of an environment.
12221220
12231221 By default, there will only be one key named "observation".
@@ -1416,7 +1414,7 @@ def full_reward_spec(self, spec: Composite) -> None:
14161414 # done spec
14171415 @property
14181416 @_cache_value
1419- def done_keys (self ) -> List [NestedKey ]:
1417+ def done_keys (self ) -> list [NestedKey ]:
14201418 """The done keys of an environment.
14211419
14221420 By default, there will only be one key named "done".
@@ -2205,8 +2203,8 @@ def register_gym(
22052203 id : str ,
22062204 * ,
22072205 entry_point : Callable | None = None ,
2208- transform : " Transform" | None = None , # noqa: F821
2209- info_keys : List [NestedKey ] | None = None ,
2206+ transform : Transform | None = None , # noqa: F821
2207+ info_keys : list [NestedKey ] | None = None ,
22102208 backend : str = None ,
22112209 to_numpy : bool = False ,
22122210 reward_threshold : float | None = None ,
@@ -2395,8 +2393,8 @@ def _register_gym(
23952393 cls ,
23962394 id ,
23972395 entry_point : Callable | None = None ,
2398- transform : " Transform" | None = None , # noqa: F821
2399- info_keys : List [NestedKey ] | None = None ,
2396+ transform : Transform | None = None , # noqa: F821
2397+ info_keys : list [NestedKey ] | None = None ,
24002398 to_numpy : bool = False ,
24012399 reward_threshold : float | None = None ,
24022400 nondeterministic : bool = False ,
@@ -2437,8 +2435,8 @@ def _register_gym( # noqa: F811
24372435 cls ,
24382436 id ,
24392437 entry_point : Callable | None = None ,
2440- transform : " Transform" | None = None , # noqa: F821
2441- info_keys : List [NestedKey ] | None = None ,
2438+ transform : Transform | None = None , # noqa: F821
2439+ info_keys : list [NestedKey ] | None = None ,
24422440 to_numpy : bool = False ,
24432441 reward_threshold : float | None = None ,
24442442 nondeterministic : bool = False ,
@@ -2485,8 +2483,8 @@ def _register_gym( # noqa: F811
24852483 cls ,
24862484 id ,
24872485 entry_point : Callable | None = None ,
2488- transform : " Transform" | None = None , # noqa: F821
2489- info_keys : List [NestedKey ] | None = None ,
2486+ transform : Transform | None = None , # noqa: F821
2487+ info_keys : list [NestedKey ] | None = None ,
24902488 to_numpy : bool = False ,
24912489 reward_threshold : float | None = None ,
24922490 nondeterministic : bool = False ,
@@ -2538,8 +2536,8 @@ def _register_gym( # noqa: F811
25382536 cls ,
25392537 id ,
25402538 entry_point : Callable | None = None ,
2541- transform : " Transform" | None = None , # noqa: F821
2542- info_keys : List [NestedKey ] | None = None ,
2539+ transform : Transform | None = None , # noqa: F821
2540+ info_keys : list [NestedKey ] | None = None ,
25432541 to_numpy : bool = False ,
25442542 reward_threshold : float | None = None ,
25452543 nondeterministic : bool = False ,
@@ -2594,8 +2592,8 @@ def _register_gym( # noqa: F811
25942592 cls ,
25952593 id ,
25962594 entry_point : Callable | None = None ,
2597- transform : " Transform" | None = None , # noqa: F821
2598- info_keys : List [NestedKey ] | None = None ,
2595+ transform : Transform | None = None , # noqa: F821
2596+ info_keys : list [NestedKey ] | None = None ,
25992597 to_numpy : bool = False ,
26002598 reward_threshold : float | None = None ,
26012599 nondeterministic : bool = False ,
@@ -2652,8 +2650,8 @@ def _register_gym( # noqa: F811
26522650 cls ,
26532651 id ,
26542652 entry_point : Callable | None = None ,
2655- transform : " Transform" | None = None , # noqa: F821
2656- info_keys : List [NestedKey ] | None = None ,
2653+ transform : Transform | None = None , # noqa: F821
2654+ info_keys : list [NestedKey ] | None = None ,
26572655 to_numpy : bool = False ,
26582656 reward_threshold : float | None = None ,
26592657 nondeterministic : bool = False ,
@@ -2710,7 +2708,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
27102708
27112709 def reset (
27122710 self ,
2713- tensordict : Optional [ TensorDictBase ] = None ,
2711+ tensordict : TensorDictBase | None = None ,
27142712 ** kwargs ,
27152713 ) -> TensorDictBase :
27162714 """Resets the environment.
@@ -2819,8 +2817,8 @@ def numel(self) -> int:
28192817 return prod (self .batch_size )
28202818
28212819 def set_seed (
2822- self , seed : Optional [ int ] = None , static_seed : bool = False
2823- ) -> Optional [ int ] :
2820+ self , seed : int | None = None , static_seed : bool = False
2821+ ) -> int | None :
28242822 """Sets the seed of the environment and returns the next seed to be used (which is the input seed if a single environment is present).
28252823
28262824 Args:
@@ -2841,7 +2839,7 @@ def set_seed(
28412839 return seed
28422840
28432841 @abc .abstractmethod
2844- def _set_seed (self , seed : Optional [ int ] ):
2842+ def _set_seed (self , seed : int | None ):
28452843 raise NotImplementedError
28462844
28472845 def set_state (self ):
@@ -2856,9 +2854,7 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
28562854 f"got { tensordict .batch_size } and { self .batch_size } "
28572855 )
28582856
2859- def all_actions (
2860- self , tensordict : Optional [TensorDictBase ] = None
2861- ) -> TensorDictBase :
2857+ def all_actions (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
28622858 """Generates all possible actions from the action spec.
28632859
28642860 This only works in environments with fully discrete actions.
@@ -2877,7 +2873,7 @@ def all_actions(
28772873
28782874 return self .full_action_spec .enumerate (use_mask = True )
28792875
2880- def rand_action (self , tensordict : Optional [ TensorDictBase ] = None ):
2876+ def rand_action (self , tensordict : TensorDictBase | None = None ):
28812877 """Performs a random action given the action_spec attribute.
28822878
28832879 Args:
@@ -2911,7 +2907,7 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None):
29112907 tensordict .update (r )
29122908 return tensordict
29132909
2914- def rand_step (self , tensordict : Optional [ TensorDictBase ] = None ) -> TensorDictBase :
2910+ def rand_step (self , tensordict : TensorDictBase | None = None ) -> TensorDictBase :
29152911 """Performs a random step in the environment given the action_spec attribute.
29162912
29172913 Args:
@@ -2947,15 +2943,15 @@ def _has_dynamic_specs(self) -> bool:
29472943 def rollout (
29482944 self ,
29492945 max_steps : int ,
2950- policy : Optional [ Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2951- callback : Optional [ Callable [[TensorDictBase , ...], Any ]] = None ,
2946+ policy : Callable [[TensorDictBase ], TensorDictBase ] | None = None ,
2947+ callback : Callable [[TensorDictBase , ...], Any ] | None = None ,
29522948 * ,
29532949 auto_reset : bool = True ,
29542950 auto_cast_to_device : bool = False ,
29552951 break_when_any_done : bool | None = None ,
29562952 break_when_all_done : bool | None = None ,
29572953 return_contiguous : bool | None = False ,
2958- tensordict : Optional [ TensorDictBase ] = None ,
2954+ tensordict : TensorDictBase | None = None ,
29592955 set_truncated : bool = False ,
29602956 out = None ,
29612957 trust_policy : bool = False ,
@@ -3485,7 +3481,7 @@ def _rollout_nonstop(
34853481
34863482 def step_and_maybe_reset (
34873483 self , tensordict : TensorDictBase
3488- ) -> Tuple [TensorDictBase , TensorDictBase ]:
3484+ ) -> tuple [TensorDictBase , TensorDictBase ]:
34893485 """Runs a step in the environment and (partially) resets it if needed.
34903486
34913487 Args:
@@ -3606,7 +3602,7 @@ def empty_cache(self):
36063602
36073603 @property
36083604 @_cache_value
3609- def reset_keys (self ) -> List [NestedKey ]:
3605+ def reset_keys (self ) -> list [NestedKey ]:
36103606 """Returns a list of reset keys.
36113607
36123608 Reset keys are keys that indicate partial reset, in batched, multitask or multiagent
@@ -3763,14 +3759,14 @@ class _EnvWrapper(EnvBase):
37633759 """
37643760
37653761 git_url : str = ""
3766- available_envs : Dict [str , Any ] = {}
3762+ available_envs : dict [str , Any ] = {}
37673763 libname : str = ""
37683764
37693765 def __init__ (
37703766 self ,
37713767 * args ,
37723768 device : DEVICE_TYPING = None ,
3773- batch_size : Optional [ torch .Size ] = None ,
3769+ batch_size : torch .Size | None = None ,
37743770 allow_done_after_reset : bool = False ,
37753771 spec_locked : bool = True ,
37763772 ** kwargs ,
@@ -3819,7 +3815,7 @@ def _sync_device(self):
38193815 return sync_func
38203816
38213817 @abc .abstractmethod
3822- def _check_kwargs (self , kwargs : Dict ):
3818+ def _check_kwargs (self , kwargs : dict ):
38233819 raise NotImplementedError
38243820
38253821 def __getattr__ (self , attr : str ) -> Any :
@@ -3845,7 +3841,7 @@ def __getattr__(self, attr: str) -> Any:
38453841 )
38463842
38473843 @abc .abstractmethod
3848- def _init_env (self ) -> Optional [ int ] :
3844+ def _init_env (self ) -> int | None :
38493845 """Runs all the necessary steps such that the environment is ready to use.
38503846
38513847 This step is intended to ensure that a seed is provided to the environment (if needed) and that the environment
@@ -3859,7 +3855,7 @@ def _init_env(self) -> Optional[int]:
38593855 raise NotImplementedError
38603856
38613857 @abc .abstractmethod
3862- def _build_env (self , ** kwargs ) -> " gym.Env" : # noqa: F821
3858+ def _build_env (self , ** kwargs ) -> gym .Env : # noqa: F821
38633859 """Creates an environment from the target library and stores it with the `_env` attribute.
38643860
38653861 When overwritten, this function should pass all the required kwargs to the env instantiation method.
@@ -3868,7 +3864,7 @@ def _build_env(self, **kwargs) -> "gym.Env": # noqa: F821
38683864 raise NotImplementedError
38693865
38703866 @abc .abstractmethod
3871- def _make_specs (self , env : " gym.Env" ) -> None : # noqa: F821
3867+ def _make_specs (self , env : gym .Env ) -> None : # noqa: F821
38723868 raise NotImplementedError
38733869
38743870 def close (self , * , raise_if_closed : bool = True ) -> None :
@@ -3882,7 +3878,7 @@ def close(self, *, raise_if_closed: bool = True) -> None:
38823878
38833879def make_tensordict (
38843880 env : _EnvWrapper ,
3885- policy : Optional [ Callable [[TensorDictBase , ...], TensorDictBase ]] = None ,
3881+ policy : Callable [[TensorDictBase , ...], TensorDictBase ] | None = None ,
38863882) -> TensorDictBase :
38873883 """Returns a zeroed-tensordict with fields matching those required for a full step (action selection and environment step) in the environment.
38883884
0 commit comments