Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/qonnx/transformation/batchnorm_to_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class BatchNormToAffine(Transformation):
Expand Down Expand Up @@ -89,6 +89,9 @@ def apply(self, model):
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name])
add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output])
# preserve metadata from original batchnorm node
copy_metadata_props(n, mul_node)
copy_metadata_props(n, add_node)
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
Expand Down
5 changes: 4 additions & 1 deletion src/qonnx/transformation/bipolar_to_xnor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_datatypes import InferDataTypes
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class ConvertBipolarMatMulToXnorPopcount(Transformation):
Expand Down Expand Up @@ -132,6 +132,9 @@ def find_prod_mt(x):
# create Mul and Add nodes to replace the batchnorm
mul_node = oh.make_node("Mul", [xnorpcout.name, mul_const.name], [mul_output.name])
add_node = oh.make_node("Add", [mul_output.name, add_const.name], [mm_output])
# preserve metadata from original MatMul node
copy_metadata_props(n, mul_node)
copy_metadata_props(n, add_node)
# insert where the batchnorm is to preserve topological ordering
graph.node.insert(node_ind, mul_node)
graph.node.insert(node_ind + 1, add_node)
Expand Down
5 changes: 4 additions & 1 deletion src/qonnx/transformation/change_datalayout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from qonnx.transformation.base import Transformation
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class ChangeDataLayoutQuantAvgPool2d(Transformation):
Expand Down Expand Up @@ -78,6 +78,7 @@ def apply(self, model):
graph.value_info.append(quantavg_out)
quantavg_out = quantavg_out.name
inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1])
copy_metadata_props(n, inp_trans_node)
quantavg_node = helper.make_node(
"QuantAvgPool2d",
[inp_trans_out],
Expand All @@ -90,8 +91,10 @@ def apply(self, model):
signed=signed,
data_layout="NHWC",
)
copy_metadata_props(n, quantavg_node)
# NHWC -> NCHW
out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2])
copy_metadata_props(n, out_trans_node)
# insert nodes
graph.node.insert(node_ind, inp_trans_node)
graph.node.insert(node_ind + 1, quantavg_node)
Expand Down
10 changes: 8 additions & 2 deletions src/qonnx/transformation/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name
from qonnx.util.onnx import is_eltwise_optype

# Standard ONNX nodes which require a ChannelsLast data format to function properly
Expand Down Expand Up @@ -96,6 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm)
copy_metadata_props(transpose_node, new_transpose_node)
t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
Expand All @@ -107,13 +108,15 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe
model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64))
unsqueeze_out_name = model.make_new_valueinfo_name()
new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name])
copy_metadata_props(eltwise_inp, new_unsqueeze_node)
unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape
model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape)
model.graph.node.append(new_unsqueeze_node)
# now add inverse transpose
new_t_inp = model.make_new_valueinfo_name()
inv_perm = np.argsort(perm)
new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm)
copy_metadata_props(transpose_node, new_transpose_node)
t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape
model.set_tensor_shape(new_t_inp, t_shape)
eltwise_node.input[ind] = new_t_inp
Expand Down Expand Up @@ -239,6 +242,7 @@ def apply(self, model):
# channels last transpose
inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim))
graph.node.insert(running_node_index, inp_trans_node)
copy_metadata_props(n, inp_trans_node)
running_node_index += 1

# Attach to original node
Expand All @@ -265,6 +269,7 @@ def apply(self, model):
"Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim)
)
graph.node.insert(running_node_index, outp_trans_node)
copy_metadata_props(n, outp_trans_node)
running_node_index += 1

# Attach to original node
Expand Down Expand Up @@ -567,7 +572,8 @@ def apply(self, model):
axis=1,
)
graph.node.insert(node_ind, flat_node)

copy_metadata_props(n, flat_node)

graph_modified = True
else:
warnings.warn(
Expand Down
2 changes: 2 additions & 0 deletions src/qonnx/transformation/extract_conv_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from onnx import helper

from qonnx.transformation.base import Transformation
from qonnx.util.basic import copy_metadata_props


class ExtractBiasFromConv(Transformation):
Expand Down Expand Up @@ -75,6 +76,7 @@ def apply(self, model):
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)
copy_metadata_props(n, add_node)
graph.node.insert(node_ind, add_node)

# Repoint Conv output and remove bias tensor
Expand Down
5 changes: 5 additions & 0 deletions src/qonnx/transformation/extract_quant_scale_zeropt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from qonnx.transformation.base import Transformation
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.util.basic import copy_metadata_props


class ExtractQuantScaleZeroPt(Transformation):
Expand Down Expand Up @@ -69,6 +70,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_scaled)
inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm])
copy_metadata_props(node, inp_scale_node)
graph.node.append(inp_scale_node)
# create new Mul node
# remove scale from Quant node
Expand All @@ -87,6 +89,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(inp_zeropt)
inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm])
copy_metadata_props(node, inp_zeropt_node)
graph.node.append(inp_zeropt_node)
# remove zeropt from Quant node
new_zeropt_nm = model.make_new_valueinfo_name()
Expand All @@ -108,6 +111,7 @@ def apply(self, model: ModelWrapper):
)
graph.value_info.append(out_zeropt)
out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output])
copy_metadata_props(node, out_zeropt_node)
last_node.output[0] = out_zeropt_nm
graph.node.append(out_zeropt_node)
# important: when tracking a pointer to newly added nodes,
Expand All @@ -127,6 +131,7 @@ def apply(self, model: ModelWrapper):
last_node.output[0] = out_scale_nm
graph.value_info.append(out_scale)
out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output])
copy_metadata_props(node, out_scale_node)
graph.node.append(out_scale_node)

