Skip to content

Commit 54d5c4c

Browse files
committed
Implement RFC 9: Constant initialization for shape-castable objects.
See amaranth-lang/rfcs#9 and #771.
1 parent ea5a150 commit 54d5c4c

File tree

7 files changed

+181
-42
lines changed

7 files changed

+181
-42
lines changed

amaranth/hdl/ast.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def __init_subclass__(cls, **kwargs):
4545
if not hasattr(cls, "as_shape"):
4646
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
4747
f"the `as_shape` method")
48+
if not hasattr(cls, "const"):
49+
raise TypeError(f"Class '{cls.__name__}' deriving from `ShapeCastable` must override "
50+
f"the `const` method")
4851

4952

5053
class Shape:
@@ -988,7 +991,7 @@ class Signal(Value, DUID):
988991
decoder : function
989992
"""
990993

991-
def __init__(self, shape=None, *, name=None, reset=0, reset_less=False,
994+
def __init__(self, shape=None, *, name=None, reset=None, reset_less=False,
992995
attrs=None, decoder=None, src_loc_at=0):
993996
super().__init__(src_loc_at=src_loc_at)
994997

@@ -1005,12 +1008,24 @@ def __init__(self, shape=None, *, name=None, reset=0, reset_less=False,
10051008
self.signed = shape.signed
10061009

10071010
orig_reset = reset
1008-
try:
1009-
reset = Const.cast(reset)
1010-
except TypeError:
1011-
raise TypeError("Reset value must be a constant-castable expression, not {!r}"
1012-
.format(orig_reset))
1013-
if orig_reset not in (0, -1): # Avoid false positives for all-zeroes and all-ones
1011+
if isinstance(orig_shape, ShapeCastable):
1012+
try:
1013+
reset = Const.cast(orig_shape.const(reset))
1014+
except Exception:
1015+
raise TypeError("Reset value must be a constant initializer of {!r}"
1016+
.format(orig_shape))
1017+
if reset.shape() != Shape.cast(orig_shape):
1018+
raise ValueError("Constant returned by {!r}.const() must have the shape that "
1019+
"it casts to, {!r}, and not {!r}"
1020+
.format(orig_shape, Shape.cast(orig_shape),
1021+
reset.shape()))
1022+
else:
1023+
try:
1024+
reset = Const.cast(reset or 0)
1025+
except TypeError:
1026+
raise TypeError("Reset value must be a constant-castable expression, not {!r}"
1027+
.format(orig_reset))
1028+
if orig_reset not in (None, 0, -1): # Avoid false positives for all-zeroes and all-ones
10141029
if reset.shape().signed and not self.signed:
10151030
warnings.warn(
10161031
message="Reset value {!r} is signed, but the signal shape is {!r}"

amaranth/lib/data.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -201,29 +201,44 @@ def __call__(self, target):
201201
"""
202202
return View(self, target)
203203

