Skip to content

Commit e34e0c0

Browse files
bo3zvloncar
authored andcommitted
Quartus streaming batch normalisation
1 parent d9e2ce7 commit e34e0c0

File tree

4 files changed

+99
-30
lines changed

4 files changed

+99
-30
lines changed

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def format(self, node):
8080

8181
batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
8282

83-
batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h']
83+
batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h']
8484

8585
class BatchNormalizationConfigTemplate(LayerConfigTemplate):
8686
def __init__(self):

hls4ml/templates/quartus/firmware/nnet_utils/nnet_batchnorm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ void normalize(
6060
Result:
6161
#pragma unroll
6262
for (int ires = 0; ires < CONFIG_T::n_in; ires++) {
63-
// TODO - Explore MULADD instruction in HLS - less clock cycles
6463
if (CONFIG_T::n_filt==-1) {
6564
res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[ires]) + bias[ires];
6665
} else {
Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,100 @@
1-
//
2-
// rfnoc-hls-neuralnet: Vivado HLS code for neural-net building blocks
3-
//
4-
// Copyright (C) 2017 EJ Kreinar
5-
//
6-
// This program is free software: you can redistribute it and/or modify
7-
// it under the terms of the GNU General Public License as published by
8-
// the Free Software Foundation, either version 3 of the License, or
9-
// (at your option) any later version.
10-
//
11-
// This program is distributed in the hope that it will be useful,
12-
// but WITHOUT ANY WARRANTY; without even the implied warranty of
13-
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14-
// GNU General Public License for more details.
15-
//
16-
// You should have received a copy of the GNU General Public License
17-
// along with this program. If not, see <http://www.gnu.org/licenses/>.
18-
//
19-
20-
/*
21-
* PLACEHOLDER - The common pass bn_quant.py includes both parallel and streaming BN; streaming is currently not supported in Quartus
22-
*/
23-
241
#ifndef NNET_BATCHNORM_STREAM_H_
252
#define NNET_BATCHNORM_STREAM_H_
263

274
#include "nnet_common.h"
285
#include "nnet_helpers.h"
296
#include "nnet_mult.h"
7+
#include "nnet_types.h"
8+
9+
namespace nnet {
10+
11+
// ****************************************************
12+
// Streaming Batch Normalization
13+
// ****************************************************
14+
template<class data_T, class res_T, typename CONFIG_T>
15+
void normalize(
16+
stream<data_T> &data,
17+
stream<res_T> &res,
18+
const typename CONFIG_T::scale_t scale[CONFIG_T::n_in],
19+
const typename CONFIG_T::bias_t bias[CONFIG_T::n_in]
20+
) {
21+
22+
constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor);
23+
constexpr unsigned pipeline = CONFIG_T::n_in / multiplier_limit;
24+
CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::limit(multiplier_limit);
25+
26+
BatchNormLoop:
27+
#pragma ii pipeline
28+
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
29+
data_T in_data = data.read();
30+
res_T out_data;
31+
32+
BatchNormpack:
33+
#pragma unroll
34+
for (int j = 0; j < data_T::size; j++) {
35+
int norm_index;
36+
if (CONFIG_T::n_filt==-1) norm_index = i * data_T::size + j;
37+
else norm_index = j % CONFIG_T::n_filt;
38+
out_data[j] = CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::product(in_data[j], scale[norm_index]) + bias[norm_index];
39+
}
40+
41+
res.write(out_data);
42+
}
43+
}
44+
45+
// ****************************************************
46+
// Merged Batch Normalization and Quantized Tanh
47+
// ****************************************************
48+
template<class data_T, typename CONFIG_T>
49+
void normalize_binary_tanh(
50+
stream<data_T> &data,
51+
stream<nnet::array<ac_int<1, false>, CONFIG_T::n_in>> &res,
52+
const typename data_T::value_type threshold[CONFIG_T::n_in]
53+
) {
54+
55+
BinaryNormLoop:
56+
#pragma ii 1
57+
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
58+
data_T in_data = data.read();
59+
nnet::array<ac_int<1, false>, CONFIG_T::n_in> out_data;
60+
61+
BatchNormPack:
62+
#pragma unroll
63+
for (int j = 0; j < data_T::size; j++) {
64+
out_data[j] = (in_data[j] > threshold[i * data_T::size + j]) ? 1 : 0;
65+
}
66+
67+
res.write(out_data);
68+
}
69+
}
70+
71+
template<class data_T, typename CONFIG_T>
72+
void normalize_ternary_tanh(
73+
stream<data_T> &data,
74+
stream<nnet::array<ac_int<2, true>, CONFIG_T::n_in>> &res,
75+
const typename data_T::value_type threshold_hi[CONFIG_T::n_in],
76+
const typename data_T::value_type threshold_lo[CONFIG_T::n_in]
77+
) {
78+
79+
TernaryNormLoop:
80+
#pragma ii 1
81+
for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
82+
data_T in_data = data.read();
83+
nnet::array<ac_int<2, true>, CONFIG_T::n_in> out_data;
84+
85+
BatchNormPack:
86+
#pragma unroll
87+
for (int j = 0; j < data_T::size; j++) {
88+
int norm_index = i * data_T::size + j;
89+
if (in_data[j] > threshold_hi[norm_index]) out_data[j] = 1;
90+
else if (in_data[j] <= threshold_lo[norm_index]) out_data[j] = -1;
91+
else out_data[j] = 0;
92+
}
93+
94+
res.write(out_data);
95+
}
96+
}
3097

31-
namespace nnet {}
98+
}
3299

33100
#endif

test/pytest/test_batchnorm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,20 @@ def model():
2424

2525

2626
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
27-
def test_global_pool1d(model, data, io_type):
27+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
28+
def test_global_pool1d(model, data, backend, io_type):
29+
30+
default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'
2831

2932
config = hls4ml.utils.config_from_keras_model(model,
30-
default_precision='ap_fixed<32,1>',
33+
default_precision=default_precision,
3134
granularity='name')
3235

3336
hls_model = hls4ml.converters.convert_from_keras_model(model,
37+
backend=backend,
3438
hls_config=config,
3539
io_type=io_type,
36-
output_dir=f'hls4mlprj_batchnorm_{io_type}',
37-
part='xcvu9p-flgb2104-2-i')
40+
output_dir=f'hls4mlprj_batchnorm_{backend}_{io_type}')
3841
hls_model.compile()
3942

4043

0 commit comments

Comments
 (0)