Skip to content

Commit f95fe45

Browse files
committed
Implement RFC 22: Add ValueCastable.shape().
Fixes #794. Closes #876.
1 parent 7714ce3 commit f95fe45

File tree

5 files changed

+67
-33
lines changed

5 files changed

+67
-33
lines changed

amaranth/hdl/ast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,11 @@ def like(other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
11101110
new_name = other.name + str(name_suffix)
11111111
else:
11121112
new_name = tracer.get_var_name(depth=2 + src_loc_at, default="$like")
1113-
kw = dict(shape=Value.cast(other).shape(), name=new_name)
1113+
if isinstance(other, ValueCastable):
1114+
shape = other.shape()
1115+
else:
1116+
shape = Value.cast(other).shape()
1117+
kw = dict(shape=shape, name=new_name)
11141118
if isinstance(other, Signal):
11151119
kw.update(reset=other.reset, reset_less=other.reset_less,
11161120
attrs=other.attrs, decoder=other.decoder)
@@ -1363,6 +1367,9 @@ def __init_subclass__(cls, **kwargs):
13631367
if not hasattr(cls, "as_value"):
13641368
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
13651369
"the `as_value` method")
1370+
if not hasattr(cls, "shape"):
1371+
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override "
1372+
"the `shape` method")
13661373
if not hasattr(cls.as_value, "_ValueCastable__memoized"):
13671374
raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate "
13681375
"the `as_value` method with the `ValueCastable.lowermethod` decorator")

amaranth/lib/data.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,6 @@ def cast(obj):
115115
raise TypeError("Object {!r} cannot be converted to a data layout"
116116
.format(obj))
117117

118-
@staticmethod
119-
def of(obj):
120-
"""Extract the layout that was used to create a view.
121-
122-
Raises
123-
------
124-
TypeError
125-
If ``obj`` is not a :class:`View` instance.
126-
"""
127-
if not isinstance(obj, View):
128-
raise TypeError("Object {!r} is not a data view"
129-
.format(obj))
130-
return obj._View__orig_layout
131-
132118
@abstractmethod
133119
def __iter__(self):
134120
"""Iterate fields in the layout.
@@ -611,6 +597,16 @@ def __init__(self, layout, target):
611597
self.__layout = cast_layout
612598
self.__target = cast_target
613599

600+
def shape(self):
601+
"""Get layout of this view.
602+
603+
Returns
604+
-------
605+
:class:`Layout`
606+
The ``layout`` provided when constructing the view.
607+
"""
608+
return self.__orig_layout
609+
614610
@ValueCastable.lowermethod
615611
def as_value(self):
616612
"""Get underlying value.

docs/changes.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,20 @@ Implemented RFCs
4242
.. _RFC 9: https://amaranth-lang.org/rfcs/0009-const-init-shape-castable.html
4343
.. _RFC 10: https://amaranth-lang.org/rfcs/0010-move-repl-to-value.html
4444
.. _RFC 15: https://amaranth-lang.org/rfcs/0015-lifting-shape-castables.html
45+
.. _RFC 22: https://amaranth-lang.org/rfcs/0022-valuecastable-shape.html
4546

4647
* `RFC 1`_: Aggregate data structure library
4748
* `RFC 3`_: Enumeration shapes
4849
* `RFC 4`_: Constant-castable expressions
49-
* `RFC 5`_: Remove Const.normalize
50+
* `RFC 5`_: Remove ``Const.normalize``
5051
* `RFC 6`_: CRC generator
5152
* `RFC 8`_: Aggregate extensibility
5253
* `RFC 9`_: Constant initialization for shape-castable objects
5354
* `RFC 8`_: Aggregate extensibility
5455
* `RFC 9`_: Constant initialization for shape-castable objects
55-
* `RFC 10`_: Move Repl to Value.replicate
56+
* `RFC 10`_: Move ``Repl`` to ``Value.replicate``
5657
* `RFC 15`_: Lifting shape-castable objects
58+
* `RFC 22`_: Define ``ValueCastable.shape()``
5759

5860

5961
Language changes

tests/test_hdl_ast.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,9 @@ class MockValueCastable(ValueCastable):
11831183
def __init__(self, dest):
11841184
self.dest = dest
11851185

1186+
def shape(self):
1187+
return Value.cast(self.dest).shape()
1188+
11861189
@ValueCastable.lowermethod
11871190
def as_value(self):
11881191
return self.dest
@@ -1192,6 +1195,9 @@ class MockValueCastableChanges(ValueCastable):
11921195
def __init__(self, width=0):
11931196
self.width = width
11941197

1198+
def shape(self):
1199+
return unsigned(self.width)
1200+
11951201
@ValueCastable.lowermethod
11961202
def as_value(self):
11971203
return Signal(self.width)
@@ -1201,6 +1207,9 @@ class MockValueCastableCustomGetattr(ValueCastable):
12011207
def __init__(self):
12021208
pass
12031209

1210+
def shape(self):
1211+
assert False
1212+
12041213
@ValueCastable.lowermethod
12051214
def as_value(self):
12061215
return Const(0)
@@ -1218,17 +1227,30 @@ class MockValueCastableNotDecorated(ValueCastable):
12181227
def __init__(self):
12191228
pass
12201229

1230+
def shape(self):
1231+
pass
1232+
12211233
def as_value(self):
12221234
return Signal()
12231235

