|
32 | 32 | from qonnx.core.datatype import DataType |
33 | 33 | from qonnx.custom_op.base import CustomOp |
34 | 34 | from qonnx.custom_op.general.quant import max_int, min_int, resolve_rounding_mode |
| 35 | +from qonnx.util.basic import get_preferred_qonnx_opset |
35 | 36 |
|
36 | 37 |
|
37 | | -def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): |
| 38 | +def trunc_v2(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode): |
38 | 39 | # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR |
39 | 40 |
|
40 | 41 | # Scaling |
@@ -65,18 +66,23 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_sca |
65 | 66 | return y |
66 | 67 |
|
67 | 68 |
|
68 | | -class Trunc(CustomOp): |
69 | | - """Generic truncation operation for QONNX. Takes four inputs: |
70 | | - - input tensor to truncate |
71 | | - - the scale |
72 | | - - the zero-point |
73 | | - - the truncation scale |
| 69 | +class Trunc_v2(CustomOp): |
| 70 | + """Generic truncation operation for QONNX. Takes four inputs: |
| 71 | + - input tensor to truncate |
| 72 | + - the scale |
| 73 | + - the zero-point |
| 74 | + - the truncation scale |
74 | 75 | - the truncation bit-width |
75 | 76 |
|
76 | 77 | The output is a tensor of the same shape as the input tensor, with truncated |
77 | 78 | values. |
78 | 79 | """ |
79 | 80 |
|
| 81 | + def __init__(self, onnx_node, onnx_opset_version=get_preferred_qonnx_opset()): |
| 82 | + super().__init__(onnx_node, onnx_opset_version) |
| 83 | + # override any specified opset version, this instance is v2 |
| 84 | + self.onnx_opset_version = 2 |
| 85 | + |
80 | 86 | def get_nodeattr_types(self): |
81 | 87 | return { |
82 | 88 | # The rounding mode, which is used for the trunc function |
@@ -107,11 +113,79 @@ def execute_node(self, context, graph): |
107 | 113 | narrow = self.get_nodeattr("narrow") |
108 | 114 | signed = self.get_nodeattr("signed") |
109 | 115 | # calculate output |
110 | | - ret = trunc( |
| 116 | + ret = trunc_v2( |
111 | 117 | inp_tensor, scale, zeropt, input_bit_width, narrow, signed, output_scale, output_bit_width, rounding_mode |
112 | 118 | ) |
113 | 119 | # set context according to output name |
114 | 120 | context[node.output[0]] = ret |
115 | 121 |
|
116 | 122 | def verify_node(self): |
117 | 123 | pass |
| 124 | + |
| 125 | + |
| 126 | +def trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): |
| 127 | + # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR |
| 128 | + |
| 129 | + # Scaling |
| 130 | + y = inp_tensor / scale |
| 131 | + y = y + zeropt |
| 132 | + # Rounding |
| 133 | + y = np.round(y) |
| 134 | + # Truncate |
| 135 | + trunc_bit_width = input_bit_width - output_bit_width |
| 136 | + trunc_scale = 2.0**trunc_bit_width |
| 137 | + y = y / trunc_scale |
| 138 | + |
| 139 | + # To int |
| 140 | + rounding_fx = resolve_rounding_mode(rounding_mode) |
| 141 | + y = rounding_fx(y) |
| 142 | + |
| 143 | + # Rescale |
| 144 | + y = y - zeropt |
| 145 | + y = y * scale |
| 146 | + |
| 147 | + return y |
| 148 | + |
| 149 | + |
| 150 | +class Trunc_v1(CustomOp): |
| 151 | + """Generic truncation operation for QONNX. Takes four inputs: |
| 152 | + - input tensor to truncate |
| 153 | + - the scale |
| 154 | + - the zero-point |
| 155 | + - the truncation bit-width |
| 156 | +
|
| 157 | + The output is a tensor of the same shape as the input tensor, with truncated |
| 158 | + values. |
| 159 | + """ |
| 160 | + |
| 161 | + def get_nodeattr_types(self): |
| 162 | + return { |
| 163 | + # The rounding mode, which is used for the trunc function |
| 164 | + "rounding_mode": ("s", True, "FLOOR"), |
| 165 | + } |
| 166 | + |
| 167 | + def make_shape_compatible_op(self, model): |
| 168 | + node = self.onnx_node |
| 169 | + return helper.make_node("Identity", [node.input[0]], [node.output[0]]) |
| 170 | + |
| 171 | + def infer_node_datatype(self, model): |
| 172 | + node = self.onnx_node |
| 173 | + model.set_tensor_datatype(node.output[0], DataType["FLOAT32"]) |
| 174 | + |
| 175 | + def execute_node(self, context, graph): |
| 176 | + node = self.onnx_node |
| 177 | + # save inputs |
| 178 | + inp_tensor = context[node.input[0]] |
| 179 | + scale = context[node.input[1]] |
| 180 | + zeropt = context[node.input[2]] |
| 181 | + input_bit_width = context[node.input[3]] |
| 182 | + output_bit_width = context[node.input[4]] |
| 183 | + # save attributes |
| 184 | + rounding_mode = self.get_nodeattr("rounding_mode") |
| 185 | + # calculate output |
| 186 | + ret = trunc_v1(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode) |
| 187 | + # set context according to output name |
| 188 | + context[node.output[0]] = ret |
| 189 | + |
| 190 | + def verify_node(self): |
| 191 | + pass |
0 commit comments