Skip to content

Commit f8bb71b

Browse files
authored
[Feature] Named dims in Composite (#3174)
1 parent 2e5e353 commit f8bb71b

File tree

2 files changed

+460
-2
lines changed

2 files changed

+460
-2
lines changed

test/test_specs.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4301,6 +4301,291 @@ def test_composite(self):
43014301
assert c_enum["b"].shape == torch.Size((20, 3))
43024302

43034303

4304+
class TestCompositeNames:
4305+
"""Test the names functionality of Composite specs."""
4306+
4307+
def test_names_property_basic(self):
4308+
"""Test basic names property functionality."""
4309+
# Test with names
4310+
spec = Composite(
4311+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4312+
shape=(10, 5),
4313+
names=["batch", "time"],
4314+
)
4315+
assert spec.names == ["batch", "time"]
4316+
assert spec._has_names() is True
4317+
4318+
# Test without names
4319+
spec_no_names = Composite(
4320+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
4321+
)
4322+
assert spec_no_names.names == [None, None]
4323+
assert spec_no_names._has_names() is False
4324+
4325+
def test_names_setter(self):
4326+
"""Test setting names."""
4327+
spec = Composite(
4328+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
4329+
)
4330+
4331+
# Set names
4332+
spec.names = ["batch", "time"]
4333+
assert spec.names == ["batch", "time"]
4334+
assert spec._has_names() is True
4335+
4336+
# Clear names
4337+
spec.names = None
4338+
assert spec.names == [None, None]
4339+
assert spec._has_names() is False
4340+
4341+
def test_names_setter_validation(self):
4342+
"""Test names setter validation."""
4343+
spec = Composite(
4344+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
4345+
)
4346+
4347+
# Test wrong number of names
4348+
with pytest.raises(ValueError, match="Expected 2 names, but got 3 names"):
4349+
spec.names = ["batch", "time", "extra"]
4350+
4351+
def test_refine_names_basic(self):
4352+
"""Test basic refine_names functionality."""
4353+
spec = Composite(
4354+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5, 3)
4355+
)
4356+
4357+
# Initially no names
4358+
assert spec.names == [None, None, None]
4359+
assert spec._has_names() is False
4360+
4361+
# Refine names
4362+
spec_refined = spec.refine_names(None, None, "feature")
4363+
assert spec_refined.names == [None, None, "feature"]
4364+
assert spec_refined._has_names() is True
4365+
4366+
def test_refine_names_ellipsis(self):
4367+
"""Test refine_names with ellipsis."""
4368+
spec = Composite(
4369+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4370+
shape=(10, 5, 3),
4371+
names=["batch", None, None],
4372+
)
4373+
4374+
# Use ellipsis to fill remaining dimensions
4375+
spec_refined = spec.refine_names("batch", ...)
4376+
assert spec_refined.names == ["batch", None, None]
4377+
4378+
def test_refine_names_validation(self):
4379+
"""Test refine_names validation."""
4380+
spec = Composite(
4381+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4382+
shape=(10, 5),
4383+
names=["batch", "time"],
4384+
)
4385+
4386+
# Try to refine to different name
4387+
with pytest.raises(RuntimeError, match="cannot coerce Composite names"):
4388+
spec.refine_names("batch", "different")
4389+
4390+
def test_expand_preserves_names(self):
4391+
"""Test that expand preserves names."""
4392+
spec = Composite(
4393+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4394+
shape=(10,),
4395+
names=["batch"],
4396+
)
4397+
4398+
expanded = spec.expand(5, 10)
4399+
assert expanded.names == [None, "batch"]
4400+
assert expanded.shape == torch.Size([5, 10])
4401+
4402+
def test_squeeze_preserves_names(self):
4403+
"""Test that squeeze preserves names."""
4404+
spec = Composite(
4405+
{"obs": Bounded(low=-1, high=1, shape=(10, 1, 5, 3, 4))},
4406+
shape=(10, 1, 5),
4407+
names=["batch", "dummy", "time"],
4408+
)
4409+
4410+
squeezed = spec.squeeze(1) # Remove the dimension with size 1
4411+
assert squeezed.names == ["batch", "time"]
4412+
assert squeezed.shape == torch.Size([10, 5])
4413+
4414+
def test_squeeze_all_ones_clears_names(self):
4415+
"""Test that squeezing all dimensions clears names if all become None."""
4416+
spec = Composite(
4417+
{"obs": Bounded(low=-1, high=1, shape=(1, 1, 3, 4))},
4418+
shape=(1, 1),
4419+
names=["dummy1", "dummy2"],
4420+
)
4421+
4422+
squeezed = spec.squeeze()
4423+
assert squeezed.names == [] # All dimensions removed, so no names
4424+
assert squeezed.shape == torch.Size([])
4425+
4426+
def test_unsqueeze_preserves_names(self):
4427+
"""Test that unsqueeze preserves names."""
4428+
spec = Composite(
4429+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4430+
shape=(10, 5),
4431+
names=["batch", "time"],
4432+
)
4433+
4434+
unsqueezed = spec.unsqueeze(1)
4435+
assert unsqueezed.names == ["batch", None, "time"]
4436+
assert unsqueezed.shape == torch.Size([10, 1, 5])
4437+
4438+
def test_unbind_preserves_names(self):
4439+
"""Test that unbind preserves names."""
4440+
spec = Composite(
4441+
{"obs": Bounded(low=-1, high=1, shape=(3, 5, 3, 4))},
4442+
shape=(3, 5),
4443+
names=["batch", "time"],
4444+
)
4445+
4446+
unbound = spec.unbind(0)
4447+
assert len(unbound) == 3
4448+
for spec_item in unbound:
4449+
assert spec_item.names == ["time"]
4450+
assert spec_item.shape == torch.Size([5])
4451+
4452+
def test_clone_preserves_names(self):
4453+
"""Test that clone preserves names."""
4454+
spec = Composite(
4455+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4456+
shape=(10,),
4457+
names=["batch"],
4458+
)
4459+
4460+
cloned = spec.clone()
4461+
assert cloned.names == ["batch"]
4462+
assert cloned.shape == spec.shape
4463+
assert cloned is not spec # Different objects
4464+
4465+
def test_to_preserves_names(self):
4466+
"""Test that to() preserves names."""
4467+
spec = Composite(
4468+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4469+
shape=(10,),
4470+
names=["batch"],
4471+
)
4472+
4473+
moved = spec.to("cpu")
4474+
assert moved.names == ["batch"]
4475+
assert moved.device == torch.device("cpu")
4476+
4477+
def test_indexing_preserves_names(self):
4478+
"""Test that indexing preserves names."""
4479+
spec = Composite(
4480+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4481+
shape=(10, 5),
4482+
names=["batch", "time"],
4483+
)
4484+
4485+
# Test single dimension indexing
4486+
indexed = spec[0]
4487+
assert indexed.names == ["time"]
4488+
assert indexed.shape == torch.Size([5])
4489+
4490+
# Test slice indexing
4491+
sliced = spec[0:5]
4492+
assert sliced.names == ["batch", "time"]
4493+
assert sliced.shape == torch.Size([5, 5])
4494+
4495+
def test_nested_composite_names_propagation(self):
4496+
"""Test that names are propagated to nested Composite specs."""
4497+
nested_spec = Composite(
4498+
{
4499+
"outer": Composite(
4500+
{"inner": Bounded(low=-1, high=1, shape=(10, 3, 2))}, shape=(10, 3)
4501+
)
4502+
},
4503+
shape=(10,),
4504+
names=["batch"],
4505+
)
4506+
4507+
assert nested_spec.names == ["batch"]
4508+
assert nested_spec["outer"].names == ["batch", None]
4509+
4510+
def test_erase_names(self):
4511+
"""Test erasing names."""
4512+
spec = Composite(
4513+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4514+
shape=(10,),
4515+
names=["batch"],
4516+
)
4517+
4518+
assert spec._has_names() is True
4519+
spec._erase_names()
4520+
assert spec._has_names() is False
4521+
assert spec.names == [None]
4522+
4523+
def test_names_with_different_shapes(self):
4524+
"""Test names with different spec shapes."""
4525+
spec = Composite(
4526+
{
4527+
"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4)),
4528+
"action": Bounded(low=0, high=1, shape=(10, 5, 2)),
4529+
},
4530+
shape=(10, 5),
4531+
names=["batch", "time"],
4532+
)
4533+
4534+
assert spec.names == ["batch", "time"]
4535+
assert spec["obs"].shape == torch.Size([10, 5, 3, 4])
4536+
assert spec["action"].shape == torch.Size([10, 5, 2])
4537+
4538+
def test_names_constructor_parameter(self):
4539+
"""Test names parameter in constructor."""
4540+
# Test with names
4541+
spec = Composite(
4542+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))},
4543+
shape=(10, 5),
4544+
names=["batch", "time"],
4545+
)
4546+
assert spec.names == ["batch", "time"]
4547+
4548+
# Test without names
4549+
spec_no_names = Composite(
4550+
{"obs": Bounded(low=-1, high=1, shape=(10, 5, 3, 4))}, shape=(10, 5)
4551+
)
4552+
assert spec_no_names.names == [None, None]
4553+
4554+
def test_names_with_empty_composite(self):
4555+
"""Test names with empty Composite."""
4556+
spec = Composite({}, shape=(10,), names=["batch"])
4557+
assert spec.names == ["batch"]
4558+
assert spec._has_names() is True
4559+
4560+
def test_names_equality(self):
4561+
"""Test that names don't affect equality."""
4562+
spec1 = Composite(
4563+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4564+
shape=(10,),
4565+
names=["batch"],
4566+
)
4567+
4568+
spec2 = Composite(
4569+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))}, shape=(10,)
4570+
)
4571+
4572+
# They should be equal despite different names
4573+
assert spec1 == spec2
4574+
4575+
def test_names_repr(self):
4576+
"""Test that names don't break repr."""
4577+
spec = Composite(
4578+
{"obs": Bounded(low=-1, high=1, shape=(10, 3, 4))},
4579+
shape=(10,),
4580+
names=["batch"],
4581+
)
4582+
4583+
# Should not raise an error
4584+
repr_str = repr(spec)
4585+
assert "Composite" in repr_str
4586+
assert "obs" in repr_str
4587+
4588+
43044589
if __name__ == "__main__":
43054590
args, unknown = argparse.ArgumentParser().parse_known_args()
43064591
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)