Skip to content

Commit 58721ee

Browse files
committed
hdl: implement constant-castable expressions.
See #755 and amaranth-lang/rfcs#4.
1 parent bef2052 commit 58721ee

File tree

5 files changed

+203
-106
lines changed

5 files changed

+203
-106
lines changed

amaranth/hdl/ast.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,21 @@ def cast(obj, *, src_loc_at=0):
9393
bits_for(obj.stop - obj.step, signed))
9494
return Shape(width, signed)
9595
elif isinstance(obj, type) and issubclass(obj, Enum):
96-
min_value = min(member.value for member in obj)
97-
max_value = max(member.value for member in obj)
98-
if not isinstance(min_value, int) or not isinstance(max_value, int):
99-
raise TypeError("Only enumerations with integer values can be used "
100-
"as value shapes")
101-
signed = min_value < 0 or max_value < 0
102-
width = max(bits_for(min_value, signed), bits_for(max_value, signed))
96+
signed = False
97+
width = 0
98+
for member in obj:
99+
try:
100+
member_shape = Const.cast(member.value).shape()
101+
except TypeError as e:
102+
raise TypeError("Only enumerations whose members have constant-castable "
103+
"values can be used in Amaranth code")
104+
if not signed and member_shape.signed:
105+
signed = True
106+
width = max(width + 1, member_shape.width)
107+
elif signed and not member_shape.signed:
108+
width = max(width, member_shape.width + 1)
109+
else:
110+
width = max(width, member_shape.width)
103111
return Shape(width, signed)
104112
elif isinstance(obj, ShapeCastable):
105113
new_obj = obj.as_shape()
@@ -402,11 +410,8 @@ def matches(self, *patterns):
402410
``1`` if any pattern matches the value, ``0`` otherwise.
403411
"""
404412
matches = []
413+
# This code should accept exactly the same patterns as `with m.Case(...):`.
405414
for pattern in patterns:
406-
if not isinstance(pattern, (int, str, Enum)):
407-
raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
408-
"not {!r}"
409-
.format(pattern))
410415
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
411416
raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
412417
"bits, and may include whitespace"
@@ -416,23 +421,26 @@ def matches(self, *patterns):
416421
raise SyntaxError("Match pattern '{}' must have the same width as match value "
417422
"(which is {})"
418423
.format(pattern, len(self)))
419-
if isinstance(pattern, int) and bits_for(pattern) > len(self):
420-
warnings.warn("Match pattern '{:b}' is wider than match value "
421-
"(which has width {}); comparison will never be true"
422-
.format(pattern, len(self)),
423-
SyntaxWarning, stacklevel=3)
424-
continue
425424
if isinstance(pattern, str):
426425
pattern = "".join(pattern.split()) # remove whitespace
427426
mask = int(pattern.replace("0", "1").replace("-", "0"), 2)
428427
pattern = int(pattern.replace("-", "0"), 2)
429428
matches.append((self & mask) == pattern)
430-
elif isinstance(pattern, int):
431-
matches.append(self == pattern)
432-
elif isinstance(pattern, Enum):
433-
matches.append(self == pattern.value)
434429
else:
435-
assert False
430+
try:
431+
orig_pattern, pattern = pattern, Const.cast(pattern)
432+
except TypeError as e:
433+
raise SyntaxError("Match pattern must be a string or a constant-castable "
434+
"expression, not {!r}"
435+
.format(pattern)) from e
436+
pattern_len = bits_for(pattern.value)
437+
if pattern_len > len(self):
438+
warnings.warn("Match pattern '{!r}' ({}'{:b}) is wider than match value "
439+
"(which has width {}); comparison will never be true"
440+
.format(orig_pattern, pattern_len, pattern.value, len(self)),
441+
SyntaxWarning, stacklevel=2)
442+
continue
443+
matches.append(self == pattern)
436444
if not matches:
437445
return Const(0)
438446
elif len(matches) == 1:
@@ -560,9 +568,6 @@ def _lhs_signals(self):
560568
def _rhs_signals(self):
561569
pass # :nocov:
562570

563-
def _as_const(self):
564-
raise TypeError("Value {!r} cannot be evaluated as constant".format(self))
565-
566571
__hash__ = None
567572

568573

@@ -595,6 +600,28 @@ def normalize(value, shape):
595600
value |= ~mask
596601
return value
597602

603+
@staticmethod
604+
def cast(obj):
605+
"""Converts ``obj`` to an Amaranth constant.
606+
607+
First, ``obj`` is converted to a value using :meth:`Value.cast`. If it is a constant, it
608+
is returned. If it is a constant-castable expression, it is evaluated and returned.
609+
Otherwise, :exn:`TypeError` is raised.
610+
"""
611+
obj = Value.cast(obj)
612+
if type(obj) is Const:
613+
return obj
614+
elif type(obj) is Cat:
615+
value = 0
616+
width = 0
617+
for part in obj.parts:
618+
const = Const.cast(part)
619+
value |= const.value << width
620+
width += len(const)
621+
return Const(value, width)
622+
else:
623+
raise TypeError("Value {!r} cannot be converted to an Amaranth constant".format(obj))
624+
598625
def __init__(self, value, shape=None, *, src_loc_at=0):
599626
# We deliberately do not call Value.__init__ here.
600627
self.value = int(value)
@@ -617,9 +644,6 @@ def shape(self):
617644
def _rhs_signals(self):
618645
return SignalSet()
619646

620-
def _as_const(self):
621-
return self.value
622-
623647
def __repr__(self):
624648
return "(const {}'{}d{})".format(self.width, "s" if self.signed else "", self.value)
625649

@@ -858,13 +882,6 @@ def _lhs_signals(self):
858882
def _rhs_signals(self):
859883
return union((part._rhs_signals() for part in self.parts), start=SignalSet())
860884

861-
def _as_const(self):
862-
value = 0
863-
for part in reversed(self.parts):
864-
value <<= len(part)
865-
value |= part._as_const()
866-
return value
867-
868885
def __repr__(self):
869886
return "(cat {})".format(" ".join(map(repr, self.parts)))
870887

amaranth/hdl/dsl.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,11 +305,8 @@ def Case(self, *patterns):
305305
src_loc = tracer.get_src_loc(src_loc_at=1)
306306
switch_data = self._get_ctrl("Switch")
307307
new_patterns = ()
308+
# This code should accept exactly the same patterns as `v.matches(...)`.
308309
for pattern in patterns:
309-
if not isinstance(pattern, (int, str, Enum)):
310-
raise SyntaxError("Case pattern must be an integer, a string, or an enumeration, "
311-
"not {!r}"
312-
.format(pattern))
313310
if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
314311
raise SyntaxError("Case pattern '{}' must consist of 0, 1, and - (don't care) "
315312
"bits, and may include whitespace"
@@ -319,20 +316,24 @@ def Case(self, *patterns):
319316
raise SyntaxError("Case pattern '{}' must have the same width as switch value "
320317
"(which is {})"
321318
.format(pattern, len(switch_data["test"])))
322-
if isinstance(pattern, int) and bits_for(pattern) > len(switch_data["test"]):
323-
warnings.warn("Case pattern '{:b}' is wider than switch value "
324-
"(which has width {}); comparison will never be true"
325-
.format(pattern, len(switch_data["test"])),
326-
SyntaxWarning, stacklevel=3)
327-
continue
328-
if isinstance(pattern, Enum) and bits_for(pattern.value) > len(switch_data["test"]):
329-
warnings.warn("Case pattern '{:b}' ({}.{}) is wider than switch value "
330-
"(which has width {}); comparison will never be true"
331-
.format(pattern.value, pattern.__class__.__name__, pattern.name,
332-
len(switch_data["test"])),
333-
SyntaxWarning, stacklevel=3)
334-
continue
335-
new_patterns = (*new_patterns, pattern)
319+
if isinstance(pattern, str):
320+
new_patterns = (*new_patterns, pattern)
321+
else:
322+
try:
323+
orig_pattern, pattern = pattern, Const.cast(pattern)
324+
except TypeError as e:
325+
raise SyntaxError("Case pattern must be a string or a constant-castable "
326+
"expression, not {!r}"
327+
.format(pattern)) from e
328+
pattern_len = bits_for(pattern.value)
329+
if pattern_len > len(switch_data["test"]):
330+
warnings.warn("Case pattern '{!r}' ({}'{:b}) is wider than switch value "
331+
"(which has width {}); comparison will never be true"
332+
.format(orig_pattern, pattern_len, pattern.value,
333+
len(switch_data["test"])),
334+
SyntaxWarning, stacklevel=3)
335+
continue
336+
new_patterns = (*new_patterns, pattern.value)
336337
try:
337338
_outer_case, self._statements = self._statements, []
338339
self._ctrl_context = None

docs/lang.rst

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,43 @@ All of the examples below assume that a glob import is used.
3434
from amaranth import *
3535

3636

37+
.. _lang-shapes:
38+
39+
Shapes
40+
======
41+
42+
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
43+
44+
.. doctest::
45+
46+
>>> Shape(width=5, signed=False)
47+
unsigned(5)
48+
>>> Shape(width=12, signed=True)
49+
signed(12)
50+
51+
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
52+
53+
.. doctest::
54+
55+
>>> unsigned(5) == Shape(width=5, signed=False)
56+
True
57+
>>> signed(12) == Shape(width=12, signed=True)
58+
True
59+
60+
61+
Shapes of values
62+
----------------
63+
64+
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
65+
66+
.. doctest::
67+
68+
>>> Const(5).shape()
69+
unsigned(3)
70+
>>> len(Const(5))
71+
3
72+
73+
3774
.. _lang-values:
3875

3976
Values
@@ -79,43 +116,6 @@ The shape of the constant can be specified explicitly, in which case the number'
79116
0
80117

81118

82-
.. _lang-shapes:
83-
84-
Shapes
85-
======
86-
87-
A ``Shape`` is an object with two attributes, ``.width`` and ``.signed``. It can be constructed directly:
88-
89-
.. doctest::
90-
91-
>>> Shape(width=5, signed=False)
92-
unsigned(5)
93-
>>> Shape(width=12, signed=True)
94-
signed(12)
95-
96-
However, in most cases, the shape is always constructed with the same signedness, and the aliases ``signed`` and ``unsigned`` are more convenient:
97-
98-
.. doctest::
99-
100-
>>> unsigned(5) == Shape(width=5, signed=False)
101-
True
102-
>>> signed(12) == Shape(width=12, signed=True)
103-
True
104-
105-
106-
Shapes of values
107-
----------------
108-
109-
All values have a ``.shape()`` method that computes their shape. The width of a value ``v``, ``v.shape().width``, can also be retrieved with ``len(v)``.
110-
111-
.. doctest::
112-
113-
>>> Const(5).shape()
114-
unsigned(3)
115-
>>> len(Const(5))
116-
3
117-
118-
119119
.. _lang-shapecasting:
120120

121121
Shape casting
@@ -218,7 +218,7 @@ Specifying a shape with an enumeration is convenient for finite state machines,
218218
Value casting
219219
=============
220220

221-
Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly.
221+
Like shapes, values may be *cast* from other objects, which are called *value-castable*. Casting to values allows objects that are not provided by Amaranth, such as integers or enumeration members, to be used in Amaranth expressions directly.
222222

223223
.. TODO: link to ValueCastable
224224
@@ -228,7 +228,7 @@ Casting to a value can be done explicitly with ``Value.cast``, but is usually im
228228
Values from integers
229229
--------------------
230230

231-
Casting a value from an integer ``i`` is a shorthand for ``Const(i)``:
231+
Casting a value from an integer ``i`` is equivalent to ``Const(i)``:
232232

233233
.. doctest::
234234

@@ -242,7 +242,7 @@ Casting a value from an integer ``i`` is a shorthand for ``Const(i)``:
242242
Values from enumeration members
243243
-------------------------------
244244

245-
Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.value, type(m))``:
245+
Casting a value from an enumeration member ``m`` is equivalent to ``Const(m.value, type(m))``:
246246

247247
.. doctest::
248248

@@ -254,6 +254,55 @@ Casting a value from an enumeration member ``m`` is a shorthand for ``Const(m.va
254254

255255
If a value subclasses :class:`enum.IntEnum` or its class otherwise inherits from both :class:`int` and :class:`Enum`, it is treated as an enumeration.
256256

257+
258+
.. _lang-constcasting:
259+
260+
Constant casting
261+
================
262+
263+
A subset of :ref:`values <lang-values>` are *constant-castable*. If a value is constant-castable and all of its operands are also constant-castable, it can be converted to a :class:`Const`, the numeric value of which can then be read by Python code. This provides a way to perform computation on Amaranth values while constructing the design.
264+
265+
.. TODO: link to m.Case and v.matches() below
266+
267+
Constant-castable objects are accepted anywhere a constant integer is accepted. Casting to a constant can also be done explicitly with :meth:`Const.cast`:
268+
269+
.. doctest::
270+
271+
>>> Const.cast(Cat(Direction.TOP, Direction.LEFT))
272+
(const 4'd4)
273+
274+
.. TODO: uncomment when this actually works
275+
276+
.. comment::
277+
278+
They may be used in enumeration members:
279+
280+
.. testcode::
281+
282+
class Funct(enum.Enum):
283+
ADD = 0
284+
...
285+
286+
class Op(enum.Enum):
287+
REG = 0
288+
IMM = 1
289+
290+
class Instr(enum.Enum):
291+
ADD = Cat(Funct.ADD, Op.REG)
292+
ADDI = Cat(Funct.ADD, Op.IMM)
293+
...
294+
295+
296+
.. note::
297+
298+
At the moment, only the following expressions are constant-castable:
299+
300+
* :class:`Const`
301+
* :class:`Cat`
302+
303+
This list will be expanded in the future.
304+
305+
257306
.. _lang-signals:
258307

259308
Signals

0 commit comments

Comments
 (0)