@@ -391,20 +391,61 @@ class ContinuousBox(Box):
391391 _low : torch .Tensor
392392 _high : torch .Tensor
393393 device : torch .device | None = None
394+ _batch_size : torch .Size | None = None
395+
396+ @property
397+ def batch_size (self ):
398+ return self ._batch_size
399+
400+ @batch_size .setter
401+ def batch_size (self , value : torch .Size | tuple ):
402+ # Check batch size is compatible with low and high
403+ value = _remove_neg_shapes (value )
404+ if self ._batch_size is None :
405+ if value != self ._low .shape [: len (value )]:
406+ raise ValueError (
407+ f"Batch size { value } is not compatible with low and high { self ._low .shape } "
408+ )
409+ if value is None :
410+ self ._batch_size = None
411+ self ._low = self .low .clone ()
412+ self ._high = self .high .clone ()
413+ return
414+ # Remove batch size from low and high
415+ if value :
416+ # Check that low and high have a single value
417+ td_low_high = TensorDict (
418+ low = self .low , high = self .high , batch_size = value
419+ ).flatten ()
420+ td_low_high0 = td_low_high [0 ]
421+ if torch .allclose (
422+ td_low_high0 ["low" ], td_low_high ["low" ]
423+ ) and torch .allclose (td_low_high0 ["high" ], td_low_high ["high" ]):
424+ self ._low = td_low_high0 ["low" ].clone ()
425+ self ._high = td_low_high0 ["high" ].clone ()
426+ self ._batch_size = torch .Size (value )
427+ else :
428+ self ._low = self .low .clone ()
429+ self ._high = self .high .clone ()
430+ self ._batch_size = torch .Size (value )
394431
395432 # We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
396433 @property
397434 def low (self ):
398435 low = self ._low
399436 if self .device is not None and low .device != self .device :
400437 low = low .to (self .device )
438+ if self ._batch_size :
439+ low = low .expand ((* self ._batch_size , * low .shape )).clone ()
401440 return low
402441
403442 @property
404443 def high (self ):
405444 high = self ._high
406445 if self .device is not None and high .device != self .device :
407446 high = high .to (self .device )
447+ if self ._batch_size :
448+ high = high .expand ((* self ._batch_size , * high .shape )).clone ()
408449 return high
409450
410451 def unbind (self , dim : int = 0 ):
@@ -417,15 +458,30 @@ def unbind(self, dim: int = 0):
417458 def low (self , value ):
418459 self .device = value .device
419460 self ._low = value
461+ if self ._batch_size is not None :
462+ if value .shape [: len (self ._batch_size )] != self ._batch_size :
463+ raise ValueError (
464+ f"Batch size { value .shape [:len (self ._batch_size )]} is not compatible with low and high { self ._batch_size } "
465+ )
466+ if self ._batch_size :
467+ self ._low = self ._low .flatten (0 , len (self ._batch_size ) - 1 )[0 ].clone ()
420468
421469 @high .setter
422470 def high (self , value ):
423471 self .device = value .device
424472 self ._high = value
473+ if self ._batch_size is not None :
474+ if value .shape [: len (self ._batch_size )] != self ._batch_size :
475+ raise ValueError (
476+ f"Batch size { value .shape [:len (self ._batch_size )]} is not compatible with low and high { self ._batch_size } "
477+ )
478+ if self ._batch_size :
479+ self ._high = self ._high .flatten (0 , len (self ._batch_size ) - 1 )[0 ].clone ()
425480
426481 def __post_init__ (self ):
427482 self .low = self .low .clone ()
428483 self .high = self .high .clone ()
484+ self ._batch_size = None
429485
430486 def __iter__ (self ):
431487 yield self .low
@@ -2366,6 +2422,10 @@ def __init__(
23662422 )
23672423 self .encode = self ._encode_eager
23682424
2425+ def _register_batch_size (self , batch_size : torch .Size | tuple ):
2426+ # Register batch size in the space to decrease the memory footprint of the specs
2427+ self .space .batch_size = batch_size
2428+
23692429 def index (
23702430 self , index : INDEX_TYPING , tensor_to_index : torch .Tensor | TensorDictBase
23712431 ) -> torch .Tensor | TensorDictBase :
@@ -5191,6 +5251,8 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
51915251 f"{ self .ndim } dimensions should match but got spec.shape={ spec .shape } and "
51925252 f"Composite.shape={ self .shape } ."
51935253 )
5254+ if isinstance (spec , Bounded ):
5255+ spec ._register_batch_size (self .shape )
51945256 self ._specs [name ] = spec
51955257 return self
51965258
0 commit comments