Skip to content

Commit 57612f1

Browse files
committed
lib.enum: add Enum wrappers that allow specifying shape.
See #756 and amaranth-lang/rfcs#3.
1 parent ef2e9fa commit 57612f1

File tree

10 files changed

+343
-40
lines changed

10 files changed

+343
-40
lines changed

amaranth/hdl/ast.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,34 @@ def __init__(self, width=1, signed=False):
7878
self.width = width
7979
self.signed = signed
8080

81+
# The algorithm for inferring shape for standard Python enumerations is factored out so that
82+
# `Shape.cast()` and Amaranth's `EnumMeta.as_shape()` can both use it.
83+
@staticmethod
84+
def _cast_plain_enum(obj):
85+
signed = False
86+
width = 0
87+
for member in obj:
88+
try:
89+
member_shape = Const.cast(member.value).shape()
90+
except TypeError as e:
91+
raise TypeError("Only enumerations whose members have constant-castable "
92+
"values can be used in Amaranth code")
93+
if not signed and member_shape.signed:
94+
signed = True
95+
width = max(width + 1, member_shape.width)
96+
elif signed and not member_shape.signed:
97+
width = max(width, member_shape.width + 1)
98+
else:
99+
width = max(width, member_shape.width)
100+
return Shape(width, signed)
101+
81102
@staticmethod
82103
def cast(obj, *, src_loc_at=0):
83104
while True:
84105
if isinstance(obj, Shape):
85106
return obj
107+
elif isinstance(obj, ShapeCastable):
108+
new_obj = obj.as_shape()
86109
elif isinstance(obj, int):
87110
return Shape(obj)
88111
elif isinstance(obj, range):
@@ -93,24 +116,9 @@ def cast(obj, *, src_loc_at=0):
93116
bits_for(obj.stop - obj.step, signed))
94117
return Shape(width, signed)
95118
elif isinstance(obj, type) and issubclass(obj, Enum):
96-
signed = False
97-
width = 0
98-
for member in obj:
99-
try:
100-
member_shape = Const.cast(member.value).shape()
101-
except TypeError as e:
102-
raise TypeError("Only enumerations whose members have constant-castable "
103-
"values can be used in Amaranth code")
104-
if not signed and member_shape.signed:
105-
signed = True
106-
width = max(width + 1, member_shape.width)
107-
elif signed and not member_shape.signed:
108-
width = max(width, member_shape.width + 1)
109-
else:
110-
width = max(width, member_shape.width)
111-
return Shape(width, signed)
112-
elif isinstance(obj, ShapeCastable):
113-
new_obj = obj.as_shape()
119+
# For compatibility with third party enumerations, handle them as if they were
120+
# defined as subclasses of lib.enum.Enum with no explicitly specified shape.
121+
return Shape._cast_plain_enum(obj)
114122
else:
115123
raise TypeError("Object {!r} cannot be converted to an Amaranth shape".format(obj))
116124
if new_obj is obj:
@@ -866,9 +874,17 @@ def __init__(self, *args, src_loc_at=0):
866874
super().__init__(src_loc_at=src_loc_at)
867875
self.parts = []
868876
for index, arg in enumerate(flatten(args)):
877+
if isinstance(arg, Enum) and (not isinstance(type(arg), ShapeCastable) or
878+
not hasattr(arg, "_amaranth_shape_")):
879+
warnings.warn("Argument #{} of Cat() is an enumerated value {!r} without "
880+
"a defined shape used in bit vector context; define the enumeration "
881+
"by inheriting from the class in amaranth.lib.enum and specifying "
882+
"the 'shape=' keyword argument"
883+
.format(index + 1, arg),
884+
SyntaxWarning, stacklevel=2 + src_loc_at)
869885
if isinstance(arg, int) and not isinstance(arg, Enum) and arg not in [0, 1]:
870886
warnings.warn("Argument #{} of Cat() is a bare integer {} used in bit vector "
871-
"context; consider specifying explicit width using C({}, {}) instead"
887+
"context; specify the width explicitly using C({}, {})"
872888
.format(index + 1, arg, arg, bits_for(arg)),
873889
SyntaxWarning, stacklevel=2 + src_loc_at)
874890
self.parts.append(Value.cast(arg))

