1414import re
1515import warnings
1616from enum import Enum
17- from typing import Any , Dict , List
17+ from typing import Any
1818
1919import torch
2020
@@ -329,9 +329,9 @@ def step_mdp(
329329 exclude_reward : bool = True ,
330330 exclude_done : bool = False ,
331331 exclude_action : bool = True ,
332- reward_keys : NestedKey | List [NestedKey ] = "reward" ,
333- done_keys : NestedKey | List [NestedKey ] = "done" ,
334- action_keys : NestedKey | List [NestedKey ] = "action" ,
332+ reward_keys : NestedKey | list [NestedKey ] = "reward" ,
333+ done_keys : NestedKey | list [NestedKey ] = "done" ,
334+ action_keys : NestedKey | list [NestedKey ] = "action" ,
335335) -> TensorDictBase :
336336 """Creates a new tensordict that reflects a step in time of the input tensordict.
337337
@@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):
680680
681681
682682def check_env_specs (
683- env ,
684- return_contiguous = True ,
683+ env : torchrl . envs . EnvBase , # noqa
684+ return_contiguous : bool | None = None ,
685685 check_dtype = True ,
686686 seed : int | None = None ,
687687 tensordict : TensorDictBase | None = None ,
@@ -700,7 +700,7 @@ def check_env_specs(
700700 env (EnvBase): the env for which the specs have to be checked against data.
701701 return_contiguous (bool, optional): if ``True``, the random rollout will be called with
702702 return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
703- of inputs/outputs). Defaults to True .
703+ of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs) .
704704 check_dtype (bool, optional): if False, dtype checks will be skipped.
705705 Defaults to True.
706706 seed (int, optional): for reproducibility, a seed can be set.
@@ -718,6 +718,8 @@ def check_env_specs(
718718 of an experiment and as such should be kept out of training scripts.
719719
720720 """
721+ if return_contiguous is None :
722+ return_contiguous = not env ._has_dynamic_specs
721723 if break_when_any_done == "both" :
722724 check_env_specs (
723725 env ,
@@ -746,7 +748,7 @@ def check_env_specs(
746748 )
747749
748750 fake_tensordict = env .fake_tensordict ()
749- if not env ._batch_locked and tensordict is not None :
751+ if not env .batch_locked and tensordict is not None :
750752 shape = torch .broadcast_shapes (fake_tensordict .shape , tensordict .shape )
751753 fake_tensordict = fake_tensordict .expand (shape )
752754 tensordict = tensordict .expand (shape )
@@ -786,10 +788,13 @@ def check_env_specs(
786788 - List of keys present in fake but not in real: { fake_tensordict_keys - real_tensordict_keys } .
787789"""
788790 )
789- zeroing_err_msg = (
790- "zeroing the two tensordicts did not make them identical. "
791- f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
792- )
791+
792+ def zeroing_err_msg ():
793+ return (
794+ "zeroing the two tensordicts did not make them identical. "
795+ f"Check for discrepancies:\n Fake=\n { fake_tensordict } \n Real=\n { real_tensordict } "
796+ )
797+
793798 from torchrl .envs .common import _has_dynamic_specs
794799
795800 if _has_dynamic_specs (env .specs ):
@@ -799,7 +804,7 @@ def check_env_specs(
799804 ):
800805 fake = fake .apply (lambda x , y : x .expand_as (y ), real )
801806 if (torch .zeros_like (real ) != torch .zeros_like (fake )).any ():
802- raise AssertionError (zeroing_err_msg )
807+ raise AssertionError (zeroing_err_msg () )
803808
804809 # Checks shapes and eventually dtypes of keys at all nesting levels
805810 _per_level_env_check (fake , real , check_dtype = check_dtype )
@@ -809,7 +814,7 @@ def check_env_specs(
809814 torch .zeros_like (fake_tensordict_select )
810815 != torch .zeros_like (real_tensordict_select )
811816 ).any ():
812- raise AssertionError (zeroing_err_msg )
817+ raise AssertionError (zeroing_err_msg () )
813818
814819 # Checks shapes and eventually dtypes of keys at all nesting levels
815820 _per_level_env_check (
@@ -1028,14 +1033,14 @@ class MarlGroupMapType(Enum):
10281033 ALL_IN_ONE_GROUP = 1
10291034 ONE_GROUP_PER_AGENT = 2
10301035
1031- def get_group_map (self , agent_names : List [str ]):
1036+ def get_group_map (self , agent_names : list [str ]):
10321037 if self == MarlGroupMapType .ALL_IN_ONE_GROUP :
10331038 return {"agents" : agent_names }
10341039 elif self == MarlGroupMapType .ONE_GROUP_PER_AGENT :
10351040 return {agent_name : [agent_name ] for agent_name in agent_names }
10361041
10371042
1038- def check_marl_grouping (group_map : Dict [str , List [str ]], agent_names : List [str ]):
1043+ def check_marl_grouping (group_map : dict [str , list [str ]], agent_names : list [str ]):
10391044 """Check MARL group map.
10401045
10411046 Performs checks on the group map of a marl environment to assess its validity.
@@ -1379,7 +1384,7 @@ def skim_through(td, reset=reset):
13791384def _update_during_reset (
13801385 tensordict_reset : TensorDictBase ,
13811386 tensordict : TensorDictBase ,
1382- reset_keys : List [NestedKey ],
1387+ reset_keys : list [NestedKey ],
13831388):
13841389 """Updates the input tensordict with the reset data, based on the reset keys."""
13851390 if not reset_keys :
0 commit comments