From ba24ecccf410680dea0ec09d872502a530b613a4 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 20:56:41 +0000 Subject: [PATCH 01/12] copy in metadata preservation --- .../transformation/extract_quant_scale_zeropt.py | 8 ++++++++ src/qonnx/transformation/gemm_to_matmul.py | 13 ++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..614df416 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + if hasattr(node, "metadata_props"): + inp_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + if hasattr(node, "metadata_props"): + inp_zeropt_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_zeropt_node.metadata_props.extend(node.metadata_props) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..1298f3d6 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -76,6 +76,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +100,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +113,8 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + if hasattr(n, "metadata_props"): + matMul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +150,8 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +183,8 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +206,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, add_node) running_node_index += 1 From da632b9f9b94817a6e9dba51f2ac665818a55f84 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 21:59:38 +0000 Subject: [PATCH 02/12] expand metadata copy coverage to other transforms --- src/qonnx/transformation/change_datalayout.py | 6 ++++++ src/qonnx/transformation/channels_last.py | 6 ++++++ src/qonnx/transformation/extract_conv_bias.py | 2 ++ src/qonnx/transformation/lower_convs_to_matmul.py | 4 ++++ src/qonnx/transformation/qcdq_to_qonnx.py | 4 ++++ src/qonnx/transformation/rebalance_conv.py | 2 ++ src/qonnx/transformation/resize_conv_to_deconv.py | 3 +++ src/qonnx/transformation/subpixel_to_deconv.py | 3 +++ 8 files changed, 30 insertions(+) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 7b73e4bf..07fbe400 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -78,6 +78,8 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -90,8 +92,12 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) + if hasattr(n, "metadata_props"): + quantavg_node.metadata_props.extend(n.metadata_props) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) + if hasattr(n, "metadata_props"): + out_trans_node.metadata_props.extend(n.metadata_props) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..8c934190 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -96,6 +96,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) + if hasattr(transpose_node, "metadata_props"): + new_transpose_node.metadata_props.extend(transpose_node.metadata_props) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -107,6 +109,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) + if hasattr(eltwise_inp, "metadata_props"): + new_unsqueeze_node.metadata_props.extend(eltwise_inp.metadata_props) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -114,6 +118,8 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) + if hasattr(transpose_node, "metadata_props"): + new_transpose_node.metadata_props.extend(transpose_node.metadata_props) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index bf2cf8b4..1bf264e3 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -75,6 +75,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 81f0b713..d864d8d2 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -178,8 +178,12 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) + if hasattr(node, "metadata_props"): + matmul_node.metadata_props.extend(node.metadata_props) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) + if hasattr(node, "metadata_props"): + out_trans_node.metadata_props.extend(node.metadata_props) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b7e35c0d..b122d840 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -203,6 +203,10 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) + # Pass on metadata from DequantizeLinear node since it's the only node that + # must be present to be able to perform this transformation. + if hasattr(node, "metadata_props"): + fused_node.metadata_props.extend(node.metadata_props) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index ecb2b5e4..098bff20 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -103,6 +103,8 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) + if hasattr(node, "metadata_props"): + inp_reshape_node.metadata_props.extend(node.metadata_props) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 0dd40972..30bc6c3c 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -242,6 +242,9 @@ def apply(self, model): group=group, dilations=dilation, ) + # Save metadata from the convolution node + if hasattr(conv, "metadata_props"): + deconv_node.metadata_props.extend(conv.metadata_props) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 3f330c99..241422c8 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -197,6 +197,9 @@ def apply(self, model): group=group, dilations=dilation, ) + # Save metadata from the original convolution node + if hasattr(n, "metadata_props"): + deconv_node.metadata_props.extend(n.metadata_props) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name From 0d9d3e56ad6cc899872f5730b8be0af10f7ead02 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:06:39 +0000 Subject: [PATCH 03/12] add copy metadata props function --- src/qonnx/transformation/change_datalayout.py | 11 +++---- .../extract_quant_scale_zeropt.py | 13 +++----- src/qonnx/transformation/gemm_to_matmul.py | 20 ++++-------- .../transformation/lower_convs_to_matmul.py | 8 ++--- src/qonnx/transformation/qcdq_to_qonnx.py | 8 ++--- .../transformation/resize_conv_to_deconv.py | 6 ++-- .../transformation/subpixel_to_deconv.py | 6 ++-- src/qonnx/util/basic.py | 32 +++++++++++++++++++ 8 files changed, 58 insertions(+), 46 deletions(-) diff --git a/src/qonnx/transformation/change_datalayout.py b/src/qonnx/transformation/change_datalayout.py index 07fbe400..62e6140b 100644 --- a/src/qonnx/transformation/change_datalayout.py +++ b/src/qonnx/transformation/change_datalayout.py @@ -30,7 +30,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ChangeDataLayoutQuantAvgPool2d(Transformation): @@ -78,8 +78,7 @@ def apply(self, model): graph.value_info.append(quantavg_out) quantavg_out = quantavg_out.name inp_trans_node = helper.make_node("Transpose", [node_input], [inp_trans_out], perm=[0, 2, 3, 1]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) quantavg_node = helper.make_node( "QuantAvgPool2d", [inp_trans_out], @@ -92,12 +91,10 @@ def apply(self, model): signed=signed, data_layout="NHWC", ) - if hasattr(n, "metadata_props"): - quantavg_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, quantavg_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [quantavg_out], [node_output], perm=[0, 3, 1, 2]) - if hasattr(n, "metadata_props"): - out_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, out_trans_node) # insert nodes graph.node.insert(node_ind, inp_trans_node) graph.node.insert(node_ind + 1, quantavg_node) diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 614df416..f76e5555 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -33,6 +33,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph from qonnx.transformation.remove import RemoveIdentityOps +from qonnx.util.basic import copy_metadata_props class ExtractQuantScaleZeroPt(Transformation): @@ -69,8 +70,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) - if hasattr(node, "metadata_props"): - inp_scale_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_scale_node) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -89,8 +89,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) - if hasattr(node, "metadata_props"): - inp_zeropt_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_zeropt_node) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -112,8 +111,7 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) - if hasattr(node, "metadata_props"): - out_zeropt_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_zeropt_node) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -133,8 +131,7 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) - if hasattr(node, "metadata_props"): - out_scale_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_scale_node) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 1298f3d6..245a0a2a 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.transformation.base import Transformation from qonnx.transformation.remove import RemoveIdentityOps -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class GemmToMatMul(Transformation): @@ -76,8 +76,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -100,8 +99,7 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) - if hasattr(n, "metadata_props"): - inp_trans_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, inp_trans_node) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -113,8 +111,7 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) - if hasattr(n, "metadata_props"): - matMul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, matMul_node) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -150,8 +147,7 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) - if hasattr(n, "metadata_props"): - mul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -183,8 +179,7 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) - if hasattr(n, "metadata_props"): - mul_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, mul_node) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -206,8 +201,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - if hasattr(n, "metadata_props"): - add_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, add_node) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index d864d8d2..5140b71d 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name class LowerConvsToMatMul(Transformation): @@ -178,12 +178,10 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) - if hasattr(node, "metadata_props"): - matmul_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, matmul_node) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) - if hasattr(node, "metadata_props"): - out_trans_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, out_trans_node) nodes_to_insert.extend([matmul_node, out_trans_node]) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index b122d840..7aaf9271 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -34,7 +34,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name def extract_elem_type(elem_type: int, clip_range=None) -> Tuple[int, int, bool]: @@ -203,10 +203,8 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: rounding_mode="ROUND", # round-to-even signed=signed, ) - # Pass on metadata from DequantizeLinear node since it's the only node that - # must be present to be able to perform this transformation. - if hasattr(node, "metadata_props"): - fused_node.metadata_props.extend(node.metadata_props) + # Preserve metadata from all nodes being fused + copy_metadata_props(nodes_to_remove, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) diff --git a/src/qonnx/transformation/resize_conv_to_deconv.py b/src/qonnx/transformation/resize_conv_to_deconv.py index 30bc6c3c..7eda4fa7 100644 --- a/src/qonnx/transformation/resize_conv_to_deconv.py +++ b/src/qonnx/transformation/resize_conv_to_deconv.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.general.quant import quant, resolve_rounding_mode from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray: @@ -242,9 +242,7 @@ def apply(self, model): group=group, dilations=dilation, ) - # Save metadata from the convolution node - if hasattr(conv, "metadata_props"): - deconv_node.metadata_props.extend(conv.metadata_props) + copy_metadata_props(conv, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/transformation/subpixel_to_deconv.py b/src/qonnx/transformation/subpixel_to_deconv.py index 241422c8..73ef3f8f 100644 --- a/src/qonnx/transformation/subpixel_to_deconv.py +++ b/src/qonnx/transformation/subpixel_to_deconv.py @@ -31,7 +31,7 @@ from onnx import helper from qonnx.transformation.base import Transformation -from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name +from qonnx.util.basic import auto_pad_to_explicit_padding, copy_metadata_props, get_by_name def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray: @@ -197,9 +197,7 @@ def apply(self, model): group=group, dilations=dilation, ) - # Save metadata from the original convolution node - if hasattr(n, "metadata_props"): - deconv_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, deconv_node) W_deconv_init = weight_name if weight_prod is not None: W_deconv_init = q_w_name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 4e300dd1..1696c42b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -350,3 +350,35 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w] else: raise Exception("Unsupported auto_pad: " + autopad_str) + + +def copy_metadata_props(source_node, target_node): + """Copy metadata properties from source node(s) to target node. + + Parameters + ---------- + source_node : onnx.NodeProto or list of onnx.NodeProto + Source node(s) from which to copy metadata_props. If a list is provided, + metadata from all nodes will be merged into the target node. + target_node : onnx.NodeProto + Target node to which metadata_props will be copied. + + Returns + ------- + None + Modifies target_node in place by extending its metadata_props. + + Examples + -------- + >>> # Copy from single node + >>> copy_metadata_props(old_node, new_node) + >>> + >>> # Copy from multiple nodes (e.g., when fusing) + >>> copy_metadata_props([quant_node, dequant_node], fused_node) + """ + # Handle both single node and list of nodes + source_nodes = source_node if isinstance(source_node, list) else [source_node] + + for node in source_nodes: + if hasattr(node, "metadata_props"): + target_node.metadata_props.extend(node.metadata_props) From 6f3a631abc80d1a9bed7865fcd485d5a9bb212ad Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:19:46 +0000 Subject: [PATCH 04/12] convert missed functions --- src/qonnx/transformation/channels_last.py | 11 ++++------- src/qonnx/transformation/extract_conv_bias.py | 4 ++-- src/qonnx/transformation/rebalance_conv.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 8c934190..444a326c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -40,7 +40,7 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly @@ -96,8 +96,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [eltwise_inp], [new_t_inp], perm=inv_perm) - if hasattr(transpose_node, "metadata_props"): - new_transpose_node.metadata_props.extend(transpose_node.metadata_props) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(inp_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp @@ -109,8 +108,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe model.set_initializer(unsqueeze_param_name, np.asarray(list(range(ndim_inp - ndim)), dtype=np.int64)) unsqueeze_out_name = model.make_new_valueinfo_name() new_unsqueeze_node = helper.make_node("Unsqueeze", [eltwise_inp, unsqueeze_param_name], [unsqueeze_out_name]) - if hasattr(eltwise_inp, "metadata_props"): - new_unsqueeze_node.metadata_props.extend(eltwise_inp.metadata_props) + copy_metadata_props(eltwise_inp, new_unsqueeze_node) unsqueeze_out_shape = np.expand_dims(np.empty(inp_shape), axis=tuple(range(ndim_inp - ndim))).shape model.set_tensor_shape(unsqueeze_out_name, unsqueeze_out_shape) model.graph.node.append(new_unsqueeze_node) @@ -118,8 +116,7 @@ def move_transpose_past_eltwise(transpose_node, eltwise_node, model: ModelWrappe new_t_inp = model.make_new_valueinfo_name() inv_perm = np.argsort(perm) new_transpose_node = helper.make_node("Transpose", [unsqueeze_out_name], [new_t_inp], perm=inv_perm) - if hasattr(transpose_node, "metadata_props"): - new_transpose_node.metadata_props.extend(transpose_node.metadata_props) + copy_metadata_props(transpose_node, new_transpose_node) t_shape = np.transpose(np.empty(unsqueeze_out_shape), axes=inv_perm).shape model.set_tensor_shape(new_t_inp, t_shape) eltwise_node.input[ind] = new_t_inp diff --git a/src/qonnx/transformation/extract_conv_bias.py b/src/qonnx/transformation/extract_conv_bias.py index 1bf264e3..34b017bd 100644 --- a/src/qonnx/transformation/extract_conv_bias.py +++ b/src/qonnx/transformation/extract_conv_bias.py @@ -30,6 +30,7 @@ from onnx import helper from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class ExtractBiasFromConv(Transformation): @@ -75,8 +76,7 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - if hasattr(n, "metadata_props"): - add_node.metadata_props.extend(n.metadata_props) + copy_metadata_props(n, add_node) graph.node.insert(node_ind, add_node) # Repoint Conv output and remove bias tensor diff --git a/src/qonnx/transformation/rebalance_conv.py b/src/qonnx/transformation/rebalance_conv.py index 098bff20..0107a62a 100644 --- a/src/qonnx/transformation/rebalance_conv.py +++ b/src/qonnx/transformation/rebalance_conv.py @@ -31,6 +31,7 @@ from qonnx.custom_op.registry import getCustomOp from qonnx.transformation.base import Transformation +from qonnx.util.basic import copy_metadata_props class RebalanceIm2Col(Transformation): @@ -103,8 +104,7 @@ def apply(self, model): inp_reshape_node = helper.make_node( "Reshape", [node.input[0], inp_shapedata.name], [inp_reshape_out.name] ) - if hasattr(node, "metadata_props"): - inp_reshape_node.metadata_props.extend(node.metadata_props) + copy_metadata_props(node, inp_reshape_node) graph.node.insert(running_node_index, inp_reshape_node) # rewire Im2Col input node.input[0] = inp_reshape_out.name From 59ca168ebd4fa744b1ee4de15f8ac849f43464e1 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 30 Oct 2025 22:23:11 +0000 Subject: [PATCH 05/12] correct fused node source mistake --- src/qonnx/transformation/qcdq_to_qonnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qonnx/transformation/qcdq_to_qonnx.py b/src/qonnx/transformation/qcdq_to_qonnx.py index 7aaf9271..b4e18f25 100644 --- a/src/qonnx/transformation/qcdq_to_qonnx.py +++ b/src/qonnx/transformation/qcdq_to_qonnx.py @@ -204,7 +204,7 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: signed=signed, ) # Preserve metadata from all nodes being fused - copy_metadata_props(nodes_to_remove, fused_node) + copy_metadata_props(node, fused_node) model.graph.node.insert(dequant_node_index, fused_node) for node_to_remove in nodes_to_remove: model.graph.node.remove(node_to_remove) From 159060670cf5692ffb6b79d88b65e66bec39327c Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Sat, 22 Nov 2025 00:41:31 +0000 Subject: [PATCH 06/12] add metadata preservation to batchnorm transform. --- src/qonnx/transformation/batchnorm_to_affine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index c89d2bdc..d63dd178 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import get_by_name, copy_metadata_props class BatchNormToAffine(Transformation): @@ -89,6 +89,9 @@ def apply(self, model): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [bn_input, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [bn_output]) + # preserve metadata from original batchnorm node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) From 5690a790291a0cc9a54fd9fb96d8a25080d63101 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Sat, 22 Nov 2025 01:00:42 +0000 Subject: [PATCH 07/12] adding more copy metadata nodes. --- src/qonnx/transformation/batchnorm_to_affine.py | 2 +- src/qonnx/transformation/bipolar_to_xnor.py | 5 ++++- src/qonnx/transformation/channels_last.py | 5 ++++- src/qonnx/transformation/lower_convs_to_matmul.py | 2 ++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/qonnx/transformation/batchnorm_to_affine.py b/src/qonnx/transformation/batchnorm_to_affine.py index d63dd178..6190f867 100644 --- a/src/qonnx/transformation/batchnorm_to_affine.py +++ b/src/qonnx/transformation/batchnorm_to_affine.py @@ -32,7 +32,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name, copy_metadata_props +from qonnx.util.basic import copy_metadata_props, get_by_name class BatchNormToAffine(Transformation): diff --git a/src/qonnx/transformation/bipolar_to_xnor.py b/src/qonnx/transformation/bipolar_to_xnor.py index 37f939a2..0b764ef8 100644 --- a/src/qonnx/transformation/bipolar_to_xnor.py +++ b/src/qonnx/transformation/bipolar_to_xnor.py @@ -36,7 +36,7 @@ from qonnx.transformation.base import Transformation from qonnx.transformation.infer_datatypes import InferDataTypes from qonnx.transformation.infer_shapes import InferShapes -from qonnx.util.basic import get_by_name +from qonnx.util.basic import copy_metadata_props, get_by_name class ConvertBipolarMatMulToXnorPopcount(Transformation): @@ -132,6 +132,9 @@ def find_prod_mt(x): # create Mul and Add nodes to replace the batchnorm mul_node = oh.make_node("Mul", [xnorpcout.name, mul_const.name], [mul_output.name]) add_node = oh.make_node("Add", [mul_output.name, add_const.name], [mm_output]) + # preserve metadata from original MatMul node + copy_metadata_props(n, mul_node) + copy_metadata_props(n, add_node) # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 1a2a0dcd..f9ca62bb 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -242,6 +242,7 @@ def apply(self, model): # channels last transpose inp_trans_node = helper.make_node("Transpose", [inp], [inp_trans_out], perm=to_channels_last_args(ndim)) graph.node.insert(running_node_index, inp_trans_node) + copy_metadata_props(n, inp_trans_node) running_node_index += 1 # Attach to original node @@ -268,6 +269,7 @@ def apply(self, model): "Transpose", [outp_trans_in], [outp], perm=to_channels_first_args(ndim) ) graph.node.insert(running_node_index, outp_trans_node) + copy_metadata_props(n, outp_trans_node) running_node_index += 1 # Attach to original node @@ -570,7 +572,8 @@ def apply(self, model): axis=1, ) graph.node.insert(node_ind, flat_node) - + copy_metadata_props(n, flat_node) + graph_modified = True else: warnings.warn( diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 5140b71d..f0981b34 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -152,6 +152,7 @@ def apply(self, model): # create new nodes # NCHW -> NHWC inp_trans_node = helper.make_node("Transpose", [cnv_input], [inp_trans_out], perm=[0, 2, 3, 1]) + copy_metadata_props(node, inp_trans_node) nodes_to_insert = [inp_trans_node] if need_im2col: @@ -174,6 +175,7 @@ def apply(self, model): dilations=dilation, ) nodes_to_insert.append(im2col_node) + copy_metadata_props(node, im2col_node) matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul From 7c04726f8bb9c51fef07ca4a05dd647053a70e2a Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 25 Nov 2025 23:42:20 +0000 Subject: [PATCH 08/12] added overwrite mode flag and basic unit tests --- src/qonnx/util/basic.py | 22 ++++++++++-- tests/util/test_copy_metadata.py | 61 ++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/util/test_copy_metadata.py diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 73f4cca2..2c52285a 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -362,7 +362,7 @@ def auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h raise Exception("Unsupported auto_pad: " + autopad_str) -def copy_metadata_props(source_node, target_node): +def copy_metadata_props(source_node, target_node, mode="overwrite"): """Copy metadata properties from source node(s) to target node. Parameters @@ -386,9 +386,27 @@ def copy_metadata_props(source_node, target_node): >>> # Copy from multiple nodes (e.g., when fusing) >>> copy_metadata_props([quant_node, dequant_node], fused_node) """ + assert mode in ["overwrite", "keep_existing"], "Copy Metadata Mode must be either 'overwrite' or 'keep_existing'." + # Handle both single node and list of nodes source_nodes = source_node if isinstance(source_node, list) else [source_node] for node in source_nodes: if hasattr(node, "metadata_props"): - target_node.metadata_props.extend(node.metadata_props) + + # check for existing keys in target_node to avoid duplicates + if hasattr(target_node, "metadata_props"): + existing_keys = {prop.key for prop in target_node.metadata_props} + else: + existing_keys = set() + + for prop in node.metadata_props: + if prop.key in existing_keys: + if mode == "overwrite": + # Overwrite existing metadata property + for existing_prop in target_node.metadata_props: + if existing_prop.key == prop.key: + existing_prop.value = prop.value + break + else: + target_node.metadata_props.append(prop) \ No newline at end of file diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py new file mode 100644 index 00000000..f976eabb --- /dev/null +++ b/tests/util/test_copy_metadata.py @@ -0,0 +1,61 @@ + +import onnx +import pytest +from qonnx.util.basic import copy_metadata_props + + +def add_metadata(key, value): + return onnx.StringStringEntryProto(key=key, value=value) + + +def test_copy_metadata_props(): + + # Create source node with metadata + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] + ) + dst_node = onnx.NodeProto() + + copy_metadata_props(src_node, dst_node) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + assert dst_node.metadata_props[0].value == "value1" + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +@pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) +def test_copy_metadata_props_existing_target_md(mode): + + # Create source node with metadata + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] + ) + # Create destination node with existing metadata + dst_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value3")] + ) + + copy_metadata_props(src_node, dst_node, mode=mode) + + assert len(dst_node.metadata_props) == 2 + assert dst_node.metadata_props[0].key == "key1" + + if mode == "keep_existing": + assert dst_node.metadata_props[0].value == "value3" # Should keep existing + elif mode == "overwrite": + assert dst_node.metadata_props[0].value == "value1" # Should be overwritten + + assert dst_node.metadata_props[1].key == "key2" + assert dst_node.metadata_props[1].value == "value2" + + +def test_copy_metadata_props_bad_mode(): + src_node = onnx.NodeProto( + metadata_props=[add_metadata("key1", "value1")] + ) + dst_node = onnx.NodeProto() + + with pytest.raises(AssertionError): + copy_metadata_props(src_node, dst_node, mode="invalid_mode") \ No newline at end of file From 3085eee25c656cca08eed48b89f9e61be45e562f Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Tue, 25 Nov 2025 23:43:15 +0000 Subject: [PATCH 09/12] update documention copy_metadata_props --- src/qonnx/util/basic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 2c52285a..2752212b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -372,6 +372,14 @@ def copy_metadata_props(source_node, target_node, mode="overwrite"): metadata from all nodes will be merged into the target node. target_node : onnx.NodeProto Target node to which metadata_props will be copied. + mode : str, optional + Mode for handling existing metadata properties in the target node. + Options are: + - "overwrite": Existing properties in the target node will be overwritten + by those from the source node(s) if they share the same key. + - "keep_existing": Existing properties in the target node will be kept, + and only new properties from the source node(s) will be added. + Default is "overwrite". Returns ------- From 43793044915aa32ab47c2430a11f7dd80b74d987 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 17:25:56 +0000 Subject: [PATCH 10/12] add gemm2matmul test --- tests/util/test_copy_metadata.py | 34 +++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index f976eabb..81f39a09 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -1,6 +1,7 @@ import onnx import pytest +from qonnx.transformation.infer_shapes import InferShapes from qonnx.util.basic import copy_metadata_props @@ -58,4 +59,35 @@ def test_copy_metadata_props_bad_mode(): dst_node = onnx.NodeProto() with pytest.raises(AssertionError): - copy_metadata_props(src_node, dst_node, mode="invalid_mode") \ No newline at end of file + copy_metadata_props(src_node, dst_node, mode="invalid_mode") + + +from onnxscript import script +from onnxscript import opset9 as op +from onnxscript import FLOAT +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul + +def test_copy_metadata_props_gemm2matmul(): + + @script() + def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: + return op.Gemm(A, B, C) + + model_proto = MyGemm.to_model_proto() + gemm_node = model_proto.graph.node[0] + gemm_node.metadata_props.extend([ + add_metadata("key1", "value1"), + add_metadata("key2", "value2") + ]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + + transformed_mw = mw.transform(GemmToMatMul()) + + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == 'key1' + assert node.metadata_props[0].value == 'value1' + assert node.metadata_props[1].key == 'key2' + assert node.metadata_props[1].value == 'value2' \ No newline at end of file From 9cbd803969fe0d844aef65ef96e9754130a63af8 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 18:55:18 +0000 Subject: [PATCH 11/12] add batchnorm to affine test --- tests/util/test_copy_metadata.py | 46 +++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index 81f39a09..32a9ed1f 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -63,7 +63,7 @@ def test_copy_metadata_props_bad_mode(): from onnxscript import script -from onnxscript import opset9 as op +from onnxscript import opset17 as op from onnxscript import FLOAT from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.gemm_to_matmul import GemmToMatMul @@ -86,6 +86,50 @@ def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: transformed_mw = mw.transform(GemmToMatMul()) + for node in transformed_mw.graph.node: + assert node.metadata_props[0].key == 'key1' + assert node.metadata_props[0].value == 'value1' + assert node.metadata_props[1].key == 'key2' + assert node.metadata_props[1].value == 'value2' + + +from onnx import helper as oh +import numpy as np +import onnxscript +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + + +def test_copy_metadata_props_batchnorm2affine(): + @script() + def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: + scale = op.Constant(value=[[1.0, 1.0, 1.0]]) + B = op.Constant(value=[[0.0, 0.0, 0.0]]) + var = op.Constant(value=[[1.0, 1.0, 1.0]]) + mean = op.Constant(value=[[0.0, 0.0, 0.0]]) + return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) + + # remove cast-like nodes + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + + # batchnorm_to_affine requires initializers for scale/mean/var/bias + model_ir = onnxscript.ir.serde.deserialize_model(model_proto) + pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) + PassResult = pass_.call(model_ir) + model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) + + # Add metadata to BatchNorm node + bn_node = model_proto.graph.node[0] + bn_node.metadata_props.extend([ + add_metadata("key1", "value1"), + add_metadata("key2", "value2") + ]) + + # Create Model Wrapper + mw = ModelWrapper(model_proto) + from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + transformed_mw = mw.transform(BatchNormToAffine()) + + # Check that metadata was copied for node in transformed_mw.graph.node: assert node.metadata_props[0].key == 'key1' assert node.metadata_props[0].value == 'value1' From 47c9bb855417e97bbda776c527f1a67a155811e1 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 26 Nov 2025 19:05:44 +0000 Subject: [PATCH 12/12] force precommit run --- tests/util/test_copy_metadata.py | 109 +++++++++++++------------------ 1 file changed, 44 insertions(+), 65 deletions(-) diff --git a/tests/util/test_copy_metadata.py b/tests/util/test_copy_metadata.py index 32a9ed1f..1cc913b9 100644 --- a/tests/util/test_copy_metadata.py +++ b/tests/util/test_copy_metadata.py @@ -1,7 +1,14 @@ +import pytest import onnx -import pytest -from qonnx.transformation.infer_shapes import InferShapes +import onnxscript +from onnxscript import FLOAT +from onnxscript import opset17 as op +from onnxscript import script +from onnxscript.ir.passes.common import LiftConstantsToInitializersPass + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.transformation.gemm_to_matmul import GemmToMatMul from qonnx.util.basic import copy_metadata_props @@ -10,15 +17,12 @@ def add_metadata(key, value): def test_copy_metadata_props(): - # Create source node with metadata - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) dst_node = onnx.NodeProto() - + copy_metadata_props(src_node, dst_node) - + assert len(dst_node.metadata_props) == 2 assert dst_node.metadata_props[0].key == "key1" assert dst_node.metadata_props[0].value == "value1" @@ -28,77 +32,54 @@ def test_copy_metadata_props(): @pytest.mark.parametrize("mode", ["keep_existing", "overwrite"]) def test_copy_metadata_props_existing_target_md(mode): - # Create source node with metadata - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1"), add_metadata("key2", "value2")]) # Create destination node with existing metadata - dst_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value3")] - ) - + dst_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value3")]) + copy_metadata_props(src_node, dst_node, mode=mode) - + assert len(dst_node.metadata_props) == 2 assert dst_node.metadata_props[0].key == "key1" - + if mode == "keep_existing": assert dst_node.metadata_props[0].value == "value3" # Should keep existing elif mode == "overwrite": assert dst_node.metadata_props[0].value == "value1" # Should be overwritten - + assert dst_node.metadata_props[1].key == "key2" assert dst_node.metadata_props[1].value == "value2" - - + + def test_copy_metadata_props_bad_mode(): - src_node = onnx.NodeProto( - metadata_props=[add_metadata("key1", "value1")] - ) + src_node = onnx.NodeProto(metadata_props=[add_metadata("key1", "value1")]) dst_node = onnx.NodeProto() - + with pytest.raises(AssertionError): copy_metadata_props(src_node, dst_node, mode="invalid_mode") - -from onnxscript import script -from onnxscript import opset17 as op -from onnxscript import FLOAT -from qonnx.core.modelwrapper import ModelWrapper -from qonnx.transformation.gemm_to_matmul import GemmToMatMul - -def test_copy_metadata_props_gemm2matmul(): +def test_copy_metadata_props_gemm2matmul(): @script() def MyGemm(A: FLOAT[4, 5], B: FLOAT[5, 4], C: FLOAT[4, 4]) -> FLOAT[4, 4]: return op.Gemm(A, B, C) model_proto = MyGemm.to_model_proto() gemm_node = model_proto.graph.node[0] - gemm_node.metadata_props.extend([ - add_metadata("key1", "value1"), - add_metadata("key2", "value2") - ]) + gemm_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) # Create Model Wrapper mw = ModelWrapper(model_proto) - + transformed_mw = mw.transform(GemmToMatMul()) - + for node in transformed_mw.graph.node: - assert node.metadata_props[0].key == 'key1' - assert node.metadata_props[0].value == 'value1' - assert node.metadata_props[1].key == 'key2' - assert node.metadata_props[1].value == 'value2' - - -from onnx import helper as oh -import numpy as np -import onnxscript -from onnxscript.ir.passes.common import LiftConstantsToInitializersPass - - + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2" + + def test_copy_metadata_props_batchnorm2affine(): @script() def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: @@ -107,31 +88,29 @@ def MyBatchNorm(X: FLOAT[1, 3, 4, 4]) -> FLOAT[1, 3, 4, 4]: var = op.Constant(value=[[1.0, 1.0, 1.0]]) mean = op.Constant(value=[[0.0, 0.0, 0.0]]) return op.BatchNormalization(X, scale, B, mean, var, epsilon=1e-5, momentum=0.9) - + # remove cast-like nodes - model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) - + model_proto = onnxscript.optimizer.optimize(MyBatchNorm.to_model_proto()) + # batchnorm_to_affine requires initializers for scale/mean/var/bias model_ir = onnxscript.ir.serde.deserialize_model(model_proto) pass_ = LiftConstantsToInitializersPass(lift_all_constants=True, size_limit=1) PassResult = pass_.call(model_ir) model_proto = onnxscript.ir.serde.serialize_model(PassResult.model) - + # Add metadata to BatchNorm node bn_node = model_proto.graph.node[0] - bn_node.metadata_props.extend([ - add_metadata("key1", "value1"), - add_metadata("key2", "value2") - ]) - + bn_node.metadata_props.extend([add_metadata("key1", "value1"), add_metadata("key2", "value2")]) + # Create Model Wrapper mw = ModelWrapper(model_proto) from qonnx.transformation.batchnorm_to_affine import BatchNormToAffine + transformed_mw = mw.transform(BatchNormToAffine()) - + # Check that metadata was copied for node in transformed_mw.graph.node: - assert node.metadata_props[0].key == 'key1' - assert node.metadata_props[0].value == 'value1' - assert node.metadata_props[1].key == 'key2' - assert node.metadata_props[1].value == 'value2' \ No newline at end of file + assert node.metadata_props[0].key == "key1" + assert node.metadata_props[0].value == "value1" + assert node.metadata_props[1].key == "key2" + assert node.metadata_props[1].value == "value2"