amaranth/lib/enum.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import enum as py_enum
2+
import warnings
3+
4+
from ..hdl.ast import Shape, ShapeCastable, Const
5+
from .._utils import bits_for
6+
7+
8+
__all__ = py_enum.__all__
9+
10+
11+
for member in py_enum.__all__:
12+
globals()[member] = getattr(py_enum, member)
13+
del member
14+
15+
16+
class EnumMeta(ShapeCastable, py_enum.EnumMeta):
17+
"""Subclass of the standard :class:`enum.EnumMeta` that implements the :class:`ShapeCastable`
18+
protocol.
19+
20+
This metaclass provides the :meth:`as_shape` method, making its instances
21+
:ref:`shape-castable <lang-shapecasting>`, and accepts a ``shape=`` keyword argument
22+
to specify a shape explicitly. Other than this, it acts the same as the standard
23+
:class:`enum.EnumMeta` class; if the ``shape=`` argument is not specified and
24+
:meth:`as_shape` is never called, it places no restrictions on the enumeration class
25+
or the values of its members.
26+
"""
27+
def __new__(metacls, name, bases, namespace, shape=None, **kwargs):
28+
cls = py_enum.EnumMeta.__new__(metacls, name, bases, namespace, **kwargs)
29+
if shape is not None:
30+
# Shape is provided explicitly. Set the `_amaranth_shape_` attribute, and check that
31+
# the values of every member can be cast to the provided shape without truncation.
32+
cls._amaranth_shape_ = shape = Shape.cast(shape)
33+
for member in cls:
34+
try:
35+
Const.cast(member.value)
36+
except TypeError as e:
37+
raise TypeError("Value of enumeration member {!r} must be "
38+
"a constant-castable expression"
39+
.format(member)) from e
40+
width = bits_for(member.value, shape.signed)
41+
if member.value < 0 and not shape.signed:
42+
warnings.warn(
43+
message="Value of enumeration member {!r} is signed, but enumeration "
44+
"shape is {!r}" # the repr will be `unsigned(X)`
45+
.format(member, shape),
46+
category=RuntimeWarning,
47+
stacklevel=2)
48+
elif width > shape.width:
49+
warnings.warn(
50+
message="Value of enumeration member {!r} will be truncated to "
51+
"enumeration shape {!r}"
52+
.format(member, shape),
53+
category=RuntimeWarning,
54+
stacklevel=2)
55+
else:
56+
# Shape is not provided explicitly. Behave the same as a standard enumeration;
57+
# the lack of `_amaranth_shape_` attribute is used to emit a warning when such
58+
# an enumeration is used in a concatenation.
59+
pass
60+
return cls
61+
62+
def as_shape(cls):
63+
"""Cast this enumeration to a shape.
64+
65+
Returns
66+
-------
67+
:class:`Shape`
68+
Explicitly provided shape. If not provided, returns the result of shape-casting
69+
this class :ref:`as a standard Python enumeration <lang-shapeenum>`.
70+
71+
Raises
72+
------
73+
TypeError
74+
If the enumeration has neither an explicitly provided shape nor any members.
75+
"""
76+
if hasattr(cls, "_amaranth_shape_"):
77+
# Shape was provided explicitly; return it.
78+
return cls._amaranth_shape_
79+
elif cls.__members__:
80+
# Shape was not provided explicitly, but enumeration has members; treat it
81+
# the same way `Shape.cast` treats standard library enumerations, so that
82+
# `amaranth.lib.enum.Enum` can be a drop-in replacement for `enum.Enum`.
83+
return Shape._cast_plain_enum(cls)
84+
else:
85+
# Shape was not provided explicitly, and enumeration has no members.
86+
# This is a base or mixin class that cannot be instantiated directly.
87+
raise TypeError("Enumeration '{}.{}' does not have a defined shape"
88+
.format(cls.__module__, cls.__qualname__))
89+
90+
91+
class Enum(py_enum.Enum, metaclass=EnumMeta):
92+
"""Subclass of the standard :class:`enum.Enum` that has :class:`EnumMeta` as
93+
its metaclass."""
94+
95+
96+
class IntEnum(py_enum.IntEnum, metaclass=EnumMeta):
97+
"""Subclass of the standard :class:`enum.IntEnum` that has :class:`EnumMeta` as
98+
its metaclass."""
99+
100+
101+
class Flag(py_enum.Flag, metaclass=EnumMeta):
102+
"""Subclass of the standard :class:`enum.Flag` that has :class:`EnumMeta` as
103+
its metaclass."""
104+
105+
106+
class IntFlag(py_enum.IntFlag, metaclass=EnumMeta):
107+
"""Subclass of the standard :class:`enum.IntFlag` that has :class:`EnumMeta` as
108+
its metaclass."""

docs/changes.rst

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ Apply the following changes to code written against Amaranth 0.3 to migrate it t
2222

2323
While code that uses the features listed as deprecated below will work in Amaranth 0.4, they will be removed in the next version.
2424

