Skip to content

Commit e92c7d9

Browse files
committed
csr.reg: check the signature of fields instantiated by Field.create().
1 parent 4464c31 commit e92c7d9

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

amaranth_soc/csr/reg.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,16 @@ class Field:
170170
Positional arguments passed to ``field_cls.__init__``.
171171
**kwargs : :class:`dict`
172172
Keyword arguments passed to ``field_cls.__init__``.
173+
174+
Raises
175+
------
176+
:exc:`TypeError`
177+
If ``field_cls`` is not a subclass of :class:`wiring.Component`.
173178
"""
174179
def __init__(self, field_cls, *args, **kwargs):
180+
if not issubclass(field_cls, wiring.Component):
181+
raise TypeError(f"{field_cls.__qualname__} must be a subclass of wiring.Component")
182+
175183
self._field_cls = field_cls
176184
self._args = args
177185
self._kwargs = kwargs
@@ -183,8 +191,22 @@ def create(self):
183191
-------
184192
:class:`object`
185193
The instance returned by ``field_cls(*args, **kwargs)``.
194+
195+
Raises
196+
------
197+
:exc:`TypeError`
198+
If the instance returned by ``field_cls(*args, **kwargs)`` doesn't have a signature
199+
with a member named "port" that is a :class:`FieldPort.Signature` with a
200+
:attr:`wiring.In` direction.
186201
"""
187-
return self._field_cls(*self._args, **self._kwargs)
202+
field = self._field_cls(*self._args, **self._kwargs)
203+
if not ("port" in field.signature.members
204+
and field.signature.members["port"].flow is In
205+
and field.signature.members["port"].is_signature
206+
and isinstance(field.signature.members["port"].signature, FieldPort.Signature)):
207+
raise TypeError(f"{self._field_cls.__qualname__} instance signature must have a "
208+
f"csr.FieldPort.Signature member named 'port' and oriented as In")
209+
return field
188210

189211

190212
class FieldMap(Mapping):

tests/test_csr_reg.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,20 @@ def test_wrong_signature(self):
120120

121121

122122
class FieldTestCase(unittest.TestCase):
123+
def test_wrong_class(self):
124+
class Foo:
125+
pass
126+
with self.assertRaisesRegex(TypeError,
127+
r"Foo must be a subclass of wiring.Component"):
128+
Field(Foo)
129+
123130
def test_create(self):
124131
class MockField(wiring.Component):
125132
def __init__(self, shape, *, reset):
126-
super().__init__({"port": Out(FieldPort.Signature(shape, "rw"))})
133+
super().__init__({
134+
"port": In(FieldPort.Signature(shape, "rw")),
135+
"data": Out(shape)
136+
})
127137
self.reset = reset
128138

129139
def elaborate(self, platform):
@@ -135,7 +145,7 @@ def elaborate(self, platform):
135145

136146
def test_create_multiple(self):
137147
class MockField(wiring.Component):
138-
port: Out(FieldPort.Signature(unsigned(8), "rw"))
148+
port: In(FieldPort.Signature(unsigned(8), "rw"))
139149

140150
def elaborate(self, platform):
141151
return Module()
@@ -144,6 +154,58 @@ def elaborate(self, platform):
144154
field_2 = Field(MockField).create()
145155
self.assertIsNot(field_1, field_2)
146156

157+
def test_wrong_port_name(self):
158+
class MockField(wiring.Component):
159+
foo: In(FieldPort.Signature(unsigned(1), access="rw"))
160+
161+
def elaborate(self, platform):
162+
return Module()
163+
164+
field = Field(MockField)
165+
with self.assertRaisesRegex(TypeError,
166+
r"MockField instance signature must have a csr\.FieldPort\.Signature member named "
167+
r"'port' and oriented as In"):
168+
field.create()
169+
170+
def test_wrong_port_direction(self):
171+
class MockField(wiring.Component):
172+
port: Out(FieldPort.Signature(unsigned(1), access="rw"))
173+
174+
def elaborate(self, platform):
175+
return Module()
176+
177+
field = Field(MockField)
178+
with self.assertRaisesRegex(TypeError,
179+
r"MockField instance signature must have a csr\.FieldPort\.Signature member named "
180+
r"'port' and oriented as In"):
181+
field.create()
182+
183+
def test_wrong_port_type_port(self):
184+
class MockField(wiring.Component):
185+
port: In(unsigned(1))
186+
187+
def elaborate(self, platform):
188+
return Module()
189+
190+
field = Field(MockField)
191+
with self.assertRaisesRegex(TypeError,
192+
r"MockField instance signature must have a csr\.FieldPort\.Signature member named "
193+
r"'port' and oriented as In"):
194+
field.create()
195+
196+
def test_wrong_port_type_signature(self):
197+
class MockField(wiring.Component):
198+
port: In(wiring.Signature({"foo": Out(unsigned(1))}))
199+
200+
def elaborate(self, platform):
201+
return Module()
202+
203+
field = Field(MockField)
204+
with self.assertRaisesRegex(TypeError,
205+
r"MockField instance signature must have a csr\.FieldPort\.Signature member named "
206+
r"'port' and oriented as In"):
207+
field.create()
208+
147209

148210
class FieldMapTestCase(unittest.TestCase):
149211
def test_simple(self):

0 commit comments

Comments
 (0)