1616from collections import defaultdict
1717from functools import partial
1818from sys import platform
19- from typing import Optional
19+ from typing import Any , Optional
2020
2121import numpy as np
2222import pytest
3333 TensorDictBase ,
3434)
3535from tensordict .nn import TensorDictModuleBase
36- from tensordict .tensorclass import NonTensorStack
36+ from tensordict .tensorclass import NonTensorStack , TensorClass
3737from tensordict .utils import _unravel_key_to_tuple
3838from torch import nn
3939
@@ -340,7 +340,8 @@ def forward(self, values):
340340 )
341341 env .rollout (10 , policy )
342342
343- def test_make_spec_from_td (self ):
343+ @pytest .mark .parametrize ("dynamic_shape" , [True , False ])
344+ def test_make_spec_from_td (self , dynamic_shape ):
344345 data = TensorDict (
345346 {
346347 "obs" : torch .randn (3 ),
@@ -353,10 +354,44 @@ def test_make_spec_from_td(self):
353354 },
354355 [],
355356 )
356- spec = make_composite_from_td (data )
357+ spec = make_composite_from_td (data , dynamic_shape = dynamic_shape )
357358 assert (spec .zero () == data .zero_ ()).all ()
358359 for key , val in data .items (True , True ):
359360 assert val .dtype is spec [key ].dtype
361+ if dynamic_shape :
362+ assert all (s .shape [- 1 ] == - 1 for s in spec .values (True , True ))
363+
364+ def test_make_spec_from_tc (self ):
365+ class Scratch (TensorClass ):
366+ obs : torch .Tensor
367+ string : str
368+ some_object : Any
369+
370+ class Whatever :
371+ ...
372+
373+ td = TensorDict (
374+ a = Scratch (
375+ obs = torch .ones (5 , 3 ),
376+ string = "another string!" ,
377+ some_object = Whatever (),
378+ batch_size = (5 ,),
379+ ),
380+ b = "a string!" ,
381+ batch_size = (5 ,),
382+ )
383+ spec = make_composite_from_td (td )
384+ assert isinstance (spec , Composite )
385+ assert isinstance (spec ["a" ], Composite )
386+ assert isinstance (spec ["b" ], NonTensor )
387+ assert spec ["b" ].example_data == "a string!" , spec ["b" ].example_data
388+ assert spec ["a" , "string" ].example_data == "another string!"
389+ one = spec .one ()
390+ assert isinstance (one ["a" ], Scratch )
391+ assert isinstance (one ["b" ], str )
392+ assert isinstance (one ["a" ].string , str )
393+ assert isinstance (one ["a" ].some_object , Whatever )
394+ assert (one == td ).all ()
360395
361396 def test_env_that_does_nothing (self ):
362397 env = EnvThatDoesNothing ()
0 commit comments