|
| 1 | +### <a name="Trunc"></a><a name="abs">**Trunc**</a> |
| 2 | + |
| 3 | +Truncates the values of one input data (Tensor<T>) at a specified bitwidth and produces one output data (Tensor<T>). |
| 4 | +Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization. |
| 5 | +The attribute rounding_mode defines how truncated values are rounded. |
| 6 | + |
| 7 | +#### Version |
| 8 | + |
| 9 | +The description of this operator in this document corresponds to `qonnx.custom_ops.general` opset version 1. |
| 10 | + |
| 11 | +#### Attributes |
| 12 | + |
| 13 | +<dl> |
| 14 | +<dt><tt>rounding_mode</tt> : string (default is "FLOOR")</dt> |
| 15 | +<dd>Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".</dd> |
| 16 | +</dl> |
| 17 | + |
| 18 | +#### Inputs |
| 19 | + |
| 20 | +<dl> |
| 21 | +<dt><tt>X</tt> (differentiable) : tensor(float32)</dt> |
| 22 | +<dd>input tensor to truncate</dd> |
| 23 | +<dt><tt>scale</tt> : float32</dt> |
| 24 | +<dd>The scale factor</dd> |
| 25 | +<dt><tt>zeropt</tt> : float32</dt> |
| 26 | +<dd>The zero-point</dd> |
| 27 | +<dt><tt>in_bitwidth</tt> : int32</dt> |
| 28 | +<dd>The number of bits used at the input of the truncation</dd> |
| 29 | +<dt><tt>out_bitwidth</tt> : int32</dt> |
| 30 | +<dd>The number of bits used at the output of the truncation</dd> |
| 31 | +</dl> |
| 32 | + |
| 33 | + |
| 34 | +#### Outputs |
| 35 | + |
| 36 | +<dl> |
| 37 | +<dt><tt>Y</tt> (differentiable) : tensor(float32)</dt> |
| 38 | +<dd>Output tensor</dd> |
| 39 | +</dl> |
| 40 | + |
| 41 | + |
| 42 | +#### Examples |
| 43 | +<details> |
| 44 | +<summary>Trunc</summary> |
| 45 | + |
| 46 | +```python |
| 47 | +from onnx import helper |
| 48 | +import numpy as np |
| 49 | + |
| 50 | +# Define node settings and input |
| 51 | +x = np.random.randn(100).astype(np.float32)*10. |
| 52 | +scale = np.array(1.) |
| 53 | +zeropt = np.array(0.) |
| 54 | +in_bitwidth = np.array(10) |
| 55 | +out_bitwidth = np.array(4) |
| 56 | +rounding_mode = "ROUND" |
| 57 | + |
| 58 | +# Create node |
| 59 | +node = helper.make_node( |
| 60 | + 'Trunc', |
| 61 | + domain='finn.custom_op.general', |
| 62 | + inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'], |
| 63 | + outputs=['y'], |
| 64 | + rounding_mode=rounding_mode, |
| 65 | +) |
| 66 | + |
| 67 | +# Execute the same settings with the reference implementation (trunc) |
| 68 | +# See the sample implementation for more details on trunc. |
| 69 | +output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode) |
| 70 | + |
| 71 | +# Execute node and compare |
| 72 | +expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc') |
| 73 | + |
| 74 | +``` |
| 75 | + |
| 76 | +</details> |
| 77 | + |
| 78 | + |
| 79 | +#### Sample Implementation |
| 80 | + |
| 81 | +<details> |
| 82 | +<summary>Trunc</summary> |
| 83 | + |
| 84 | +```python |
| 85 | +# SPDX-License-Identifier: Apache-2.0 |
| 86 | + |
| 87 | +from __future__ import absolute_import |
| 88 | +from __future__ import division |
| 89 | +from __future__ import print_function |
| 90 | +from __future__ import unicode_literals |
| 91 | + |
| 92 | +import numpy as np |
| 93 | + |
| 94 | +def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): |
| 95 | + # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR |
| 96 | + |
| 97 | + # Scaling |
| 98 | + y = inp_tensor / scale |
| 99 | + y = y + zeropt |
| 100 | + # Rounding |
| 101 | + y = np.round(y) |
| 102 | + # Truncate |
| 103 | + trunc_bit_width = input_bit_width - output_bit_width |
| 104 | + trunc_scale = 2.0 ** trunc_bit_width |
| 105 | + y = y / trunc_scale |
| 106 | + |
| 107 | + # To int |
| 108 | + rounding_fx = resolve_rounding_mode(rounding_mode) |
| 109 | + y = rounding_fx(y) |
| 110 | + |
| 111 | + # Rescale |
| 112 | + y = y - zeropt |
| 113 | + y = y * scale |
| 114 | + |
| 115 | + return y |
| 116 | + |
| 117 | +def resolve_rounding_mode(mode_string): |
| 118 | + """Resolve the rounding mode string of Quant and Trunc ops |
| 119 | + to the corresponding numpy functions.""" |
| 120 | + if mode_string == "ROUND": |
| 121 | + return np.round |
| 122 | + elif mode_string == "CEIL": |
| 123 | + return np.ceil |
| 124 | + elif mode_string == "FLOOR": |
| 125 | + return np.floor |
| 126 | + else: |
| 127 | + raise ValueError(f"Could not resolve rounding mode called: {mode_string}") |
| 128 | + |
| 129 | +``` |
| 130 | + |
| 131 | +</details> |
0 commit comments