Skip to content

Commit 6543df3

Browse files
committed
MulAdd operator
1 parent 232be1b commit 6543df3

File tree

4 files changed

+301
-0
lines changed

4 files changed

+301
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
TARGET=$(shell ls *.py | grep -v test | grep -v parsetab.py)
2+
ARGS=
3+
4+
PYTHON=python3
5+
#PYTHON=python
6+
#OPT=-m pdb
7+
#OPT=-m cProfile -s time
8+
#OPT=-m cProfile -o profile.rslt
9+
10+
.PHONY: all
11+
all: test
12+
13+
.PHONY: run
14+
run:
15+
$(PYTHON) $(OPT) $(TARGET) $(ARGS)
16+
17+
.PHONY: test
18+
test:
19+
$(PYTHON) -m pytest -vv
20+
21+
.PHONY: check
22+
check:
23+
$(PYTHON) $(OPT) $(TARGET) $(ARGS) > tmp.v
24+
iverilog -tnull -Wall tmp.v
25+
rm -f tmp.v
26+
27+
.PHONY: clean
28+
clean:
29+
rm -rf *.pyc __pycache__ parsetab.py .cache *.out *.png *.dot tmp.v uut.vcd
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import absolute_import
2+
from __future__ import print_function
3+
4+
import os
5+
import veriloggen
6+
import thread_stream_muladd
7+
8+
9+
def test(request):
10+
veriloggen.reset()
11+
12+
simtype = request.config.getoption('--sim')
13+
14+
rslt = thread_stream_muladd.run(filename=None, simtype=simtype,
15+
outputfile=os.path.splitext(os.path.basename(__file__))[0] + '.out')
16+
17+
verify_rslt = rslt.splitlines()[-1]
18+
assert(verify_rslt == '# verify: PASSED')
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from __future__ import absolute_import
2+
from __future__ import print_function
3+
import sys
4+
import os
5+
6+
# the next line can be removed after installation
7+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(
8+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
9+
10+
from veriloggen import *
11+
import veriloggen.thread as vthread
12+
import veriloggen.types.axi as axi
13+
14+
15+
def mkLed():
16+
m = Module('blinkled')
17+
clk = m.Input('CLK')
18+
rst = m.Input('RST')
19+
20+
datawidth = 32
21+
addrwidth = 10
22+
myaxi = vthread.AXIM(m, 'myaxi', clk, rst, datawidth)
23+
ram_a = vthread.RAM(m, 'ram_a', clk, rst, datawidth, addrwidth)
24+
ram_b = vthread.RAM(m, 'ram_b', clk, rst, datawidth, addrwidth)
25+
ram_c = vthread.RAM(m, 'ram_c', clk, rst, datawidth, addrwidth)
26+
27+
strm = vthread.Stream(m, 'mystream', clk, rst)
28+
a = strm.source('a')
29+
b = strm.source('b')
30+
c = strm.MulAdd(a, b, b)
31+
strm.sink(c, 'c')
32+
33+
def comp_stream(size, offset):
34+
strm.set_source('a', ram_a, offset, size)
35+
strm.set_source('b', ram_b, offset, size)
36+
strm.set_sink('c', ram_c, offset, size)
37+
strm.run()
38+
strm.join()
39+
40+
def comp_sequential(size, offset):
41+
sum = 0
42+
for i in range(size):
43+
a = ram_a.read(i + offset)
44+
b = ram_b.read(i + offset)
45+
sum = a * b + b
46+
ram_c.write(i + offset, sum)
47+
48+
def check(size, offset_stream, offset_seq):
49+
all_ok = True
50+
for i in range(size):
51+
st = ram_c.read(i + offset_stream)
52+
sq = ram_c.read(i + offset_seq)
53+
if vthread.verilog.NotEql(st, sq):
54+
all_ok = False
55+
if all_ok:
56+
print('# verify: PASSED')
57+
else:
58+
print('# verify: FAILED')
59+
60+
def comp(size):
61+
# stream
62+
offset = 0
63+
myaxi.dma_read(ram_a, offset, 0, size)
64+
myaxi.dma_read(ram_b, offset, 512, size)
65+
comp_stream(size, offset)
66+
myaxi.dma_write(ram_c, offset, 1024, size)
67+
68+
# sequential
69+
offset = size
70+
myaxi.dma_read(ram_a, offset, 0, size)
71+
myaxi.dma_read(ram_b, offset, 512, size)
72+
comp_sequential(size, offset)
73+
myaxi.dma_write(ram_c, offset, 1024 * 2, size)
74+
75+
# verification
76+
check(size, 0, offset)
77+
78+
vthread.finish()
79+
80+
th = vthread.Thread(m, 'th_comp', clk, rst, comp)
81+
fsm = th.start(32)
82+
83+
return m
84+
85+
86+
def mkTest(memimg_name=None):
87+
m = Module('test')
88+
89+
# target instance
90+
led = mkLed()
91+
92+
# copy paras and ports
93+
params = m.copy_params(led)
94+
ports = m.copy_sim_ports(led)
95+
96+
clk = ports['CLK']
97+
rst = ports['RST']
98+
99+
memory = axi.AxiMemoryModel(m, 'memory', clk, rst, memimg_name=memimg_name)
100+
memory.connect(ports, 'myaxi')
101+
102+
uut = m.Instance(led, 'uut',
103+
params=m.connect_params(led),
104+
ports=m.connect_ports(led))
105+
106+
#simulation.setup_waveform(m, uut)
107+
simulation.setup_clock(m, clk, hperiod=5)
108+
init = simulation.setup_reset(m, rst, m.make_reset(), period=100)
109+
110+
init.add(
111+
Delay(1000000),
112+
Systask('finish'),
113+
)
114+
115+
return m
116+
117+
118+
def run(filename='tmp.v', simtype='iverilog', outputfile=None):
119+
120+
if outputfile is None:
121+
outputfile = os.path.splitext(os.path.basename(__file__))[0] + '.out'
122+
123+
memimg_name = 'memimg_' + outputfile
124+
125+
test = mkTest(memimg_name=memimg_name)
126+
127+
if filename is not None:
128+
test.to_verilog(filename)
129+
130+
sim = simulation.Simulator(test, sim=simtype)
131+
rslt = sim.run(outputfile=outputfile)
132+
lines = rslt.splitlines()
133+
if simtype == 'verilator' and lines[-1].startswith('-'):
134+
rslt = '\n'.join(lines[:-1])
135+
return rslt
136+
137+
138+
if __name__ == '__main__':
139+
rslt = run(filename='tmp.v')
140+
print(rslt)

veriloggen/stream/stypes.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from . import mul
1313
from . import div
14+
from . import madd
1415

1516

1617
# Object ID counter for object sorting key
@@ -1624,6 +1625,109 @@ def _implement(self, m, seq, svalid=None, senable=None):
16241625
m.Instance(inst, self.name('lut'), ports=ports)
16251626

16261627

1628+
class MulAdd(_SpecialOperator):
1629+
latency = 6 + 1
1630+
1631+
def __init__(self, a, b, c):
1632+
_SpecialOperator.__init__(self, a, b, c)
1633+
1634+
@property
1635+
def a(self):
1636+
return self.args[0]
1637+
1638+
@a.setter
1639+
def a(self, a):
1640+
self.args[0] = a
1641+
1642+
@property
1643+
def b(self):
1644+
return self.args[1]
1645+
1646+
@b.setter
1647+
def b(self, b):
1648+
self.args[1] = b
1649+
1650+
@property
1651+
def c(self):
1652+
return self.args[2]
1653+
1654+
@c.setter
1655+
def c(self, c):
1656+
self.args[2] = c
1657+
1658+
def _set_attributes(self):
1659+
a_fp = self.a.get_point()
1660+
b_fp = self.b.get_point()
1661+
c_fp = self.c.get_point()
1662+
a = self.a.bit_length()
1663+
b = self.b.bit_length()
1664+
c = self.c.bit_length()
1665+
self.width = max(a, b, c)
1666+
self.point = a_fp + b_fp
1667+
self.signed = self.a.get_signed() and self.b.get_signed() and self.c.get_signed()
1668+
1669+
def _implement(self, m, seq, svalid=None, senable=None):
1670+
if self.latency <= 3:
1671+
raise ValueError("Latency of '*' operator must be greater than 3")
1672+
1673+
width = self.bit_length()
1674+
signed = self.get_signed()
1675+
1676+
apoint = self.a.get_point()
1677+
bpoint = self.b.get_point()
1678+
cpoint = self.c.get_point()
1679+
awidth = self.a.bit_length()
1680+
bwidth = self.b.bit_length()
1681+
cwidth = self.c.bit_length()
1682+
asigned = self.a.get_signed()
1683+
bsigned = self.b.get_signed()
1684+
csigned = self.c.get_signed()
1685+
adata = self.a.sig_data
1686+
bdata = self.b.sig_data
1687+
cdata = self.c.sig_data
1688+
1689+
if apoint + bpoint != cpoint:
1690+
raise ValueError('apoint + bpoint == cpoint')
1691+
1692+
odata = m.Wire(self.name('madd_odata'),
1693+
max(awidth + bwidth, cwidth), signed=signed)
1694+
odata_reg = m.Reg(self.name('madd_odata_reg'),
1695+
max(awidth + bwidth, cwidth), signed=signed, initval=0)
1696+
1697+
data = m.Wire(self.name('data'), width, signed=signed)
1698+
self.sig_data = data
1699+
1700+
seq(odata_reg(odata), cond=senable)
1701+
1702+
m.Assign(data(odata_reg))
1703+
1704+
depth = self.latency - 1
1705+
1706+
inst = madd.get_madd(awidth, bwidth, cwidth,
1707+
asigned, bsigned, csigned, depth)
1708+
clk = m._clock
1709+
1710+
update = m.Wire(self.name('madd_update'))
1711+
1712+
if senable is not None:
1713+
m.Assign(update(senable))
1714+
else:
1715+
m.Assign(update(1))
1716+
1717+
ports = [('CLK', clk), ('update', update),
1718+
('a', adata), ('b', bdata), ('c', cdata), ('d', odata)]
1719+
1720+
m.Instance(inst, self.name('madd'), ports=ports)
1721+
1722+
def eval(self):
1723+
vars = [var.eval() for var in self.vars]
1724+
for var in vars:
1725+
if not isinstance(var, int):
1726+
return MulAdd(*vars)
1727+
1728+
return vars[0] * vars[1] + vars[2]
1729+
1730+
16271731
class PlusN(_SpecialOperator):
16281732
latency = 1
16291733

@@ -1638,6 +1742,16 @@ def func(*args):
16381742

16391743
self.op = func
16401744

1745+
def eval(self):
1746+
vars = [var.eval() for var in self.vars]
1747+
for var in vars:
1748+
if not isinstance(var, int):
1749+
return PlusN(*vars)
1750+
ret = 0
1751+
for var in vars:
1752+
ret += var
1753+
return ret
1754+
16411755

16421756
def AddN(*vars):
16431757
return PlusN(*vars)

0 commit comments

Comments
 (0)