File tree Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Expand file tree Collapse file tree 2 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -705,15 +705,16 @@ def add(self, data: Any) -> int:
705705 make_none = False
706706 # Transforms usually expect a time batch dimension when called within a RB, so we unsqueeze the data temporarily
707707 is_tc = is_tensor_collection (data )
708- with data .unsqueeze (- 1 ) if is_tc else contextlib .nullcontext (
709- data
710- ) as data_unsq :
708+ cm = data .unsqueeze (- 1 ) if is_tc else contextlib .nullcontext (data )
709+ new_data = None
710+ with cm as data_unsq :
711711 data_unsq_r = self ._transform .inv (data_unsq )
712712 if is_tc and data_unsq_r is not None :
713713 # this is a no-op whenever the result matches the input
714- data_unsq . update ( data_unsq_r )
714+ new_data = data_unsq_r . squeeze ( - 1 )
715715 else :
716716 make_none = data_unsq_r is None
717+ data = new_data if new_data is not None else data
717718 if make_none :
718719 data = None
719720 if data is None :
Original file line number Diff line number Diff line change @@ -4507,9 +4507,9 @@ class DeviceCastTransform(Transform):
45074507 """Moves data from one device to another.
45084508
45094509 Args:
4510- device (torch.device or equivalent): the destination device.
4511- orig_device (torch.device or equivalent): the origin device. If not specified and
4512- a parent environment exists, it it retrieved from it. In all other cases,
4510+ device (torch.device or equivalent): the destination device (outside the environment or buffer) .
4511+ orig_device (torch.device or equivalent): the origin device (inside the environment or buffer).
4512+ If not specified and a parent environment exists, it it retrieved from it. In all other cases,
45134513 it remains unspecified.
45144514
45154515 Keyword Args:
You can’t perform that action at this time.
0 commit comments