25+
Implemented RFCs
26+
----------------
27+
28+
.. _RFC 3: https://amaranth-lang.org/rfcs/0003-enumeration-shapes.html
29+
.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html
30+
.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html
31+
32+
* `RFC 3`_: Enumeration shapes
33+
* `RFC 4`_: Constant-castable expressions
34+
* `RFC 5`_: Remove Const.normalize
35+
2536

2637
Language changes
2738
----------------
@@ -30,19 +41,24 @@ Language changes
3041

3142
* Added: :class:`ShapeCastable`, similar to :class:`ValueCastable`.
3243
* Added: :meth:`Value.as_signed` and :meth:`Value.as_unsigned` can be used on left-hand side of assignment (with no difference in behavior).
33-
* Added: :meth:`Const.cast`, evaluating constant-castable values and returning a :class:`Const`. (`RFC 4`_)
44+
* Added: :meth:`Const.cast`. (`RFC 4`_)
3445
* Added: :meth:`Value.matches` and ``with m.Case():`` accept any constant-castable objects. (`RFC 4`_)
3546
* Changed: :meth:`Value.cast` casts :class:`ValueCastable` objects recursively.
3647
* Changed: :meth:`Value.cast` treats instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) as enumerations rather than integers.
37-
* Changed: ``Value.matches()`` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match ``with m.Case():``.
38-
* Changed: :class:`Cat` accepts instances of classes derived from both :class:`enum.Enum` and :class:`int` (including :class:`enum.IntEnum`) without warning.
48+
* Changed: :meth:`Value.matches` with an empty list of patterns returns ``Const(1)`` rather than ``Const(0)``, to match the behavior of ``with m.Case():``.
49+
* Changed: :class:`Cat` warns if an enumeration without an explicitly specified shape is used.
3950
* Deprecated: :meth:`Const.normalize`. (`RFC 5`_)
4051
* Removed: (deprecated in 0.1) casting of :class:`Shape` to and from a ``(width, signed)`` tuple.
4152
* Removed: (deprecated in 0.3) :class:`ast.UserValue`.
4253
* Removed: (deprecated in 0.3) support for ``# nmigen:`` linter instructions at the beginning of file.
4354

44-
.. _RFC 4: https://amaranth-lang.org/rfcs/0004-const-castable-exprs.html
45-
.. _RFC 5: https://amaranth-lang.org/rfcs/0005-remove-const-normalize.html
55+
56+
Standard library changes
57+
------------------------
58+
59+
.. currentmodule:: amaranth.lib
60+
61+
* Added: :mod:`amaranth.lib.enum`.
4662

4763

4864
Toolchain changes

docs/conf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,17 @@
2727

2828
todo_include_todos = True
2929

30+
autodoc_member_order = "bysource"
31+
autodoc_default_options = {
32+
"members": True
33+
}
34+
autodoc_preserve_defaults = True
35+
3036
napoleon_google_docstring = False
3137
napoleon_numpy_docstring = True
3238
napoleon_use_ivar = True
39+
napoleon_include_init_with_doc = True
40+
napoleon_include_special_with_doc = True
3341
napoleon_custom_sections = ["Platform overrides"]
3442

3543
html_theme = "sphinx_rtd_theme"

docs/lang.rst

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,7 @@ Specifying a shape with a range is convenient for counters, indexes, and all oth
183183
Shapes from enumerations
184184
------------------------
185185

186-
Casting a shape from an :class:`enum.Enum` subclass ``E``:
187-
188-
* fails if any of the enumeration members have non-integer values,
189-
* has a width large enough to represent both ``min(m.value for m in E)`` and ``max(m.value for m in E)``, and
190-
* is signed if either ``min(m.value for m in E)`` or ``max(m.value for m in E)`` are negative, unsigned otherwise.
186+
Casting a shape from an :class:`enum.Enum` subclass requires all of the enumeration members to have :ref:`constant-castable <lang-constcasting>` values. The shape has a width large enough to represent the value of every member, and is signed only if there is a member with a negative value.
191187

192188
Specifying a shape with an enumeration is convenient for finite state machines, multiplexers, complex control signals, and all other values whose width is derived from a few distinct choices they must be able to fit:
193189

@@ -208,9 +204,27 @@ Specifying a shape with an enumeration is convenient for finite state machines,
208204
>>> Shape.cast(Direction)
209205
unsigned(2)
210206

