Skip to content

Commit 82ed368

Browse files
committed
[Trunc] add v1 and v2 versions of the op separately
1 parent efdc74a commit 82ed368

File tree

2 files changed

+87
-11
lines changed

2 files changed

+87
-11
lines changed

src/qonnx/custom_op/general/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC
3636
from qonnx.custom_op.general.multithreshold import MultiThreshold
3737
from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d
38-
from qonnx.custom_op.general.trunc import Trunc
38+
from qonnx.custom_op.general.trunc import Trunc_v1, Trunc_v2
3939
from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul
4040

4141
custom_op = dict()
@@ -49,7 +49,7 @@
4949
custom_op["Im2Col"] = Im2Col
5050
custom_op["IntQuant"] = IntQuant
5151
custom_op["Quant"] = IntQuant
52-
custom_op["Trunc"] = Trunc
52+
custom_op["Trunc"] = Trunc_v1
5353
custom_op["BipolarQuant"] = BipolarQuant
5454
custom_op["FloatQuant"] = FloatQuant
5555

@@ -62,6 +62,8 @@
6262
custom_op["Im2Col_v1"] = Im2Col
6363
custom_op["IntQuant_v1"] = IntQuant
6464
custom_op["Quant_v1"] = IntQuant
65-
custom_op["Trunc_v1"] = Trunc
65+
custom_op["Trunc_v1"] = Trunc_v1
6666
custom_op["BipolarQuant_v1"] = BipolarQuant
6767
custom_op["FloatQuant_v1"] = FloatQuant
68+
69+
custom_op["Trunc_v2"] = Trunc_v2

src/qonnx/custom_op/general/trunc.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
from qonnx.core.datatype import DataType
3333
from qonnx.custom_op.base import CustomOp
3434
from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode
35+
from qonnx.util.basic import get_preferred_qonnx_opset
3536

3637

37-
def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
38+
def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode):
3839
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
3940

4041
# Scaling
@@ -65,18 +66,23 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_sca
6566
return y
6667

6768

68-
class Trunc(CustomOp):
69-
"""Generic truncation operation for QONNX. Takes four inputs:
70-
- input tensor to truncate
71-
- the scale
72-
- the zero-point
73-
- the truncation scale
69+
class Trunc_v2(CustomOp):
70+
"""Generic truncation operation for QONNX. Takes four inputs:
71+
- input tensor to truncate
72+
- the scale
73+
- the zero-point
74+
- the truncation scale
7475
- the truncation bit-width
7576
7677
The output is a tensor of the same shape as the input tensor, with truncated
7778
values.
7879
"""
7980

81+
def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()):
82+
super().__init__(onnx_node, onnx_opset_version)
83+
# override any specified opset version, this instance is v2
84+
self.onnx_opset_version = 2
85+
8086
def get_nodeattr_types(self):
8187
return {
8288
# The rounding mode, which is used for the trunc function
@@ -107,11 +113,79 @@ def execute_node(self, context, graph):
107113
narrow = self.get_nodeattr("narrow")
108114
signed = self.get_nodeattr("signed")
109115
# calculate output
110-
ret = trunc(
116+
ret = trunc_v2(
111117
inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode
112118
)
113119
# set context according to output name
114120
context[node.output[0]] = ret
115121

116122
def verify_node(self):
117123
pass
124+
125+
126+
def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
127+
# Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR
128+
129+
# Scaling
130+
y = inp_tensor / scale
131+
y = y + zeropt
132+
# Rounding
133+
y = np.round(y)
134+
# Truncate
135+
trunc_bit_width = input_bit_width - output_bit_width
136+
trunc_scale = 2.0**trunc_bit_width
137+
y = y / trunc_scale
138+
139+
# To int
140+
rounding_fx = resolve_rounding_mode(rounding_mode)
141+
y = rounding_fx(y)
142+
143+
# Rescale
144+
y = y - zeropt
145+
y = y * scale
146+
147+
return y
148+
149+
150+
class Trunc_v1(CustomOp):
151+
"""Generic truncation operation for QONNX. Takes four inputs:
152+
- input tensor to truncate
153+
- the scale
154+
- the zero-point
155+
- the truncation bit-width
156+
157+
The output is a tensor of the same shape as the input tensor, with truncated
158+
values.
159+
"""
160+
161+
def get_nodeattr_types(self):
162+
return {
163+
# The rounding mode, which is used for the trunc function
164+
"rounding_mode": ("s", True, "FLOOR"),
165+
}
166+
167+
def make_shape_compatible_op(self, model):
168+
node = self.onnx_node
169+
return helper.make_node("Identity", [node.input[0]], [node.output[0]])
170+
171+
def infer_node_datatype(self, model):
172+
node = self.onnx_node
173+
model.set_tensor_datatype(node.output[0], DataType["FLOAT32"])
174+
175+
def execute_node(self, context, graph):
176+
node = self.onnx_node
177+
# save inputs
178+
inp_tensor = context[node.input[0]]
179+
scale = context[node.input[1]]
180+
zeropt = context[node.input[2]]
181+
input_bit_width = context[node.input[3]]
182+
output_bit_width = context[node.input[4]]
183+
# save attributes
184+
rounding_mode = self.get_nodeattr("rounding_mode")
185+
# calculate output
186+
ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode)
187+
# set context according to output name
188+
context[node.output[0]] = ret
189+
190+
def verify_node(self):
191+
pass

0 commit comments

Comments
 (0)