@@ -297,7 +297,15 @@ def set(
297297 def get (self , index : Union [int , Sequence [int ], slice ]) -> Any :
298298 if isinstance (index , (INT_CLASSES , slice )):
299299 return self ._storage [index ]
300+ elif isinstance (index , tuple ):
301+ if len (index ) > 1 :
302+ raise RuntimeError (
303+ f"{ type (self ).__name__ } can only be indexed with one-length tuples."
304+ )
305+ return self .get (index [0 ])
300306 else :
307+ if isinstance (index , torch .Tensor ) and index .device .type != "cpu" :
308+ index = index .cpu ().tolist ()
301309 return [self ._storage [i ] for i in index ]
302310
303311 def __len__ (self ):
@@ -353,6 +361,77 @@ def contains(self, item):
353361 raise NotImplementedError (f"type { type (item )} is not supported yet." )
354362
355363
364+ class LazyStackStorage (ListStorage ):
365+ """A ListStorage that returns LazyStackTensorDict instances.
366+
367+ This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
368+ It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
369+ lazily stacking items when queried.
370+ This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
371+ Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
372+ Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
373+ the size of the buffer.
374+
375+ If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
376+ (see :mod:`~torch.nested`).
377+
378+ Args:
379+ max_size (int, optional): the maximum number of elements stored in the storage.
380+ If not provided, an unlimited storage is created.
381+
382+ Keyword Args:
383+ compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
384+ the cost of being executable in multiprocessed settings.
385+ stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`.
386+
387+ Examples:
388+ >>> import torch
389+ >>> from torchrl.data import ReplayBuffer, LazyStackStorage
390+ >>> from tensordict import TensorDict
391+ >>> _ = torch.manual_seed(0)
392+ >>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
393+ >>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
394+ >>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
395+ >>> _ = rb.add(data0)
396+ >>> _ = rb.add(data1)
397+ >>> rb.sample(10)
398+ LazyStackedTensorDict(
399+ fields={
400+ a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
401+ b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
402+ c: NonTensorStack(
403+ ['another string!', 'another string!', 'another st...,
404+ batch_size=torch.Size([10]),
405+ device=None)},
406+ exclusive_fields={
407+ },
408+ batch_size=torch.Size([10]),
409+ device=None,
410+ is_shared=False,
411+ stack_dim=0)
412+ """
413+
414+ def __init__ (
415+ self ,
416+ max_size : int | None = None ,
417+ * ,
418+ compilable : bool = False ,
419+ stack_dim : int = - 1 ,
420+ ):
421+ super ().__init__ (max_size = max_size , compilable = compilable )
422+ self .stack_dim = stack_dim
423+
424+ def get (self , index : Union [int , Sequence [int ], slice ]) -> Any :
425+ out = super ().get (index = index )
426+ if isinstance (out , list ):
427+ stack_dim = self .stack_dim
428+ if stack_dim < 0 :
429+ stack_dim = out [0 ].ndim + 1 + stack_dim
430+ out = LazyStackedTensorDict (* out , stack_dim = stack_dim )
431+ return out
432+ return out
433+
434+
356435class TensorStorage (Storage ):
357436 """A storage for tensors and tensordicts.
358437
0 commit comments