Skip to content

Commit db45b84

Browse files
committed
csr.reg: allow direct instantiation of Register class.
The access= parameter can now be set in either __init_subclass__() or __init__(). It has a default value of None, and __init__() will error out if set in both sites or none of them.
1 parent 441af5b commit db45b84

File tree

2 files changed

+99
-60
lines changed

2 files changed

+99
-60
lines changed

amaranth_soc/csr/reg.py

Lines changed: 56 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,16 @@ def flatten(self):
411411

412412

413413
class Register(wiring.Component):
414-
"""Base class for CSR registers.
414+
_doc_template = """
415+
A CSR register.
415416
416417
Parameters
417418
----------
418419
fields : :class:`dict` or :class:`list`
419-
Collection of register fields. If ``None`` (default), a :class:`dict` is populated from
420-
Python :term:`variable annotations <python:variable annotations>`. If ``fields`` is a
421-
:class:`dict`, it is cast to a :class:`FieldMap`; if ``fields`` is a :class:`list`, it is
422-
cast to a :class`FieldArray`.
420+
Collection of register fields. If ``None`` (default), a dict is populated from Python
421+
:term:`variable annotations <python:variable annotations>`. ``fields`` is used to populate
422+
a :class:`FieldMap` or a :class:`FieldArray`, depending on its type (dict or list).
423+
{parameters}
423424
424425
Interface attributes
425426
--------------------
@@ -445,57 +446,59 @@ class Register(wiring.Component):
445446
If ``element.access`` is not writable and at least one field is writable.
446447
"""
447448

448-
def __new__(cls, *args, **kwargs):
449-
if cls is Register:
450-
raise TypeError("csr.Register is a base class and cannot be directly instantiated")
451-
return super().__new__(cls, *args, **kwargs)
449+
__doc__ = _doc_template.format(parameters="""
450+
access : :class:`Element.Access`
451+
Element access mode.
452+
""")
452453

453-
def __init_subclass__(cls, *, access, **kwargs):
454-
# TODO(py3.9): Remove this. Python 3.8 and below use cls.__name__ in the error message
455-
# instead of cls.__qualname__.
456-
# cls.__access = Element.Access(access)
457-
try:
458-
cls.__access = Element.Access(access)
459-
except ValueError as e:
460-
raise ValueError(f"{access!r} is not a valid Element.Access") from e
454+
def __init_subclass__(cls, *, access=None, **kwargs):
455+
if access is not None:
456+
# TODO(py3.9): Remove this. Python 3.8 and below use cls.__name__ in the error message
457+
# instead of cls.__qualname__.
458+
# cls._access = Element.Access(access)
459+
try:
460+
cls._access = Element.Access(access)
461+
except ValueError as e:
462+
raise ValueError(f"{access!r} is not a valid Element.Access") from e
463+
cls.__doc__ = cls._doc_template.format(parameters="")
461464
super().__init_subclass__(**kwargs)
462465

463-
def __init__(self, fields=None):
466+
def __init__(self, fields=None, access=None):
464467
if hasattr(self, "__annotations__"):
465-
def filter_dict(d):
466-
fields = {}
467-
for key, value in d.items():
468-
if isinstance(value, Field):
469-
fields[key] = value
470-
elif isinstance(value, dict):
471-
if sub_fields := filter_dict(value):
472-
fields[key] = sub_fields
473-
elif isinstance(value, list):
474-
if sub_fields := filter_list(value):
475-
fields[key] = sub_fields
476-
return fields
477-
478-
def filter_list(l):
479-
fields = []
480-
for item in l:
481-
if isinstance(item, Field):
482-
fields.append(item)
483-
elif isinstance(item, dict):
484-
if sub_fields := filter_dict(item):
485-
fields.append(sub_fields)
486-
elif isinstance(item, list):
487-
if sub_fields := filter_list(item):
488-
fields.append(sub_fields)
489-
return fields
490-
491-
annot_fields = filter_dict(self.__annotations__)
468+
def filter_fields(src):
469+
if isinstance(src, Field):
470+
return src
471+
if isinstance(src, (dict, list)):
472+
items = enumerate(src) if isinstance(src, list) else src.items()
473+
dst = dict()
474+
for key, value in items:
475+
if new_value := filter_fields(value):
476+
dst[key] = new_value
477+
return list(dst.values()) if isinstance(src, list) else dst
478+
479+
annot_fields = filter_fields(self.__annotations__)
492480

493481
if fields is None:
494482
fields = annot_fields
495483
elif annot_fields:
496484
raise ValueError(f"Field collection {fields} cannot be provided in addition to "
497485
f"field annotations: {', '.join(annot_fields)}")
498486

487+
if access is not None:
488+
# TODO(py3.9): Remove this (see above).
489+
try:
490+
access = Element.Access(access)
491+
except ValueError as e:
492+
raise ValueError(f"{access!r} is not a valid Element.Access") from e
493+
if hasattr(self, "_access") and access != self._access:
494+
raise ValueError(f"Element access mode {access} conflicts with the value "
495+
f"provided during class creation: {self._access}")
496+
elif hasattr(self, "_access"):
497+
access = self._access
498+
else:
499+
raise ValueError("Element access mode must be provided during class creation or "
500+
"instantiation")
501+
499502
if isinstance(fields, dict):
500503
self._fields = FieldMap(fields)
501504
elif isinstance(fields, list):
@@ -506,14 +509,14 @@ def filter_list(l):
506509
width = 0
507510
for field_path, field in self._fields.flatten():
508511
width += Shape.cast(field.port.shape).width
509-
if field.port.access.readable() and not self.__access.readable():
510-
raise ValueError(f"Field {'__'.join(field_path)} is readable, but register access "
511-
f"mode is {self.__access!r}")
512-
if field.port.access.writable() and not self.__access.writable():
513-
raise ValueError(f"Field {'__'.join(field_path)} is writable, but register access "
514-
f"mode is {self.__access!r}")
515-
516-
super().__init__({"element": Out(Element.Signature(width, self.__access))})
512+
if field.port.access.readable() and not access.readable():
513+
raise ValueError(f"Field {'__'.join(field_path)} is readable, but element access "
514+
f"mode is {access}")
515+
if field.port.access.writable() and not access.writable():
516+
raise ValueError(f"Field {'__'.join(field_path)} is writable, but element access "
517+
f"mode is {access}")
518+
519+
super().__init__({"element": Out(Element.Signature(width, access))})
517520

518521
@property
519522
def fields(self):

tests/test_csr_reg.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,20 @@ class FooRegister(Register, access="rw"):
387387
self.assertEqual(reg.element.access, Element.Access.RW)
388388
self.assertEqual(reg.element.width, 12)
389389

390+
def test_access_init(self):
391+
class FooRegister(Register):
392+
a: Field(action.R, unsigned(1))
393+
394+
reg = FooRegister(access="r")
395+
self.assertEqual(reg.element.access, Element.Access.R)
396+
397+
def test_access_same(self):
398+
class FooRegister(Register, access="r"):
399+
a: Field(action.R, unsigned(1))
400+
401+
reg = FooRegister(access="r")
402+
self.assertEqual(reg.element.access, Element.Access.R)
403+
390404
def test_fields_dict(self):
391405
class FooRegister(Register, access=Element.Access.RW):
392406
pass
@@ -426,16 +440,38 @@ class FooRegister(Register, access="r"):
426440
self.assertEqual(reg.element.access, Element.Access.R)
427441
self.assertEqual(reg.element.width, 2)
428442

429-
def test_subclass_requirement(self):
430-
with self.assertRaisesRegex(TypeError,
431-
r"csr\.Register is a base class and cannot be directly instantiated"):
432-
Register()
433-
434443
def test_wrong_access(self):
444+
with self.assertRaisesRegex(ValueError, r"'foo' is not a valid Element.Access"):
445+
Register({"a": Field(action.R, unsigned(1))}, access="foo")
435446
with self.assertRaisesRegex(ValueError, r"'foo' is not a valid Element.Access"):
436447
class FooRegister(Register, access="foo"):
437448
pass
438449

450+
def test_no_access(self):
451+
with self.assertRaisesRegex(ValueError,
452+
r"Element access mode must be provided during class creation or instantiation"):
453+
Register({"a": Field(action.R, unsigned(1))})
454+
455+
class FooRegister(Register, access=None):
456+
pass
457+
with self.assertRaisesRegex(ValueError,
458+
r"Element access mode must be provided during class creation or instantiation"):
459+
FooRegister({"a": Field(action.R, unsigned(1))})
460+
461+
class BarRegister(Register):
462+
pass
463+
with self.assertRaisesRegex(ValueError,
464+
r"Element access mode must be provided during class creation or instantiation"):
465+
BarRegister({"a": Field(action.R, unsigned(1))})
466+
467+
def test_access_conflict(self):
468+
class FooRegister(Register, access="r"):
469+
a: Field(action.R, unsigned(1))
470+
with self.assertRaisesRegex(ValueError,
471+
r"Element access mode Access\.RW conflicts with the value provided during class "
472+
r"creation: Access\.R"):
473+
FooRegister(access="rw")
474+
439475
def test_wrong_fields(self):
440476
class FooRegister(Register, access="w"):
441477
pass
@@ -465,10 +501,10 @@ class WRegister(Register, access="w"):
465501
class RRegister(Register, access="r"):
466502
pass
467503
with self.assertRaisesRegex(ValueError,
468-
r"Field a__b is readable, but register access mode is \<Access\.W: 'w'\>"):
504+
r"Field a__b is readable, but element access mode is Access\.W"):
469505
WRegister({"a": {"b": Field(action.RW, unsigned(1))}})
470506
with self.assertRaisesRegex(ValueError,
471-
r"Field a__b is writable, but register access mode is \<Access\.R: 'r'\>"):
507+
r"Field a__b is writable, but element access mode is Access\.R"):
472508
RRegister({"a": {"b": Field(action.RW, unsigned(1))}})
473509

474510
def test_iter(self):

0 commit comments

Comments
 (0)