if extract_scale or extract_zeropt:
Expand Down
9 changes: 7 additions & 2 deletions src/qonnx/transformation/gemm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from qonnx.core.datatype import DataType
from qonnx.transformation.base import Transformation
from qonnx.transformation.remove import RemoveIdentityOps
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


class GemmToMatMul(Transformation):
Expand Down Expand Up @@ -76,6 +76,7 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name])
copy_metadata_props(n, inp_trans_node)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[0])
Expand All @@ -98,6 +99,7 @@ def apply(self, model):
)
graph.value_info.append(inp_trans_out)
inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name])
copy_metadata_props(n, inp_trans_node)
graph.node.insert(running_node_index, inp_trans_node)
running_node_index += 1
# Copy over the datatype
Expand All @@ -109,6 +111,7 @@ def apply(self, model):

# Insert MatMul: A * B
matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]])
copy_metadata_props(n, matMul_node)
graph.node.insert(running_node_index, matMul_node)
matMul_node = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -144,6 +147,7 @@ def apply(self, model):
[act_mul_tensor.name, mul_tensor.name],
[n.output[0]],
)
copy_metadata_props(n, mul_node)
graph.node.insert(running_node_index, mul_node)
mul_node_main_branch = graph.node[running_node_index]
running_node_index += 1
Expand Down Expand Up @@ -175,6 +179,7 @@ def apply(self, model):
[n.input[2], mul_tensor.name],
[act_mul_tensor.name],
)
copy_metadata_props(n, mul_node)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
dt = model.get_tensor_datatype(n.input[2])
Expand All @@ -196,7 +201,7 @@ def apply(self, model):
[act_add_tensor.name, n.input[2]],
[n.output[0]],
)

copy_metadata_props(n, add_node)
graph.node.insert(running_node_index, add_node)
running_node_index += 1

Expand Down
6 changes: 5 additions & 1 deletion src/qonnx/transformation/lower_convs_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from qonnx.transformation.base import Transformation
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


class LowerConvsToMatMul(Transformation):
Expand Down Expand Up @@ -152,6 +152,7 @@ def apply(self, model):
# create new nodes
# NCHW -> NHWC
inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1])
copy_metadata_props(node, inp_trans_node)
nodes_to_insert = [inp_trans_node]

if need_im2col:
Expand All @@ -174,12 +175,15 @@ def apply(self, model):
dilations=dilation,
)
nodes_to_insert.append(im2col_node)
copy_metadata_props(node, im2col_node)

matmul_input = im2col_out if need_im2col else inp_trans_out
# do matmul
matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out])
copy_metadata_props(node, matmul_node)
# NHWC -> NCHW
out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2])
copy_metadata_props(node, out_trans_node)

nodes_to_insert.extend([matmul_node, out_trans_node])

Expand Down
4 changes: 3 additions & 1 deletion src/qonnx/transformation/qcdq_to_qonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.base import Transformation
from qonnx.util.basic import get_by_name
from qonnx.util.basic import copy_metadata_props, get_by_name


def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]:
Expand Down Expand Up @@ -203,6 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]:
rounding_mode="ROUND", # round-to-even
signed=signed,
)
# Preserve metadata from all nodes being fused
copy_metadata_props(node, fused_node)
model.graph.node.insert(dequant_node_index, fused_node)
for node_to_remove in nodes_to_remove:
model.graph.node.remove(node_to_remove)
Expand Down
2 changes: 2 additions & 0 deletions src/qonnx/transformation/rebalance_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from qonnx.custom_op.registry import getCustomOp
from qonnx.transformation.base import Transformation
from qonnx.util.basic import copy_metadata_props


class RebalanceIm2Col(Transformation):
Expand Down Expand Up @@ -103,6 +104,7 @@ def apply(self, model):
inp_reshape_node = helper.make_node(
"Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name]
)
copy_metadata_props(node, inp_reshape_node)
graph.node.insert(running_node_index, inp_reshape_node)
# rewire Im2Col input
node.input[0] = inp_reshape_out.name
Expand Down
3 changes: 2 additions & 1 deletion src/qonnx/transformation/resize_conv_to_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qonnx.core.datatype import DataType
from qonnx.custom_op.general.quant import quant, resolve_rounding_mode
from qonnx.transformation.base import Transformation
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray:
Expand Down Expand Up @@ -242,6 +242,7 @@ def apply(self, model):
group=group,
dilations=dilation,
)
copy_metadata_props(conv, deconv_node)
W_deconv_init = weight_name
if weight_prod is not None:
W_deconv_init = q_w_name
Expand Down
3 changes: 2 additions & 1 deletion src/qonnx/transformation/subpixel_to_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from onnx import helper

from qonnx.transformation.base import Transformation
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name


def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray:
Expand Down Expand Up @@ -197,6 +197,7 @@ def apply(self, model):
group=group,
dilations=dilation,
)
copy_metadata_props(n, deconv_node)
W_deconv_init = weight_name
if weight_prod is not None:
W_deconv_init = q_w_name
Expand Down
Loading
Loading