@@ -4301,6 +4301,291 @@ def test_composite(self):
43014301 assert c_enum ["b" ].shape == torch .Size ((20 , 3 ))
43024302
43034303
4304+ class TestCompositeNames :
4305+ """Test the names functionality of Composite specs."""
4306+
4307+ def test_names_property_basic (self ):
4308+ """Test basic names property functionality."""
4309+ # Test with names
4310+ spec = Composite (
4311+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4312+ shape = (10 , 5 ),
4313+ names = ["batch" , "time" ],
4314+ )
4315+ assert spec .names == ["batch" , "time" ]
4316+ assert spec ._has_names () is True
4317+
4318+ # Test without names
4319+ spec_no_names = Composite (
4320+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))}, shape = (10 , 5 )
4321+ )
4322+ assert spec_no_names .names == [None , None ]
4323+ assert spec_no_names ._has_names () is False
4324+
4325+ def test_names_setter (self ):
4326+ """Test setting names."""
4327+ spec = Composite (
4328+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))}, shape = (10 , 5 )
4329+ )
4330+
4331+ # Set names
4332+ spec .names = ["batch" , "time" ]
4333+ assert spec .names == ["batch" , "time" ]
4334+ assert spec ._has_names () is True
4335+
4336+ # Clear names
4337+ spec .names = None
4338+ assert spec .names == [None , None ]
4339+ assert spec ._has_names () is False
4340+
4341+ def test_names_setter_validation (self ):
4342+ """Test names setter validation."""
4343+ spec = Composite (
4344+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))}, shape = (10 , 5 )
4345+ )
4346+
4347+ # Test wrong number of names
4348+ with pytest .raises (ValueError , match = "Expected 2 names, but got 3 names" ):
4349+ spec .names = ["batch" , "time" , "extra" ]
4350+
4351+ def test_refine_names_basic (self ):
4352+ """Test basic refine_names functionality."""
4353+ spec = Composite (
4354+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))}, shape = (10 , 5 , 3 )
4355+ )
4356+
4357+ # Initially no names
4358+ assert spec .names == [None , None , None ]
4359+ assert spec ._has_names () is False
4360+
4361+ # Refine names
4362+ spec_refined = spec .refine_names (None , None , "feature" )
4363+ assert spec_refined .names == [None , None , "feature" ]
4364+ assert spec_refined ._has_names () is True
4365+
4366+ def test_refine_names_ellipsis (self ):
4367+ """Test refine_names with ellipsis."""
4368+ spec = Composite (
4369+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4370+ shape = (10 , 5 , 3 ),
4371+ names = ["batch" , None , None ],
4372+ )
4373+
4374+ # Use ellipsis to fill remaining dimensions
4375+ spec_refined = spec .refine_names ("batch" , ...)
4376+ assert spec_refined .names == ["batch" , None , None ]
4377+
4378+ def test_refine_names_validation (self ):
4379+ """Test refine_names validation."""
4380+ spec = Composite (
4381+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4382+ shape = (10 , 5 ),
4383+ names = ["batch" , "time" ],
4384+ )
4385+
4386+ # Try to refine to different name
4387+ with pytest .raises (RuntimeError , match = "cannot coerce Composite names" ):
4388+ spec .refine_names ("batch" , "different" )
4389+
4390+ def test_expand_preserves_names (self ):
4391+ """Test that expand preserves names."""
4392+ spec = Composite (
4393+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4394+ shape = (10 ,),
4395+ names = ["batch" ],
4396+ )
4397+
4398+ expanded = spec .expand (5 , 10 )
4399+ assert expanded .names == [None , "batch" ]
4400+ assert expanded .shape == torch .Size ([5 , 10 ])
4401+
4402+ def test_squeeze_preserves_names (self ):
4403+ """Test that squeeze preserves names."""
4404+ spec = Composite (
4405+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 1 , 5 , 3 , 4 ))},
4406+ shape = (10 , 1 , 5 ),
4407+ names = ["batch" , "dummy" , "time" ],
4408+ )
4409+
4410+ squeezed = spec .squeeze (1 ) # Remove the dimension with size 1
4411+ assert squeezed .names == ["batch" , "time" ]
4412+ assert squeezed .shape == torch .Size ([10 , 5 ])
4413+
4414+ def test_squeeze_all_ones_clears_names (self ):
4415+ """Test that squeezing all dimensions clears names if all become None."""
4416+ spec = Composite (
4417+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (1 , 1 , 3 , 4 ))},
4418+ shape = (1 , 1 ),
4419+ names = ["dummy1" , "dummy2" ],
4420+ )
4421+
4422+ squeezed = spec .squeeze ()
4423+ assert squeezed .names == [] # All dimensions removed, so no names
4424+ assert squeezed .shape == torch .Size ([])
4425+
4426+ def test_unsqueeze_preserves_names (self ):
4427+ """Test that unsqueeze preserves names."""
4428+ spec = Composite (
4429+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4430+ shape = (10 , 5 ),
4431+ names = ["batch" , "time" ],
4432+ )
4433+
4434+ unsqueezed = spec .unsqueeze (1 )
4435+ assert unsqueezed .names == ["batch" , None , "time" ]
4436+ assert unsqueezed .shape == torch .Size ([10 , 1 , 5 ])
4437+
4438+ def test_unbind_preserves_names (self ):
4439+ """Test that unbind preserves names."""
4440+ spec = Composite (
4441+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (3 , 5 , 3 , 4 ))},
4442+ shape = (3 , 5 ),
4443+ names = ["batch" , "time" ],
4444+ )
4445+
4446+ unbound = spec .unbind (0 )
4447+ assert len (unbound ) == 3
4448+ for spec_item in unbound :
4449+ assert spec_item .names == ["time" ]
4450+ assert spec_item .shape == torch .Size ([5 ])
4451+
4452+ def test_clone_preserves_names (self ):
4453+ """Test that clone preserves names."""
4454+ spec = Composite (
4455+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4456+ shape = (10 ,),
4457+ names = ["batch" ],
4458+ )
4459+
4460+ cloned = spec .clone ()
4461+ assert cloned .names == ["batch" ]
4462+ assert cloned .shape == spec .shape
4463+ assert cloned is not spec # Different objects
4464+
4465+ def test_to_preserves_names (self ):
4466+ """Test that to() preserves names."""
4467+ spec = Composite (
4468+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4469+ shape = (10 ,),
4470+ names = ["batch" ],
4471+ )
4472+
4473+ moved = spec .to ("cpu" )
4474+ assert moved .names == ["batch" ]
4475+ assert moved .device == torch .device ("cpu" )
4476+
4477+ def test_indexing_preserves_names (self ):
4478+ """Test that indexing preserves names."""
4479+ spec = Composite (
4480+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4481+ shape = (10 , 5 ),
4482+ names = ["batch" , "time" ],
4483+ )
4484+
4485+ # Test single dimension indexing
4486+ indexed = spec [0 ]
4487+ assert indexed .names == ["time" ]
4488+ assert indexed .shape == torch .Size ([5 ])
4489+
4490+ # Test slice indexing
4491+ sliced = spec [0 :5 ]
4492+ assert sliced .names == ["batch" , "time" ]
4493+ assert sliced .shape == torch .Size ([5 , 5 ])
4494+
4495+ def test_nested_composite_names_propagation (self ):
4496+ """Test that names are propagated to nested Composite specs."""
4497+ nested_spec = Composite (
4498+ {
4499+ "outer" : Composite (
4500+ {"inner" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 2 ))}, shape = (10 , 3 )
4501+ )
4502+ },
4503+ shape = (10 ,),
4504+ names = ["batch" ],
4505+ )
4506+
4507+ assert nested_spec .names == ["batch" ]
4508+ assert nested_spec ["outer" ].names == ["batch" , None ]
4509+
4510+ def test_erase_names (self ):
4511+ """Test erasing names."""
4512+ spec = Composite (
4513+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4514+ shape = (10 ,),
4515+ names = ["batch" ],
4516+ )
4517+
4518+ assert spec ._has_names () is True
4519+ spec ._erase_names ()
4520+ assert spec ._has_names () is False
4521+ assert spec .names == [None ]
4522+
4523+ def test_names_with_different_shapes (self ):
4524+ """Test names with different spec shapes."""
4525+ spec = Composite (
4526+ {
4527+ "obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 )),
4528+ "action" : Bounded (low = 0 , high = 1 , shape = (10 , 5 , 2 )),
4529+ },
4530+ shape = (10 , 5 ),
4531+ names = ["batch" , "time" ],
4532+ )
4533+
4534+ assert spec .names == ["batch" , "time" ]
4535+ assert spec ["obs" ].shape == torch .Size ([10 , 5 , 3 , 4 ])
4536+ assert spec ["action" ].shape == torch .Size ([10 , 5 , 2 ])
4537+
4538+ def test_names_constructor_parameter (self ):
4539+ """Test names parameter in constructor."""
4540+ # Test with names
4541+ spec = Composite (
4542+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))},
4543+ shape = (10 , 5 ),
4544+ names = ["batch" , "time" ],
4545+ )
4546+ assert spec .names == ["batch" , "time" ]
4547+
4548+ # Test without names
4549+ spec_no_names = Composite (
4550+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 5 , 3 , 4 ))}, shape = (10 , 5 )
4551+ )
4552+ assert spec_no_names .names == [None , None ]
4553+
4554+ def test_names_with_empty_composite (self ):
4555+ """Test names with empty Composite."""
4556+ spec = Composite ({}, shape = (10 ,), names = ["batch" ])
4557+ assert spec .names == ["batch" ]
4558+ assert spec ._has_names () is True
4559+
4560+ def test_names_equality (self ):
4561+ """Test that names don't affect equality."""
4562+ spec1 = Composite (
4563+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4564+ shape = (10 ,),
4565+ names = ["batch" ],
4566+ )
4567+
4568+ spec2 = Composite (
4569+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))}, shape = (10 ,)
4570+ )
4571+
4572+ # They should be equal despite different names
4573+ assert spec1 == spec2
4574+
4575+ def test_names_repr (self ):
4576+ """Test that names don't break repr."""
4577+ spec = Composite (
4578+ {"obs" : Bounded (low = - 1 , high = 1 , shape = (10 , 3 , 4 ))},
4579+ shape = (10 ,),
4580+ names = ["batch" ],
4581+ )
4582+
4583+ # Should not raise an error
4584+ repr_str = repr (spec )
4585+ assert "Composite" in repr_str
4586+ assert "obs" in repr_str
4587+
4588+
43044589if __name__ == "__main__" :
43054590 args , unknown = argparse .ArgumentParser ().parse_known_args ()
43064591 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments