File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -3940,10 +3940,17 @@ def test_device_cast(self):
39403940 assert comp ["nontensor" ].device == torch .device ("cpu" )
39413941
39423942 def test_encode (self ):
3943+ nt = NonTensor (shape = (1 ,))
3944+ r = nt .encode ("a string" )
3945+ assert isinstance (r , NonTensorData )
3946+ assert r .shape == nt .shape
3947+
39433948 comp = Composite (device = "cpu" )
3944- comp ["nontensor" ] = NonTensor ( shape = ())
3949+ comp ["nontensor" ] = nt
39453950 r = comp .encode ({"nontensor" : "a string" })
3946- assert isinstance (r ["nontensor" ], str )
3951+ assert isinstance (r , TensorDict )
3952+ assert isinstance (r .get ("nontensor" ), NonTensorData )
3953+ assert r .get ("nontensor" ).shape == (1 ,)
39473954
39483955
39493956@pytest .mark .skipif (not torch .cuda .is_available (), reason = "not cuda device" )
Original file line number Diff line number Diff line change @@ -2918,7 +2918,7 @@ def _encode_eager(
29182918 * ,
29192919 ignore_device : bool = False ,
29202920 ) -> torch .Tensor | TensorDictBase :
2921- return val
2921+ return NonTensorData ( val , device = self . device , batch_size = self . shape )
29222922
29232923
29242924class _UnboundedMeta (abc .ABCMeta ):
You can’t perform that action at this time.
0 commit comments