12241236
def test_no_override(self):
12251237
with self.assertRaisesRegex(TypeError,
1226-
r"^Class 'MockValueCastableNoOverride' deriving from `ValueCastable` must "
1238+
r"^Class 'MockValueCastableNoOverrideAsValue' deriving from `ValueCastable` must "
12271239
r"override the `as_value` method$"):
1228-
class MockValueCastableNoOverride(ValueCastable):
1240+
class MockValueCastableNoOverrideAsValue(ValueCastable):
12291241
def __init__(self):
12301242
pass
12311243

1244+
with self.assertRaisesRegex(TypeError,
1245+
r"^Class 'MockValueCastableNoOverrideShapec' deriving from `ValueCastable` must "
1246+
r"override the `shape` method$"):
1247+
class MockValueCastableNoOverrideShapec(ValueCastable):
1248+
def __init__(self):
1249+
pass
1250+
1251+
def as_value(self):
1252+
return Signal()
1253+
12321254
def test_memoized(self):
12331255
vc = MockValueCastableChanges(1)
12341256
sig1 = vc.as_value()

tests/test_lib_data.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,6 @@ def test_cast_wrong_recur(self):
365365
r"^Shape-castable object <.+> casts to itself$"):
366366
Layout.cast(sc)
367367

368-
def test_of_wrong(self):
369-
with self.assertRaisesRegex(TypeError,
370-
r"^Object <.+> is not a data view$"):
371-
Layout.of(object())
372-
373368
def test_eq_wrong_recur(self):
374369
sc = MockShapeCastable(None)
375370
sc.shape = sc
@@ -379,7 +374,7 @@ def test_call(self):
379374
sl = StructLayout({"f": unsigned(1)})
380375
s = Signal(1)
381376
v = sl(s)
382-
self.assertIs(Layout.of(v), sl)
377+
self.assertIs(v.shape(), sl)
383378
self.assertIs(v.as_value(), s)
384379

385380
def test_const(self):
@@ -621,6 +616,11 @@ def test_attr_wrong_reserved(self):
621616
r"and may only be accessed by indexing$"):
622617
Signal(StructLayout({"_c": signed(1)}))._c
623618

619+
def test_signal_like(self):
620+
s1 = Signal(StructLayout({"a": unsigned(1)}))
621+
s2 = Signal.like(s1)
622+
self.assertEqual(s2.shape(), StructLayout({"a": unsigned(1)}))
623+
624624
def test_bug_837_array_layout_getitem_str(self):
625625
with self.assertRaisesRegex(TypeError,
626626
r"^Views with array layout may only be indexed with an integer or a value, "
@@ -646,7 +646,7 @@ class S(Struct):
646646
}))
647647

648648
v = Signal(S)
649-
self.assertEqual(Layout.of(v), S)
649+
self.assertEqual(v.shape(), S)
650650
self.assertEqual(Value.cast(v).shape(), Shape.cast(S))
651651
self.assertEqual(Value.cast(v).name, "v")
652652
self.assertRepr(v.a, "(slice (sig v) 0:1)")
@@ -666,11 +666,11 @@ class S(Struct):
666666
self.assertEqual(Shape.cast(S), unsigned(9))
667667

668668
v = Signal(S)
669-
self.assertIs(Layout.of(v), S)
669+
self.assertIs(v.shape(), S)
670670
self.assertIsInstance(v, S)
671-
self.assertIs(Layout.of(v.b), R)
671+
self.assertIs(v.b.shape(), R)
672672
self.assertIsInstance(v.b, R)
673-
self.assertIs(Layout.of(v.b.q), Q)
673+
self.assertIs(v.b.q.shape(), Q)
674674
self.assertIsInstance(v.b.q, View)
675675
self.assertRepr(v.b.p, "(slice (slice (sig v) 1:9) 0:4)")
676676
self.assertRepr(v.b.q.as_value(), "(slice (slice (sig v) 1:9) 4:8)")
@@ -747,10 +747,17 @@ class S(Struct):
747747
b: int
748748
c: str = "x"
749749

750-
self.assertEqual(Layout.of(Signal(S)), StructLayout({"a": unsigned(1)}))
750+
self.assertEqual(Layout.cast(S), StructLayout({"a": unsigned(1)}))
751751
self.assertEqual(S.__annotations__, {"b": int, "c": str})
752752
self.assertEqual(S.c, "x")
753753

754+
def test_signal_like(self):
755+
class S(Struct):
756+
a: 1
757+
s1 = Signal(S)
758+
s2 = Signal.like(s1)
759+
self.assertEqual(s2.shape(), S)
760+
754761

755762
class UnionTestCase(FHDLTestCase):
756763
def test_construct(self):
@@ -765,7 +772,7 @@ class U(Union):
765772
}))
766773

767774
v = Signal(U)
768-
self.assertEqual(Layout.of(v), U)
775+
self.assertEqual(v.shape(), U)
769776
self.assertEqual(Value.cast(v).shape(), Shape.cast(U))
770777
self.assertRepr(v.a, "(slice (sig v) 0:1)")
771778
self.assertRepr(v.b, "(s (slice (sig v) 0:3))")
@@ -887,7 +894,7 @@ class Kind(Enum):
887894

888895
view1 = Signal(layout1)
889896
self.assertIsInstance(view1, View)
890-
self.assertEqual(Layout.of(view1), layout1)
897+
self.assertEqual(view1.shape(), layout1)
891898
self.assertEqual(view1.as_value().shape(), unsigned(3))
892899

893900
m1 = Module()
@@ -933,4 +940,4 @@ def check_m2():
933940

934941
self.assertEqual(layout1, Layout.cast(SomeVariant))
935942

936-
self.assertIs(SomeVariant, Layout.of(view2))
943+
self.assertIs(SomeVariant, view2.shape())

0 commit comments

Comments
 (0)