204-
def _convert_to_int(self, value):
205-
"""Convert ``value``, which may be a dict or an array of field values, to an integer using
206-
the representation defined by this layout.
204+
def const(self, init):
205+
"""Convert a constant initializer to a constant.
207206
208-
This method is private because Amaranth does not currently have a concept of
209-
a constant initializer; this requires an RFC. It will be renamed or removed
210-
in a future version.
207+
Converts ``init``, which may be a sequence or a mapping of field values, to a constant.
208+
209+
Returns
210+
-------
211+
:class:`Const`
212+
A constant that has the same value as a view with this layout that was initialized with
213+
an all-zero value and had every field assigned to the corresponding value in the order
214+
in which they appear in ``init``.
211215
"""
212-
if isinstance(value, Mapping):
213-
iterator = value.items()
214-
elif isinstance(value, Sequence):
215-
iterator = enumerate(value)
216+
if init is None:
217+
iterator = iter(())
218+
elif isinstance(init, Mapping):
219+
iterator = init.items()
220+
elif isinstance(init, Sequence):
221+
iterator = enumerate(init)
216222
else:
217-
raise TypeError("Layout initializer must be a mapping or a sequence, not {!r}"
218-
.format(value))
223+
raise TypeError("Layout constant initializer must be a mapping or a sequence, not {!r}"
224+
.format(init))
219225

220226
int_value = 0
221227
for key, key_value in iterator:
222228
field = self[key]
223-
if isinstance(field.shape, Layout):
224-
key_value = field.shape._convert_to_int(key_value)
225-
int_value |= Const(key_value, Shape.cast(field.shape)).value << field.offset
226-
return int_value
229+
cast_field_shape = Shape.cast(field.shape)
230+
if isinstance(field.shape, ShapeCastable):
231+
key_value = Const.cast(field.shape.const(key_value))
232+
if key_value.shape() != cast_field_shape:
233+
raise ValueError("Constant returned by {!r}.const() must have the shape that "
234+
"it casts to, {!r}, and not {!r}"
235+
.format(field.shape, cast_field_shape,
236+
key_value.shape()))
237+
else:
238+
key_value = Const(key_value, cast_field_shape)
239+
int_value &= ~(((1 << cast_field_shape.width) - 1) << field.offset)
240+
int_value |= key_value.value << field.offset
241+
return Const(int_value, self.as_shape())
227242

228243

229244
class StructLayout(Layout):
@@ -617,13 +632,9 @@ def __init__(self, layout, target=None, *, name=None, reset=None, reset_less=Non
617632
"the {} bit(s) wide view layout"
618633
.format(len(cast_target), cast_layout.size))
619634
else:
620-
if reset is None:
621-
reset = 0
622-
else:
623-
reset = cast_layout._convert_to_int(reset)
624635
if reset_less is None:
625636
reset_less = False
626-
cast_target = Signal(cast_layout, name=name, reset=reset, reset_less=reset_less,
637+
cast_target = Signal(layout, name=name, reset=reset, reset_less=reset_less,
627638
attrs=attrs, decoder=decoder, src_loc_at=src_loc_at + 1)
628639
self.__orig_layout = layout
629640
self.__layout = cast_layout
@@ -774,6 +785,9 @@ def as_shape(cls):
774785
.format(cls.__module__, cls.__qualname__))
775786
return cls.__layout
776787

788+
def const(cls, init):
789+
return cls.as_shape().const(init)
790+
777791

778792
class _Aggregate(View, metaclass=_AggregateMeta):
779793
def __init__(self, target=None, *, name=None, reset=None, reset_less=None,

amaranth/lib/enum.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ def __call__(cls, value):
137137
return value
138138
return super().__call__(value)
139139

140+
def const(cls, init):
141+
# Same considerations apply as above.
142+
if init is None:
143+
# Signal with unspecified reset value passes ``None`` to :meth:`const`.
144+
# Before RFC 9 was implemented, the unspecified reset value was 0, so this keeps
145+
# the old behavior intact.
146+
member = cls(0)
147+
else:
148+
member = cls(init)
149+
return Const(member.value, cls.as_shape())
150+
140151

141152
class Enum(py_enum.Enum, metaclass=EnumMeta):
142153
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as

tests/test_hdl_ast.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from enum import Enum
33

44
from amaranth.hdl.ast import *
5+
from amaranth.lib.enum import Enum as AmaranthEnum
56

67
from .utils import *
78
from amaranth._utils import _ignore_deprecated
@@ -144,6 +145,9 @@ def __init__(self, dest):
144145
def as_shape(self):
145146
return self.dest
146147

148+
def const(self, obj):
149+
return Const(obj, self.dest)
150+
147151

148152
class ShapeCastableTestCase(FHDLTestCase):
149153
def test_no_override(self):
@@ -995,6 +999,29 @@ def test_reset_enum(self):
995999
r"not <StringEnum\.FOO: 'a'>$"):
9961000
Signal(1, reset=StringEnum.FOO)
9971001

1002+
def test_reset_shape_castable_const(self):
1003+
class CastableFromHex(ShapeCastable):
1004+
def as_shape(self):
1005+
return unsigned(8)
1006+
1007+
def const(self, init):
1008+
return int(init, 16)
1009+
1010+
s1 = Signal(CastableFromHex(), reset="aa")
1011+
self.assertEqual(s1.reset, 0xaa)
1012+
1013+
with self.assertRaisesRegex(ValueError,
1014+
r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape "
1015+
r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"):
1016+
Signal(CastableFromHex(), reset="01")
1017+
1018+
def test_reset_shape_castable_enum_wrong(self):
1019+
class EnumA(AmaranthEnum):
1020+
X = 1
1021+
with self.assertRaisesRegex(TypeError,
1022+
r"^Reset value must be a constant initializer of <enum 'EnumA'>$"):
1023+
Signal(EnumA) # implied reset=0
1024+
9981025
def test_reset_signed_mismatch(self):
9991026
with self.assertWarnsRegex(SyntaxWarning,
10001027
r"^Reset value -2 is signed, but the signal shape is unsigned\(2\)$"):

tests/test_hdl_dsl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ class Color(Enum):
436436
RED = 1
437437
BLUE = 2
438438
m = Module()
439-
se = Signal(Color)
439+
se = Signal(Color, reset=Color.RED)
440440
with m.Switch(se):
441441
with m.Case(Color.RED):
442442
m.d.comb += self.c1.eq(1)

tests/test_lib_data.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def __init__(self, shape):
1616
def as_shape(self):
1717
return self.shape
1818

19+
def const(self, init):
20+
return Const(init, self.shape)
21+
1922

2023
class FieldTestCase(TestCase):
2124
def test_construct(self):
@@ -332,7 +335,7 @@ def test_key_wrong_type(self):
332335
il[object()]
333336

334337

335-
class LayoutTestCase(TestCase):
338+
class LayoutTestCase(FHDLTestCase):
336339
def test_cast(self):
337340
sl = StructLayout({})
338341
self.assertIs(Layout.cast(sl), sl)
@@ -371,6 +374,53 @@ def test_call(self):
371374
self.assertIs(Layout.of(v), sl)
372375
self.assertIs(v.as_value(), s)
373376

377+
def test_const(self):
378+
sl = StructLayout({
379+
"a": unsigned(1),
380+
"b": unsigned(2)
381+
})
382+
self.assertRepr(sl.const(None), "(const 3'd0)")
383+
self.assertRepr(sl.const({"a": 0b1, "b": 0b10}), "(const 3'd5)")
384+
385+
ul = UnionLayout({
386+
"a": unsigned(1),
387+
"b": unsigned(2)
388+
})
389+
self.assertRepr(ul.const({"a": 0b11}), "(const 2'd1)")
390+
self.assertRepr(ul.const({"b": 0b10}), "(const 2'd2)")
391+
self.assertRepr(ul.const({"a": 0b1, "b": 0b10}), "(const 2'd2)")
392+
393+
def test_const_wrong(self):
394+
sl = StructLayout({"f": unsigned(1)})
395+
with self.assertRaisesRegex(TypeError,
396+
r"^Layout constant initializer must be a mapping or a sequence, not "
397+
r"<.+?object.+?>$"):
398+
sl.const(object())
399+
400+
def test_const_field_shape_castable(self):
401+
class CastableFromHex(ShapeCastable):
402+
def as_shape(self):
403+
return unsigned(8)
404+
405+
def const(self, init):
406+
return int(init, 16)
407+
408+
sl = StructLayout({"f": CastableFromHex()})
409+
self.assertRepr(sl.const({"f": "aa"}), "(const 8'd170)")
410+
411+
with self.assertRaisesRegex(ValueError,
412+
r"^Constant returned by <.+?CastableFromHex.+?>\.const\(\) must have the shape "
413+
r"that it casts to, unsigned\(8\), and not unsigned\(1\)$"):
414+
sl.const({"f": "01"})
415+
416+
def test_signal_reset(self):
417+
sl = StructLayout({
418+
"a": unsigned(1),
419+
"b": unsigned(2)
420+
})
421+
self.assertEqual(Signal(sl).reset, 0)
422+
self.assertEqual(Signal(sl, reset={"a": 0b1, "b": 0b10}).reset, 5)
423+
374424

375425
class ViewTestCase(FHDLTestCase):
376426
def test_construct(self):
@@ -434,7 +484,7 @@ def test_target_wrong_size(self):
434484

435485
def test_signal_reset_wrong(self):
436486
with self.assertRaisesRegex(TypeError,
437-
r"^Layout initializer must be a mapping or a sequence, not 1$"):
487+
r"^Reset value must be a constant initializer of StructLayout\({}\)$"):
438488
View(StructLayout({}), reset=0b1)
439489

440490
def test_target_signal_wrong(self):
@@ -483,6 +533,9 @@ def as_shape(self):
483533
def __call__(self, value):
484534
return value[::-1]
485535

536+
def const(self, init):
537+
return Const(init, 2)
538+
486539
v = View(StructLayout({
487540
"f": Reverser()
488541
}))
@@ -497,13 +550,15 @@ def as_shape(self):
497550
def __call__(self, value):
498551
pass
499552

553+
def const(self, init):
554+
return Const(init, 2)
555+
500556
v = View(StructLayout({
501557
"f": WrongCastable()
502558
}))
503559
with self.assertRaisesRegex(TypeError,
504-
r"^<tests\.test_lib_data\.ViewTestCase\.test_getitem_custom_call_wrong\.<locals>"
505-
r"\.WrongCastable object at 0x.+?>\.__call__\(\) must return a value or "
506-
r"a value-castable object, not None$"):
560+
r"^<.+?\.WrongCastable.+?>\.__call__\(\) must return a value or a value-castable "
561+
r"object, not None$"):
507562
v.f
508563

509564
def test_index_wrong_missing(self):

tests/test_lib_enum.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,12 @@
55

66

77
class EnumTestCase(FHDLTestCase):
8-
def test_non_int_members(self):
8+
def test_members_non_int(self):
99
# Mustn't raise to be a drop-in replacement for Enum.
1010
class EnumA(Enum):
1111
A = "str"
1212

13-
def test_non_const_non_int_members_wrong(self):
14-
with self.assertRaisesRegex(TypeError,
15-
r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"):
16-
class EnumA(Enum, shape=unsigned(1)):
17-
A = "str"
18-
19-
def test_const_non_int_members(self):
13+
def test_members_const_non_int(self):
2014
class EnumA(Enum):
2115
A = C(0)
2216
B = C(1)
@@ -59,6 +53,12 @@ class EnumD(Enum):
5953
B = -5
6054
self.assertEqual(Shape.cast(EnumD), signed(4))
6155

56+
def test_shape_members_non_const_non_int_wrong(self):
57+
with self.assertRaisesRegex(TypeError,
58+
r"^Value 'str' of enumeration member 'A' must be a constant-castable expression$"):
59+
class EnumA(Enum, shape=unsigned(1)):
60+
A = "str"
61+
6262
def test_shape_explicit_wrong_signed_mismatch(self):
6363
with self.assertWarnsRegex(SyntaxWarning,
6464
r"^Value -1 of enumeration member 'A' is signed, but the enumeration "
@@ -88,6 +88,23 @@ class EnumA(Enum, shape=unsigned(10)):
8888
A = 1
8989
self.assertRepr(Value.cast(EnumA.A), "(const 10'd1)")
9090

91+
def test_const_no_shape(self):
92+
class EnumA(Enum):
93+
Z = 0
94+
A = 10
95+
B = 20
96+
self.assertRepr(EnumA.const(None), "(const 5'd0)")
97+
self.assertRepr(EnumA.const(10), "(const 5'd10)")
98+
self.assertRepr(EnumA.const(EnumA.A), "(const 5'd10)")
99+
100+
def test_const_shape(self):
101+
class EnumA(Enum, shape=8):
102+
Z = 0
103+
A = 10
104+
self.assertRepr(EnumA.const(None), "(const 8'd0)")
105+
self.assertRepr(EnumA.const(10), "(const 8'd10)")
106+
self.assertRepr(EnumA.const(EnumA.A), "(const 8'd10)")
107+
91108
def test_shape_implicit_wrong_in_concat(self):
92109
class EnumA(Enum):
93110
A = 0

0 commit comments

Comments
 (0)