@@ -540,7 +540,7 @@ def __repr__(self):
540540
541541
542542@dataclass (repr = False )
543- class TensorSpec :
543+ class TensorSpec ( metaclass = abc . ABCMeta ) :
544544 """Parent class of the tensor meta-data containers.
545545
546546 TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675675 self .assert_is_in (val )
676676 return val
677677
678+ @abc .abstractmethod
679+ def __eq__ (self , other : Any ) -> bool :
680+ # Implement minimal version if super() is called
681+ return type (self ) is type (other )
682+
678683 def __ne__ (self , other ):
679684 return not (self == other )
680685
@@ -734,13 +739,31 @@ def index(
734739 ) -> torch .Tensor | TensorDictBase :
735740 """Indexes the input tensor.
736741
742+ This method is to be used with specs that encode one or more categorical variables (e.g.,
743+ :class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744+ with a sample can be done without caring about the actual representation of the index.
745+
737746 Args:
738747 index (int, torch.Tensor, slice or list): index of the tensor
739748 tensor_to_index: tensor to be indexed
740749
741750 Returns:
742751 indexed tensor
743752
753+ Exanples:
754+ >>> from torchrl.data import OneHot
755+ >>> import torch
756+ >>>
757+ >>> one_hot = OneHot(n=100)
758+ >>> categ = one_hot.to_categorical_spec()
759+ >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760+ >>> idx_one_hot[50] = 1
761+ >>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762+ tensor(50)
763+ >>> idx_categ = one_hot.to_categorical(idx_one_hot)
764+ >>> print(categ.index(idx_categ, torch.arange(100)))
765+ tensor(50)
766+
744767 """
745768 ...
746769
@@ -1302,6 +1325,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
13021325
13031326 """
13041327
1328+ def _reshape (
1329+ self ,
1330+ * args ,
1331+ ** kwargs ,
1332+ ) -> Any :
1333+ raise NotImplementedError (
1334+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
1335+ )
1336+
1337+ def cardinality (
1338+ self ,
1339+ * args ,
1340+ ** kwargs ,
1341+ ) -> Any :
1342+ raise NotImplementedError (
1343+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
1344+ )
1345+
1346+ def index (
1347+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
1348+ ) -> torch .Tensor | TensorDictBase :
1349+ raise NotImplementedError (
1350+ f"`index` is not implemented for { type (self ).__name__ } specs."
1351+ )
1352+
13051353 def __eq__ (self , other ):
13061354 if not isinstance (other , Stacked ):
13071355 return False
@@ -1823,7 +1871,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
18231871 f"Only tensors are allowed for indexing using "
18241872 f"{ self .__class__ .__name__ } .index(...)"
18251873 )
1826- index = index .nonzero (). squeeze ()
1874+ index = index .nonzero (as_tuple = True )[ - 1 ]
18271875 index = index .expand ((* tensor_to_index .shape [:- 1 ], index .shape [- 1 ]))
18281876 return tensor_to_index .gather (- 1 , index )
18291877
@@ -2142,6 +2190,11 @@ def __init__(
21422190 domain = domain ,
21432191 )
21442192
2193+ def index (
2194+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2195+ ) -> torch .Tensor | TensorDictBase :
2196+ raise NotImplementedError ("Indexing not implemented for Bounded." )
2197+
21452198 def enumerate (self ) -> Any :
21462199 raise NotImplementedError (
21472200 f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
@@ -2478,11 +2531,19 @@ def __eq__(self, other):
24782531 eq = eq & (self .example_data == getattr (other , "example_data" , None ))
24792532 return eq
24802533
2534+ def _project (self ) -> Any :
2535+ raise NotImplementedError ("Cannot project a NonTensorSpec." )
2536+
2537+ def index (
2538+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2539+ ) -> torch .Tensor | TensorDictBase :
2540+ raise NotImplementedError ("Cannot use index with a NonTensorSpec." )
2541+
24812542 def cardinality (self ) -> Any :
2482- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2543+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
24832544
24842545 def enumerate (self ) -> Any :
2485- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2546+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
24862547
24872548 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
24882549 if isinstance (dest , torch .dtype ):
@@ -2744,6 +2805,16 @@ def __init__(
27442805 shape = shape , space = box , device = device , dtype = dtype , domain = domain , ** kwargs
27452806 )
27462807
2808+ def cardinality (self ) -> int :
2809+ raise NotImplementedError (
2810+ "`cardinality` is not implemented for Unbounded specs."
2811+ )
2812+
2813+ def index (
2814+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2815+ ) -> torch .Tensor | TensorDictBase :
2816+ raise NotImplementedError ("`index` is not implemented for Unbounded specs." )
2817+
27472818 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Unbounded :
27482819 if isinstance (dest , torch .dtype ):
27492820 dest_dtype = dest
@@ -3515,6 +3586,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
35153586 out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
35163587 return out
35173588
3589+ def index (
3590+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3591+ ) -> torch .Tensor | TensorDictBase :
3592+ idx = index .expand (
3593+ tensor_to_index .shape [: - self .ndim ] + torch .Size ([- 1 ] * self .ndim )
3594+ )
3595+ return tensor_to_index .gather (- 1 , idx )
3596+
35183597 def _project (self , val : torch .Tensor ) -> torch .Tensor :
35193598 if val .dtype not in (torch .int , torch .long ):
35203599 val = torch .round (val )
@@ -3851,9 +3930,50 @@ def cardinality(self) -> int:
38513930 .item ()
38523931 )
38533932
3933+ def enumerate (self , use_mask : bool = False ) -> List [Any ]:
3934+ return [s for choice in self ._choices for s in choice .enumerate ()]
3935+
3936+ def _project (
3937+ self , val : torch .Tensor | TensorDictBase
3938+ ) -> torch .Tensor | TensorDictBase :
3939+ raise NotImplementedError (
3940+ "_project is not implemented for Choice. If this feature is required, please raise "
3941+ "an issue on TorchRL github repo."
3942+ )
3943+
3944+ def _reshape (self , shape : torch .Size ) -> T :
3945+ return self .__class__ (
3946+ [choice .reshape (shape ) for choice in self ._choices ],
3947+ )
3948+
3949+ def index (
3950+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3951+ ) -> torch .Tensor | TensorDictBase :
3952+ raise NotImplementedError (
3953+ "index is not implemented for Choice. If this feature is required, please raise "
3954+ "an issue on TorchRL github repo."
3955+ )
3956+
3957+ @property
3958+ def num_choices (self ):
3959+ """Number of choices for the spec."""
3960+ return len (self ._choices )
3961+
38543962 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Choice :
38553963 return self .__class__ ([choice .to (dest ) for choice in self ._choices ])
38563964
3965+ def __eq__ (self , other ):
3966+ if not isinstance (other , Choice ):
3967+ return False
3968+ if self .num_choices != other .num_choices :
3969+ return False
3970+ return all (
3971+ (s0 == s1 ).all ()
3972+ if isinstance (s0 , torch .Tensor ) or is_tensor_collection (s0 )
3973+ else s0 == s1
3974+ for s0 , s1 in zip (self ._choices , other ._choices )
3975+ )
3976+
38573977
38583978@dataclass (repr = False )
38593979class Binary (Categorical ):
@@ -4585,6 +4705,21 @@ def shape(self, value: torch.Size):
45854705 )
45864706 self ._shape = _size (value )
45874707
4708+ def _project (
4709+ self , val : torch .Tensor | TensorDictBase
4710+ ) -> torch .Tensor | TensorDictBase :
4711+ cls = TensorDict
4712+ return cls .from_dict (
4713+ {k : item ._project (val [k ]) for k , item in self .items ()},
4714+ batch_size = self .shape ,
4715+ device = self .device ,
4716+ )
4717+
4718+ def index (
4719+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
4720+ ) -> torch .Tensor | TensorDictBase :
4721+ raise NotImplementedError ("`index` is not implemented for Composite specs." )
4722+
45884723 def is_empty (self , recurse : bool = False ):
45894724 """Whether the composite spec contains specs or not.
45904725
@@ -5508,6 +5643,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
55085643
55095644 """
55105645
5646+ def _reshape (
5647+ self ,
5648+ * args ,
5649+ ** kwargs ,
5650+ ) -> Any :
5651+ raise NotImplementedError (
5652+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
5653+ )
5654+
5655+ def cardinality (
5656+ self ,
5657+ * args ,
5658+ ** kwargs ,
5659+ ) -> Any :
5660+ raise NotImplementedError (
5661+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
5662+ )
5663+
5664+ def index (
5665+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
5666+ ) -> torch .Tensor | TensorDictBase :
5667+ raise NotImplementedError (
5668+ f"`index` is not implemented for { type (self ).__name__ } specs."
5669+ )
5670+
55115671 def update (self , dict ) -> None :
55125672 for key , item in dict .items ():
55135673 if key in self .keys () and isinstance (
0 commit comments