207+
The :mod:`amaranth.lib.enum` module extends the standard enumerations such that their shape can be specified explicitly when they are defined:
208+
209+
.. testsetup::
210+
211+
import amaranth.lib.enum
212+
213+
.. testcode::
214+
215+
class Funct4(amaranth.lib.enum.Enum, shape=unsigned(4)):
216+
ADD = 0
217+
SUB = 1
218+
MUL = 2
219+
220+
.. doctest::
221+
222+
>>> Shape.cast(Funct4)
223+
unsigned(4)
224+
211225
.. note::
212226

213-
The enumeration does not have to subclass :class:`enum.IntEnum`; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers.
227+
The enumeration does not have to subclass :class:`enum.IntEnum` or have :class:`int` as one of its base classes; it only needs to have integers as values of every member. Using enumerations based on :class:`enum.Enum` rather than :class:`enum.IntEnum` prevents unwanted implicit conversion of enum members to integers.
214228

215229

216230
.. _lang-valuecasting:

docs/stdlib.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Standard library
88
.. toctree::
99
:maxdepth: 2
1010

11+
stdlib/enum
1112
stdlib/coding
1213
stdlib/cdc
13-
stdlib/fifo
14+
stdlib/fifo

docs/stdlib/enum.rst

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
Enumerations
2+
############
3+
4+
.. py:module:: amaranth.lib.enum
5+
6+
The :mod:`amaranth.lib.enum` module is a drop-in replacement for the standard :mod:`enum` module that provides extended :class:`Enum`, :class:`IntEnum`, :class:`Flag`, and :class:`IntFlag` classes with the ability to specify a shape explicitly.
7+
8+
A shape can be specified for an enumeration with the ``shape=`` keyword argument:
9+
10+
.. testsetup::
11+
12+
from amaranth import *
13+
14+
.. testcode::
15+
16+
from amaranth.lib import enum
17+
18+
class Funct4(enum.Enum, shape=4):
19+
ADD = 0
20+
SUB = 1
21+
MUL = 2
22+
23+
.. doctest::
24+
25+
>>> Shape.cast(Funct4)
26+
unsigned(4)
27+
28+
This module is a drop-in replacement for the standard :mod:`enum` module, and re-exports all of its members (not just the ones described below). In an Amaranth project, all ``import enum`` statements may be replaced with ``from amaranth.lib import enum``.
29+
30+
31+
Metaclass
32+
=========
33+
34+
.. autoclass:: EnumMeta()
35+
36+
37+
Base classes
38+
============
39+
40+
.. autoclass:: Enum()
41+
.. autoclass:: IntEnum()
42+
.. autoclass:: Flag()
43+
.. autoclass:: IntFlag()

tests/test_hdl_ast.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -798,28 +798,34 @@ def test_int_01(self):
798798
warnings.filterwarnings(action="error", category=SyntaxWarning)
799799
Cat(0, 1, 1, 0)
800800

801-
def test_enum(self):
801+
def test_enum_wrong(self):
802802
class Color(Enum):
803803
RED = 1
804804
BLUE = 2
805-
with warnings.catch_warnings():
806-
warnings.filterwarnings(action="error", category=SyntaxWarning)
805+
with self.assertWarnsRegex(SyntaxWarning,
806+
r"^Argument #1 of Cat\(\) is an enumerated value <Color\.RED: 1> without "
807+
r"a defined shape used in bit vector context; define the enumeration by "
808+
r"inheriting from the class in amaranth\.lib\.enum and specifying "
809+
r"the 'shape=' keyword argument$"):
807810
c = Cat(Color.RED, Color.BLUE)
808811
self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))")
809812

810-
def test_intenum(self):
813+
def test_intenum_wrong(self):
811814
class Color(int, Enum):
812815
RED = 1
813816
BLUE = 2
814-
with warnings.catch_warnings():
815-
warnings.filterwarnings(action="error", category=SyntaxWarning)
817+
with self.assertWarnsRegex(SyntaxWarning,
818+
r"^Argument #1 of Cat\(\) is an enumerated value <Color\.RED: 1> without "
819+
r"a defined shape used in bit vector context; define the enumeration by "
820+
r"inheriting from the class in amaranth\.lib\.enum and specifying "
821+
r"the 'shape=' keyword argument$"):
816822
c = Cat(Color.RED, Color.BLUE)
817823
self.assertEqual(repr(c), "(cat (const 2'd1) (const 2'd2))")
818824

819825
def test_int_wrong(self):
820826
with self.assertWarnsRegex(SyntaxWarning,
821827
r"^Argument #1 of Cat\(\) is a bare integer 2 used in bit vector context; "
822-
r"consider specifying explicit width using C\(2, 2\) instead$"):
828+
r"specify the width explicitly using C\(2, 2\)$"):
823829
Cat(2)
824830

825831

0 commit comments

Comments
 (0)