@@ -869,12 +869,16 @@ def contains(self, item: torch.Tensor | TensorDictBase) -> bool:
869869 return self .is_in (item )
870870
871871 @abc .abstractmethod
872- def enumerate (self ) -> Any :
872+ def enumerate (self , use_mask : bool = False ) -> Any :
873873 """Returns all the samples that can be obtained from the TensorSpec.
874874
875875 The samples will be stacked along the first dimension.
876876
877877 This method is only implemented for discrete specs.
878+
879+ Args:
880+ use_mask (bool, optional): If ``True`` and the spec has a mask,
881+ samples that are masked are excluded. Default is ``False``.
878882 """
879883 ...
880884
@@ -1315,9 +1319,9 @@ def __eq__(self, other):
13151319 return False
13161320 return True
13171321
1318- def enumerate (self ) -> torch .Tensor | TensorDictBase :
1322+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor | TensorDictBase :
13191323 return torch .stack (
1320- [spec .enumerate () for spec in self ._specs ], dim = self .stack_dim + 1
1324+ [spec .enumerate (use_mask ) for spec in self ._specs ], dim = self .stack_dim + 1
13211325 )
13221326
13231327 def __len__ (self ):
@@ -1810,7 +1814,9 @@ def to_numpy(self, val: torch.Tensor, safe: bool = None) -> np.ndarray:
18101814 return np .array (vals ).reshape (tuple (val .shape ))
18111815 return val
18121816
1813- def enumerate (self ) -> torch .Tensor :
1817+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
1818+ if use_mask :
1819+ raise NotImplementedError
18141820 return (
18151821 torch .eye (self .n , dtype = self .dtype , device = self .device )
18161822 .expand (* self .shape , self .n )
@@ -2142,7 +2148,7 @@ def __init__(
21422148 domain = domain ,
21432149 )
21442150
2145- def enumerate (self ) -> Any :
2151+ def enumerate (self , use_mask : bool = False ) -> Any :
21462152 raise NotImplementedError (
21472153 f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
21482154 )
@@ -2481,7 +2487,7 @@ def __eq__(self, other):
24812487 def cardinality (self ) -> Any :
24822488 raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
24832489
2484- def enumerate (self ) -> Any :
2490+ def enumerate (self , use_mask : bool = False ) -> Any :
24852491 raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
24862492
24872493 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
@@ -2779,7 +2785,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
27792785 val .shape [: - self .ndim ] + self .shape
27802786 )
27812787
2782- def enumerate (self ) -> Any :
2788+ def enumerate (self , use_mask : bool = False ) -> Any :
27832789 raise NotImplementedError ("enumerate cannot be called with continuous specs." )
27842790
27852791 def expand (self , * shape ):
@@ -2951,9 +2957,9 @@ def __init__(
29512957 def cardinality (self ) -> int :
29522958 return torch .as_tensor (self .nvec ).prod ()
29532959
2954- def enumerate (self ) -> torch .Tensor :
2960+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
29552961 nvec = self .nvec
2956- enum_disc = self .to_categorical_spec ().enumerate ()
2962+ enum_disc = self .to_categorical_spec ().enumerate (use_mask )
29572963 enums = torch .cat (
29582964 [
29592965 torch .nn .functional .one_hot (enum_unb , nv ).to (self .dtype )
@@ -3417,14 +3423,18 @@ def __init__(
34173423 def _undefined_n (self ):
34183424 return self .space .n < 0
34193425
3420- def enumerate (self ) -> torch .Tensor :
3426+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
34213427 dtype = self .dtype
34223428 if dtype is torch .bool :
34233429 dtype = torch .uint8
3424- arange = torch .arange (self .n , dtype = dtype , device = self .device )
3430+ n = self .n
3431+ arange = torch .arange (n , dtype = dtype , device = self .device )
3432+ if use_mask and self .mask is not None :
3433+ arange = arange [self .mask ]
3434+ n = arange .shape [0 ]
34253435 if self .ndim :
34263436 arange = arange .view (- 1 , * (1 ,) * self .ndim )
3427- return arange .expand (self . n , * self .shape )
3437+ return arange .expand (n , * self .shape )
34283438
34293439 @property
34303440 def n (self ):
@@ -4088,7 +4098,9 @@ def __init__(
40884098 self .update_mask (mask )
40894099 self .remove_singleton = remove_singleton
40904100
4091- def enumerate (self ) -> torch .Tensor :
4101+ def enumerate (self , use_mask : bool = False ) -> torch .Tensor :
4102+ if use_mask :
4103+ raise NotImplementedError ()
40924104 if self .mask is not None :
40934105 raise RuntimeError (
40944106 "Cannot enumerate a masked TensorSpec. Submit an issue on github if this feature is requested."
@@ -5136,13 +5148,15 @@ def cardinality(self) -> int:
51365148 n = 0
51375149 return n
51385150
5139- def enumerate (self ) -> TensorDictBase :
5151+ def enumerate (self , use_mask : bool = False ) -> TensorDictBase :
51405152 # We are going to use meshgrid to create samples of all the subspecs in here
51415153 # but first let's get rid of the batch size, we'll put it back later
51425154 self_without_batch = self
51435155 while self_without_batch .ndim :
51445156 self_without_batch = self_without_batch [0 ]
5145- samples = {key : spec .enumerate () for key , spec in self_without_batch .items ()}
5157+ samples = {
5158+ key : spec .enumerate (use_mask ) for key , spec in self_without_batch .items ()
5159+ }
51465160 if self .data_cls is not None :
51475161 cls = self .data_cls
51485162 else :
@@ -5566,10 +5580,10 @@ def update(self, dict) -> None:
55665580 self [key ] = item
55675581 return self
55685582
5569- def enumerate (self ) -> TensorDictBase :
5583+ def enumerate (self , use_mask : bool = False ) -> TensorDictBase :
55705584 dim = self .stack_dim
55715585 return LazyStackedTensorDict .maybe_dense_stack (
5572- [spec .enumerate () for spec in self ._specs ], dim + 1
5586+ [spec .enumerate (use_mask ) for spec in self ._specs ], dim + 1
55735587 )
55745588
55755589 def __eq__ (self , other ):
0 commit comments