33import sys
44import os
55import numpy as np
6+ import math
67
78# the next line can be removed after installation
89sys .path .insert (0 , os .path .dirname (os .path .dirname (
1213import veriloggen .thread as vthread
1314import veriloggen .types .axi as axi
1415
15- datawidth = 8
16+ mem_datawidth = 8
17+ datawidth = 16
1618addrwidth = 8
1719
18- matrix_size = 8
20+ matrix_size = 10
21+
22+ num_pack = math .ceil (datawidth / mem_datawidth )
23+ addr_pack = math .ceil ((addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )))
24+ / mem_datawidth )
25+
26+ matrix_size_addr = 0
27+ a_offset_addr = 4
28+ b_offset_addr = 8
29+ c_offset_addr = 12
1930a_offset = 16
20- b_offset = a_offset + matrix_size * matrix_size
21- c_offset = b_offset + matrix_size * matrix_size
31+ b_offset = a_offset + matrix_size * matrix_size * num_pack
32+ c_offset = b_offset + matrix_size * matrix_size * num_pack
2233
2334
2435def mkLed ():
@@ -28,7 +39,8 @@ def mkLed():
2839 start = m .Input ('start' )
2940 busy = m .OutputReg ('busy' , initval = 0 )
3041
31- ram = vthread .ExtRAM (m , 'ram' , clk , rst , datawidth , addrwidth )
42+ ram = vthread .ExtRAM (m , 'ram' , clk , rst , mem_datawidth ,
43+ addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )))
3244
3345 def matmul ():
3446 while True :
@@ -46,27 +58,31 @@ def wait():
4658 busy .value = 1
4759
4860 def read_matrix_size ():
49- size0 = ram .read (0 )
50- size1 = ram .read (1 )
51- size = (size1 << 8 ) | size0
61+ size = 0
62+ for i in range (addr_pack ):
63+ size |= ((ram .read (matrix_size_addr + i ) & ((1 << mem_datawidth ) - 1 ))
64+ << (mem_datawidth * i ))
5265 return size
5366
5467 def read_matrix_a_offset ():
55- offset0 = ram .read (4 ) & 0xff
56- offset1 = ram .read (5 ) & 0xff
57- offset = (offset1 << 8 ) | offset0
68+ offset = 0
69+ for i in range (addr_pack ):
70+ offset |= ((ram .read (a_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
71+ << (mem_datawidth * i ))
5872 return offset
5973
6074 def read_matrix_b_offset ():
61- offset0 = ram .read (8 ) & 0xff
62- offset1 = ram .read (9 ) & 0xff
63- offset = (offset1 << 8 ) | offset0
75+ offset = 0
76+ for i in range (addr_pack ):
77+ offset |= ((ram .read (b_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
78+ << (mem_datawidth * i ))
6479 return offset
6580
6681 def read_matrix_c_offset ():
67- offset0 = ram .read (12 ) & 0xff
68- offset1 = ram .read (13 ) & 0xff
69- offset = (offset1 << 8 ) | offset0
82+ offset = 0
83+ for i in range (addr_pack ):
84+ offset |= ((ram .read (c_offset_addr + i ) & ((1 << mem_datawidth ) - 1 ))
85+ << (mem_datawidth * i ))
7086 return offset
7187
7288 def comp (matrix_size , a_offset , b_offset , c_offset ):
@@ -77,15 +93,24 @@ def comp(matrix_size, a_offset, b_offset, c_offset):
7793 for j in range (matrix_size ):
7894 sum = 0
7995 for k in range (matrix_size ):
80- x = ram .read (a_addr + k )
81- y = ram .read (b_addr + k )
96+ x = int (0 , base = 2 )
97+ y = 0
98+ for l in range (num_pack ):
99+ x |= ((ram .read (a_addr + k * num_pack + l )
100+ & ((1 << mem_datawidth ) - 1 ))
101+ << (mem_datawidth * l ))
102+ y |= ((ram .read (b_addr + k * num_pack + l )
103+ & ((1 << mem_datawidth ) - 1 ))
104+ << (mem_datawidth * l ))
82105 sum += x * y
83- ram .write (c_addr + j , sum )
106+ for l in range (num_pack ):
107+ ram .write (c_addr + j * num_pack + l ,
108+ (sum >> (mem_datawidth * l )) & (1 << mem_datawidth )- 1 )
84109
85- b_addr += matrix_size * ( datawidth // 8 )
110+ b_addr += matrix_size * num_pack
86111
87- a_addr += matrix_size * ( datawidth // 8 )
88- c_addr += matrix_size * ( datawidth // 8 )
112+ a_addr += matrix_size * num_pack
113+ c_addr += matrix_size * num_pack
89114
90115 def done ():
91116 busy .value = 0
@@ -128,13 +153,11 @@ def mkTest(memimg_name=None):
128153 b [y ][x ] = 0
129154
130155 a_addr = a_offset
131- size_a = n_a * datawidth // 8
132156 b_addr = b_offset
133- size_b = n_b * datawidth // 8
134157
135- mem = np .zeros ([2 ** addrwidth * ( 8 // datawidth ) ], dtype = np .int64 )
136- axi .set_memory (mem , a , datawidth , datawidth , a_addr )
137- axi .set_memory (mem , b , datawidth , datawidth , b_addr )
158+ mem = np .zeros ([( 2 ** addrwidth ) * num_pack ], dtype = np .int64 )
159+ axi .set_memory (mem , a , mem_datawidth , datawidth , a_addr )
160+ axi .set_memory (mem , b , mem_datawidth , datawidth , b_addr )
138161
139162 led = mkLed ()
140163
@@ -149,7 +172,8 @@ def mkTest(memimg_name=None):
149172
150173 start .initval = 0
151174
152- memory = vthread .RAM (m , 'memory' , clk , rst , datawidth , addrwidth ,
175+ memory = vthread .RAM (m , 'memory' , clk , rst , mem_datawidth ,
176+ addrwidth + math .ceil (np .log2 (datawidth / mem_datawidth )),
153177 numports = 2 , initvals = mem .tolist ())
154178 memory .connect_rtl (0 , ports ['ram_0_addr' ], ports ['ram_0_wdata' ],
155179 ports ['ram_0_wenable' ], ports ['ram_0_rdata' ],
@@ -166,45 +190,33 @@ def ctrl():
166190 for i in range (100 ):
167191 pass
168192
169- awaddr = 0
170- v = (matrix_size & 0xff )
171- print ('# matrix_size[7:0] = %d' % v )
172- memory .write (awaddr , v , port = 1 )
173-
174- awaddr = 1
175- v = ((matrix_size >> 8 ) & 0xff )
176- print ('# matrix_size[15:8] = %d' % v )
177- memory .write (awaddr , v , port = 1 )
178-
179- awaddr = 4
180- v = (a_offset & 0xff )
181- print ('# a_offset[7:0] = %d' % v )
182- memory .write (awaddr , v , port = 1 )
183-
184- awaddr = 5
185- v = ((a_offset >> 8 ) & 0xff )
186- print ('# a_offset[15:8] = %d' % v )
187- memory .write (awaddr , v , port = 1 )
188-
189- awaddr = 8
190- v = (b_offset & 0xff )
191- print ('# b_offset[7:0] = %d' % v )
192- memory .write (awaddr , v , port = 1 )
193-
194- awaddr = 9
195- v = ((b_offset >> 8 ) & 0xff )
196- print ('# b_offset[15:8] = %d' % v )
197- memory .write (awaddr , v , port = 1 )
198-
199- awaddr = 12
200- v = (c_offset & 0xff )
201- print ('# c_offset[7:0] = %d' % v )
202- memory .write (awaddr , v , port = 1 )
203-
204- awaddr = 13
205- v = ((c_offset >> 8 ) & 0xff )
206- print ('# c_offset[15:8] = %d' % v )
207- memory .write (awaddr , v , port = 1 )
193+ for i in range (addr_pack ):
194+ awaddr = matrix_size_addr + i
195+ v = (matrix_size >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
196+ print ('# matrix_size[%d:%d] = %d' %
197+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
198+ memory .write (awaddr , v , port = 1 )
199+
200+ for i in range (addr_pack ):
201+ awaddr = a_offset_addr + i
202+ v = (a_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
203+ print ('# a_offset[%d:%d] = %d' %
204+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
205+ memory .write (awaddr , v , port = 1 )
206+
207+ for i in range (addr_pack ):
208+ awaddr = b_offset_addr + i
209+ v = (b_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
210+ print ('# b_offset[%d:%d] = %d' %
211+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
212+ memory .write (awaddr , v , port = 1 )
213+
214+ for i in range (addr_pack ):
215+ awaddr = c_offset_addr + i
216+ v = (c_offset >> (mem_datawidth * i )) & ((1 << mem_datawidth ) - 1 )
217+ print ('# c_offset[%d:%d] = %d' %
218+ (mem_datawidth * (i + 1 ) - 1 , mem_datawidth * i , v ))
219+ memory .write (awaddr , v , port = 1 )
208220
209221 start_time = counter
210222 print ('# start time = %d' % start_time )
@@ -227,14 +239,19 @@ def ctrl():
227239 all_ok = True
228240 for y in range (matrix_size ):
229241 for x in range (matrix_size ):
230- v = memory .read (
231- c_offset + (y * matrix_size + x ) * datawidth // 8 , port = 1 )
242+ v = 0
243+ v_addr = c_offset + (y * matrix_size + x ) * num_pack
244+ for l in range (num_pack ):
245+ v |= memory .read (v_addr + l , port = 1 ) << (mem_datawidth * l )
246+ v |= ((memory .read (v_addr + l , port = 1 )
247+ & ((1 << mem_datawidth ) - 1 ))
248+ << (mem_datawidth * l ))
232249 if y == x and vthread .verilog .NotEql (v , (y + 1 ) * 2 ):
233250 all_ok = False
234- print ("NG [%d,%d] = %d" % (y , x , v ))
251+ print ("NG [%d,%d] = %d (expected: %d) " % (y , x , v , ( y + 1 ) * 2 ))
235252 if y != x and vthread .verilog .NotEql (v , 0 ):
236253 all_ok = False
237- print ("NG [%d,%d] = %d" % (y , x , v ))
254+ print ("NG [%d,%d] = %d (expected: %d) " % (y , x , v , 0 ))
238255
239256 if all_ok :
240257 print ('# verify: PASSED' )
0 commit comments