diff --git a/src/qonnx/transformation/general.py b/src/qonnx/transformation/general.py index 5126bf27..d634ce9b 100644 --- a/src/qonnx/transformation/general.py +++ b/src/qonnx/transformation/general.py @@ -31,8 +31,7 @@ import warnings # Protobuf onnx graph node type -from onnx import NodeProto # noqa -from onnx import mapping +from onnx import AttributeProto, NodeProto, mapping # noqa from toposort import toposort_flatten import qonnx.util.basic as util @@ -335,56 +334,88 @@ def __init__(self, config, node_filter=lambda x: True): super().__init__() self.config = config self.node_filter = node_filter - - def apply(self, model): - if isinstance(self.config, dict): - model_config = self.config + self.used_configurations = ["Defaults"] + self.missing_configurations = [] + + def configure_network(self, graph_proto, model_config, subgraph_hier): + # Configure network - graph_proto can be a GraphProto or ModelWrapper + # If it's a ModelWrapper, get the graph + if hasattr(graph_proto, "graph"): + graph = graph_proto.graph else: - with open(self.config, "r") as f: - model_config = json.load(f) - - used_configurations = ["Defaults"] - missing_configurations = [] + graph = graph_proto - # Configure network - for node_idx, node in enumerate(model.graph.node): + for node in graph.node: if not self.node_filter(node): continue + + # Build the config key by prepending hierarchy + config_key = node.name if subgraph_hier is None else str(subgraph_hier) + "_" + node.name + try: - node_config = model_config[node.name] + node_config = model_config[config_key].copy() except KeyError: - missing_configurations += [node.name] + self.missing_configurations += [node.name] node_config = {} + if node_config: + self.used_configurations += [config_key] + from qonnx.custom_op.registry import getCustomOp try: inst = getCustomOp(node) + + if "Defaults" in model_config.keys(): + # set specified defaults + default_values = [] + for key, value in model_config["Defaults"].items(): + assert len(value) % 2 == 0 + if key not in model_config: + for val, op in zip(value[::2], value[1::2]): + default_values.append((key, val, op)) + assert not (op == "all" and len(value) > 2) + default_configs = {key: val for key, val, op in default_values if op == "all" or node.op_type in op} + for attr_name, value in default_configs.items(): + inst.set_nodeattr(attr_name, value) + + # set node attributes from specified configuration + for attr_name, value in node_config.items(): + inst.set_nodeattr(attr_name, value) except Exception: - continue - used_configurations += [node.name] - - # set specified defaults - default_values = [] - for key, value in model_config["Defaults"].items(): - assert len(value) % 2 == 0 - if key not in model_config: - for val, op in zip(value[::2], value[1::2]): - default_values.append((key, val, op)) - assert not (op == "all" and len(value) > 2) - default_configs = {key: val for key, val, op in default_values if op == "all" or node.op_type in op} - for attr, value in default_configs.items(): - inst.set_nodeattr(attr, value) - - # set node attributes from specified configuration - for attr, value in node_config.items(): - inst.set_nodeattr(attr, value) + # Node is not a custom op, but it might have subgraphs + pass + + # Recursively handle nested subgraphs + for attr in node.attribute: + if attr.type == AttributeProto.GRAPH: + # Build the subgraph hierarchy including the attribute name + if subgraph_hier is None: + new_hier = node.name + else: + new_hier = str(subgraph_hier) + "_" + node.name + # Include the subgraph attribute name in the hierarchy + new_hier = new_hier + "_" + attr.name + self.configure_network(attr.g, model_config, subgraph_hier=new_hier) + + def apply(self, model): + if isinstance(self.config, dict): + model_config = self.config + else: + with open(self.config, "r") as f: + model_config = json.load(f) + + # apply configuration on upper level + self.configure_network(model.model.graph, model_config, subgraph_hier=None) # Configuration verification - if len(missing_configurations) > 0: - warnings.warn("\nNo HW configuration for nodes: " + ", ".join(missing_configurations)) + # Remove duplicates from missing_configurations (can happen with shared subgraphs in If nodes) + unique_missing = list(dict.fromkeys(self.missing_configurations)) + if len(unique_missing) > 0: + warnings.warn("\nNo HW configuration for nodes: " + ", ".join(unique_missing)) - unused_configs = [x for x in model_config if x not in used_configurations] + # Check for unused configs (top-level configs that weren't applied) + unused_configs = [x for x in model_config if x not in self.used_configurations and x != "Defaults"] if len(unused_configs) > 0: warnings.warn("\nUnused HW configurations: " + ", ".join(unused_configs)) diff --git a/src/qonnx/util/config.py b/src/qonnx/util/config.py index 63661862..36413bc9 100644 --- a/src/qonnx/util/config.py +++ b/src/qonnx/util/config.py @@ -27,8 +27,53 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import onnx -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import getCustomOp, is_custom_op + + +# update this code to handle export configs from subgraphs +# where the subgraph is found in a node's attribute as a graph type +def extract_model_config(model, subgraph_hier, attr_names_to_extract): + """Create a dictionary with layer name -> attribute mappings extracted from the + model. The created dictionary can be later applied on a model with + qonnx.transform.general.ApplyConfig. + + Nodes in subgraphs are prefixed with their parent hierarchy using '_' as separator. + For example, a node 'Conv_0' inside a subgraph of node 'IfNode_0' will be exported + as 'IfNode_0_Conv_0' in the config.""" + + cfg = dict() + cfg["Defaults"] = dict() + for n in model.graph.node: + new_hier = n.name if subgraph_hier is None else str(subgraph_hier) + "_" + n.name + + # Check if this is a custom op and prepare to extract attributes + is_custom = is_custom_op(n.domain, n.op_type) + if is_custom: + oi = getCustomOp(n) + layer_dict = dict() + + # Process node attributes - handle both subgraphs and extractable attributes + for attr in n.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # If the attribute is a graph, extract configs from the subgraph recursively + # Include the subgraph attribute name in the hierarchy + subgraph_hier_with_attr = new_hier + "_" + attr.name + cfg.update( + extract_model_config( + model.make_subgraph_modelwrapper(attr.g), subgraph_hier_with_attr, attr_names_to_extract + ) + ) + elif is_custom and attr.name in attr_names_to_extract: + # For custom ops, extract the requested attribute + layer_dict[attr.name] = oi.get_nodeattr(attr.name) + + # Add the node's config if we extracted any attributes + if is_custom and len(layer_dict) > 0: + cfg[new_hier] = layer_dict + + return cfg def extract_model_config_to_json(model, json_filename, attr_names_to_extract): @@ -36,17 +81,5 @@ def extract_model_config_to_json(model, json_filename, attr_names_to_extract): model. The created json file can be later applied on a model with qonnx.transform.general.ApplyConfig.""" - cfg = dict() - cfg["Defaults"] = dict() - for n in model.graph.node: - oi = getCustomOp(n) - layer_dict = dict() - for attr in attr_names_to_extract: - try: - layer_dict[attr] = oi.get_nodeattr(attr) - except AttributeError: - pass - if len(layer_dict) > 0: - cfg[n.name] = layer_dict with open(json_filename, "w") as f: - json.dump(cfg, f, indent=2) + json.dump(extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attr_names_to_extract), f, indent=2) diff --git a/tests/util/test_config.py b/tests/util/test_config.py new file mode 100644 index 00000000..60230f55 --- /dev/null +++ b/tests/util/test_config.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025 Advanced Micro Devices, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of QONNX nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +import pytest +import tempfile + +import onnx +import onnx.helper as helper +import numpy as np +from onnxscript import script, FLOAT, BOOL +from onnxscript import opset13 as op +from onnxscript.values import Opset + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.util.basic import qonnx_make_model +from qonnx.util.config import extract_model_config_to_json, extract_model_config +from typing import List, Dict, Any, Tuple +# this is a pretend opset so that we can create +# qonnx custom ops with onnxscript +qops = Opset("qonnx.custom_op.general", 1) + +@script(default_opset=op) +def main_graph_fn(main_inp: FLOAT[1, 28, 28, 1], condition: BOOL, nested_condition: BOOL) -> FLOAT[1, 4, 4, 144]: + """Main graph with nested if statement in else branch.""" + im2col_0 = qops.Im2Col(main_inp, stride=[1, 1], kernel_size=[3, 3], + pad_amount=[1, 1, 1, 1], input_shape=[1, 28, 28, 1]) + + # Python if statement → ONNX If node with subgraph + # settings for Im2Col are meant to validate the extraction/application of attributes + # and are not necessarily realistic or correct + if condition: + # Then branch: simple subgraph (2 levels) + main_out = qops.Im2Col(im2col_0, stride=[2, 1], kernel_size=[5, 5], + pad_amount=[2, 2, 2, 2], input_shape=[1, 14, 14, 144]) + else: + im2col_1 = qops.Im2Col(im2col_0, stride=[2, 1], kernel_size=[6, 6], + pad_amount=[3, 3, 3, 3], input_shape=[1, 14, 14, 145]) + # Else branch: nested if statement (3 levels) + if nested_condition: + main_out = qops.Im2Col(im2col_1, stride=[3, 1], kernel_size=[7, 7], + pad_amount=[4, 4, 4, 4], input_shape=[1, 4, 4, 146]) + else: + main_out = qops.Im2Col(im2col_1, stride=[3, 2], kernel_size=[8, 8], + pad_amount=[5, 5, 5, 5], input_shape=[1, 4, 4, 147]) + + return main_out + + +def build_expected_config_from_node(node: onnx.NodeProto, prefix = '') -> Dict[str, Any]: + """Build expected config dictionary from a given ONNX node.""" + custom_op = getCustomOp(node) + attrs = {} + for attr in node.attribute: + attrs[attr.name] = custom_op.get_nodeattr(attr.name) + return {prefix + node.name: attrs} + + +def make_im2col_test_model(): + """Create a simple ONNX model with a single Im2Col node.""" + + model_proto = main_graph_fn.to_model_proto() + + im2col_node = model_proto.graph.node[0] + if_im2col_then_node = model_proto.graph.node[1].attribute[0].g.node[0] + if_im2col_else_node = model_proto.graph.node[1].attribute[1].g.node[0] + nested_if_im2col_then_node = model_proto.graph.node[1].attribute[1].g.node[1].attribute[0].g.node[0] + nested_if_im2col_else_node = model_proto.graph.node[1].attribute[1].g.node[1].attribute[1].g.node[0] + + # this test assumes that all Im2Col nodes have the same name + # to verify that node aliasing is handled correctly between nodes on + # the same and different levels of the hierarchy + assert im2col_node.name == if_im2col_then_node.name + assert im2col_node.name == if_im2col_else_node.name + assert im2col_node.name == nested_if_im2col_then_node.name + assert im2col_node.name == nested_if_im2col_else_node.name + + expected_config = {} + expected_config["Defaults"] = {} + expected_config.update(build_expected_config_from_node(im2col_node)) + expected_config.update(build_expected_config_from_node(if_im2col_then_node, prefix='n1_then_branch_')) + expected_config.update(build_expected_config_from_node(if_im2col_else_node, prefix='n1_else_branch_')) + expected_config.update(build_expected_config_from_node(nested_if_im2col_then_node, prefix='n1_else_branch_n1_then_branch_')) + expected_config.update(build_expected_config_from_node(nested_if_im2col_else_node, prefix='n1_else_branch_n1_else_branch_')) + + return ModelWrapper(model_proto), expected_config + +def test_extract_model_config(): + """Test extraction of model config from models with and without subgraphs.""" + + model, expected_config = make_im2col_test_model() + + attrs_to_extract = ["kernel_size", "stride", "pad_amount", "input_shape"] + + extracted_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract) + assert extracted_config == expected_config, "Extracted config does not match expected config" + + +def test_roundtrip_export_import(): + """Test config extraction and re-application preserves node attributes.""" + from qonnx.transformation.general import ApplyConfig + + model, expected_config = make_im2col_test_model() + attrs_to_extract = ["kernel_size", "stride", "pad_amount", "input_shape"] + + # Extract config from model + original_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract) + + # Modify all Im2Col nodes to different values (recursively through subgraphs) + def modify_all_im2col_nodes(graph_proto): + for node in graph_proto.node: + if node.op_type == "Im2Col": + inst = getCustomOp(node) + inst.set_nodeattr("kernel_size", [11, 11]) + inst.set_nodeattr("stride", [5, 5]) + inst.set_nodeattr("pad_amount", [7, 7, 7, 7]) + inst.set_nodeattr("input_shape", "") # input_shape is a string attribute + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + modify_all_im2col_nodes(attr.g) + + modify_all_im2col_nodes(model.graph) + + # Apply the original config via temp JSON file + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + config_with_defaults = original_config.copy() + config_with_defaults["Defaults"] = {} + json.dump(config_with_defaults, f, indent=2) + config_json_file = f.name + + try: + model = model.transform(ApplyConfig(config_json_file)) + + # Re-extract config and verify it matches original + restored_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract) + assert restored_config == original_config, "Config not properly restored after roundtrip" + finally: + if os.path.exists(config_json_file): + os.remove(config_json_file) +