@@ -540,7 +540,7 @@ def __repr__(self):
540540
541541
542542@dataclass (repr = False )
543- class TensorSpec :
543+ class TensorSpec ( metaclass = abc . ABCMeta ) :
544544 """Parent class of the tensor meta-data containers.
545545
546546 TorchRL's TensorSpec are used to present what input/output is to be expected for a specific class,
@@ -675,6 +675,11 @@ def encode(
675675 self .assert_is_in (val )
676676 return val
677677
678+ @abc .abstractmethod
679+ def __eq__ (self , other : Any ) -> bool :
680+ # Implement minimal version if super() is called
681+ return type (self ) is type (other )
682+
678683 def __ne__ (self , other ):
679684 return not (self == other )
680685
@@ -734,13 +739,31 @@ def index(
734739 ) -> torch .Tensor | TensorDictBase :
735740 """Indexes the input tensor.
736741
742+ This method is to be used with specs that encode one or more categorical variables (e.g.,
743+ :class:`~torchrl.data.OneHot` or :class:`~torchrl.data.Categorical`), such that indexing of a tensor
744+ with a sample can be done without caring about the actual representation of the index.
745+
737746 Args:
738747 index (int, torch.Tensor, slice or list): index of the tensor
739748 tensor_to_index: tensor to be indexed
740749
741750 Returns:
742751 indexed tensor
743752
753+ Exanples:
754+ >>> from torchrl.data import OneHot
755+ >>> import torch
756+ >>>
757+ >>> one_hot = OneHot(n=100)
758+ >>> categ = one_hot.to_categorical_spec()
759+ >>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
760+ >>> idx_one_hot[50] = 1
761+ >>> print(one_hot.index(idx_one_hot, torch.arange(100)))
762+ tensor(50)
763+ >>> idx_categ = one_hot.to_categorical(idx_one_hot)
764+ >>> print(categ.index(idx_categ, torch.arange(100)))
765+ tensor(50)
766+
744767 """
745768 ...
746769
@@ -1306,6 +1329,31 @@ class Stacked(_LazyStackedMixin[TensorSpec], TensorSpec):
13061329
13071330 """
13081331
1332+ def _reshape (
1333+ self ,
1334+ * args ,
1335+ ** kwargs ,
1336+ ) -> Any :
1337+ raise NotImplementedError (
1338+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
1339+ )
1340+
1341+ def cardinality (
1342+ self ,
1343+ * args ,
1344+ ** kwargs ,
1345+ ) -> Any :
1346+ raise NotImplementedError (
1347+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
1348+ )
1349+
1350+ def index (
1351+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
1352+ ) -> torch .Tensor | TensorDictBase :
1353+ raise NotImplementedError (
1354+ f"`index` is not implemented for { type (self ).__name__ } specs."
1355+ )
1356+
13091357 def __eq__ (self , other ):
13101358 if not isinstance (other , Stacked ):
13111359 return False
@@ -1829,7 +1877,7 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
18291877 f"Only tensors are allowed for indexing using "
18301878 f"{ self .__class__ .__name__ } .index(...)"
18311879 )
1832- index = index .nonzero (). squeeze ()
1880+ index = index .nonzero (as_tuple = True )[ - 1 ]
18331881 index = index .expand ((* tensor_to_index .shape [:- 1 ], index .shape [- 1 ]))
18341882 return tensor_to_index .gather (- 1 , index )
18351883
@@ -2148,6 +2196,11 @@ def __init__(
21482196 domain = domain ,
21492197 )
21502198
2199+ def index (
2200+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2201+ ) -> torch .Tensor | TensorDictBase :
2202+ raise NotImplementedError ("Indexing not implemented for Bounded." )
2203+
21512204 def enumerate (self , use_mask : bool = False ) -> Any :
21522205 raise NotImplementedError (
21532206 f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
@@ -2484,11 +2537,19 @@ def __eq__(self, other):
24842537 eq = eq & (self .example_data == getattr (other , "example_data" , None ))
24852538 return eq
24862539
2540+ def _project (self ) -> Any :
2541+ raise NotImplementedError ("Cannot project a NonTensorSpec." )
2542+
2543+ def index (
2544+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2545+ ) -> torch .Tensor | TensorDictBase :
2546+ raise NotImplementedError ("Cannot use index with a NonTensorSpec." )
2547+
24872548 def cardinality (self ) -> Any :
2488- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2549+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
24892550
24902551 def enumerate (self , use_mask : bool = False ) -> Any :
2491- raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2552+ raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
24922553
24932554 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
24942555 if isinstance (dest , torch .dtype ):
@@ -2752,6 +2813,16 @@ def __init__(
27522813 shape = shape , space = box , device = device , dtype = dtype , domain = domain , ** kwargs
27532814 )
27542815
2816+ def cardinality (self ) -> int :
2817+ raise NotImplementedError (
2818+ "`cardinality` is not implemented for Unbounded specs."
2819+ )
2820+
2821+ def index (
2822+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
2823+ ) -> torch .Tensor | TensorDictBase :
2824+ raise NotImplementedError ("`index` is not implemented for Unbounded specs." )
2825+
27552826 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Unbounded :
27562827 if isinstance (dest , torch .dtype ):
27572828 dest_dtype = dest
@@ -3527,6 +3598,14 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
35273598 out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
35283599 return out
35293600
3601+ def index (
3602+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3603+ ) -> torch .Tensor | TensorDictBase :
3604+ idx = index .expand (
3605+ tensor_to_index .shape [: - self .ndim ] + torch .Size ([- 1 ] * self .ndim )
3606+ )
3607+ return tensor_to_index .gather (- 1 , idx )
3608+
35303609 def _project (self , val : torch .Tensor ) -> torch .Tensor :
35313610 if val .dtype not in (torch .int , torch .long ):
35323611 val = torch .round (val )
@@ -3863,9 +3942,50 @@ def cardinality(self) -> int:
38633942 .item ()
38643943 )
38653944
3945+ def enumerate (self , use_mask : bool = False ) -> List [Any ]:
3946+ return [s for choice in self ._choices for s in choice .enumerate ()]
3947+
3948+ def _project (
3949+ self , val : torch .Tensor | TensorDictBase
3950+ ) -> torch .Tensor | TensorDictBase :
3951+ raise NotImplementedError (
3952+ "_project is not implemented for Choice. If this feature is required, please raise "
3953+ "an issue on TorchRL github repo."
3954+ )
3955+
3956+ def _reshape (self , shape : torch .Size ) -> T :
3957+ return self .__class__ (
3958+ [choice .reshape (shape ) for choice in self ._choices ],
3959+ )
3960+
3961+ def index (
3962+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
3963+ ) -> torch .Tensor | TensorDictBase :
3964+ raise NotImplementedError (
3965+ "index is not implemented for Choice. If this feature is required, please raise "
3966+ "an issue on TorchRL github repo."
3967+ )
3968+
3969+ @property
3970+ def num_choices (self ):
3971+ """Number of choices for the spec."""
3972+ return len (self ._choices )
3973+
38663974 def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> Choice :
38673975 return self .__class__ ([choice .to (dest ) for choice in self ._choices ])
38683976
3977+ def __eq__ (self , other ):
3978+ if not isinstance (other , Choice ):
3979+ return False
3980+ if self .num_choices != other .num_choices :
3981+ return False
3982+ return all (
3983+ (s0 == s1 ).all ()
3984+ if isinstance (s0 , torch .Tensor ) or is_tensor_collection (s0 )
3985+ else s0 == s1
3986+ for s0 , s1 in zip (self ._choices , other ._choices )
3987+ )
3988+
38693989
38703990@dataclass (repr = False )
38713991class Binary (Categorical ):
@@ -4643,6 +4763,24 @@ def shape(self, value: torch.Size):
46434763 )
46444764 self ._shape = _size (value )
46454765
4766+ def _project (
4767+ self , val : torch .Tensor | TensorDictBase
4768+ ) -> torch .Tensor | TensorDictBase :
4769+ if self .data_cls is None :
4770+ cls = TensorDict
4771+ else :
4772+ cls = self .data_cls
4773+ return cls .from_dict (
4774+ {k : item ._project (val [k ]) for k , item in self .items ()},
4775+ batch_size = self .shape ,
4776+ device = self .device ,
4777+ )
4778+
4779+ def index (
4780+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
4781+ ) -> torch .Tensor | TensorDictBase :
4782+ raise NotImplementedError ("`index` is not implemented for Composite specs." )
4783+
46464784 def is_empty (self , recurse : bool = False ):
46474785 """Whether the composite spec contains specs or not.
46484786
@@ -5569,6 +5707,31 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
55695707
55705708 """
55715709
5710+ def _reshape (
5711+ self ,
5712+ * args ,
5713+ ** kwargs ,
5714+ ) -> Any :
5715+ raise NotImplementedError (
5716+ f"`reshape` is not implemented for { type (self ).__name__ } specs."
5717+ )
5718+
5719+ def cardinality (
5720+ self ,
5721+ * args ,
5722+ ** kwargs ,
5723+ ) -> Any :
5724+ raise NotImplementedError (
5725+ f"`cardinality` is not implemented for { type (self ).__name__ } specs."
5726+ )
5727+
5728+ def index (
5729+ self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
5730+ ) -> torch .Tensor | TensorDictBase :
5731+ raise NotImplementedError (
5732+ f"`index` is not implemented for { type (self ).__name__ } specs."
5733+ )
5734+
55725735 def update (self , dict ) -> None :
55735736 for key , item in dict .items ():
55745737 if key in self .keys () and isinstance (
0 commit comments