Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ce1da1f
update extract folding config
Oct 31, 2025
a4b7f20
[Transform] Enable ApplyConfig to work on subgraph json
auphelia Nov 5, 2025
2f868af
[Transform] Ensure that subgraph_hier is not applied as a node attr i…
auphelia Nov 5, 2025
ad629dc
Merge branch 'main' of github.com:fastmachinelearning/qonnx into feat…
Nov 7, 2025
d607c27
Merge branch 'feature/config_extract_for_subgraphs' of github.com:fas…
Nov 7, 2025
cacd977
add tests for export config to json
Nov 8, 2025
1e0978a
have ai reduce test size.
Nov 8, 2025
c50e29b
adds check that subgraph_hier does not exist for top level nodes
Nov 8, 2025
c6a7365
remove the "Defaults" section
Nov 10, 2025
575962d
update calls to function call
Nov 10, 2025
6b8f715
ensure extra attributes aren't extracted
Nov 10, 2025
f4b5d6d
remove immediate node name from subgrpah hierarchy.
Nov 10, 2025
de9747c
remove model.saves; fix variable name
Nov 10, 2025
4250e13
add round trip tests export->apply config
Nov 10, 2025
f80cc5c
update tests with to look for hierarchy in node name rather than in n…
Nov 11, 2025
d1d17a3
tests passing now.
Nov 11, 2025
3e936e2
simplify extract config model
Nov 11, 2025
2d2ee89
simplifiy and remove unneeded code
Nov 11, 2025
7a5a3d2
ensure separate graphs for each branch of an if-statemet.
Nov 11, 2025
d85bb82
convert tests to onnxscript rather than onnx proto
Nov 11, 2025
848bf07
reduce test size.
Nov 12, 2025
10ff868
move everything to onnxscript.
Nov 12, 2025
5ab65ef
simplify to just two models.
Nov 12, 2025
66002d9
consolidate roundtrip tests into one test
Nov 12, 2025
c67b326
further test consolidation
Nov 12, 2025
36e06f1
reduce complexity of nested model.
Nov 12, 2025
0e0b77d
readd one branch with deep hierarchy
Nov 12, 2025
049acab
reduce tests and include attribute name in hierarhcy.
Nov 17, 2025
3239621
it's working
Nov 17, 2025
27e6753
update test attributes to be different
Nov 17, 2025
a94eb8f
[Util] Bring back default field for config extraction and run linting
auphelia Nov 18, 2025
7bf6349
[ApplyConfig] Check if Default field exists
auphelia Nov 18, 2025
c0da3ff
[Tests] Update config test with Defaults field
auphelia Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 67 additions & 36 deletions src/qonnx/transformation/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
import warnings

# Protobuf onnx graph node type
from onnx import NodeProto # noqa
from onnx import mapping
from onnx import AttributeProto, NodeProto, mapping # noqa
from toposort import toposort_flatten

import qonnx.util.basic as util
Expand Down Expand Up @@ -335,56 +334,88 @@ def __init__(self, config, node_filter=lambda x: True):
super().__init__()
self.config = config
self.node_filter = node_filter

def apply(self, model):
if isinstance(self.config, dict):
model_config = self.config
self.used_configurations = ["Defaults"]
self.missing_configurations = []

def configure_network(self, graph_proto, model_config, subgraph_hier):
# Configure network - graph_proto can be a GraphProto or ModelWrapper
# If it's a ModelWrapper, get the graph
if hasattr(graph_proto, "graph"):
graph = graph_proto.graph
else:
with open(self.config, "r") as f:
model_config = json.load(f)

used_configurations = ["Defaults"]
missing_configurations = []
graph = graph_proto

# Configure network
for node_idx, node in enumerate(model.graph.node):
for node in graph.node:
if not self.node_filter(node):
continue

# Build the config key by prepending hierarchy
config_key = node.name if subgraph_hier is None else str(subgraph_hier) + "_" + node.name

try:
node_config = model_config[node.name]
node_config = model_config[config_key].copy()
except KeyError:
missing_configurations += [node.name]
self.missing_configurations += [node.name]
node_config = {}

if node_config:
self.used_configurations += [config_key]

from qonnx.custom_op.registry import getCustomOp

try:
inst = getCustomOp(node)

if "Defaults" in model_config.keys():
# set specified defaults
default_values = []
for key, value in model_config["Defaults"].items():
assert len(value) % 2 == 0
if key not in model_config:
for val, op in zip(value[::2], value[1::2]):
default_values.append((key, val, op))
assert not (op == "all" and len(value) > 2)
default_configs = {key: val for key, val, op in default_values if op == "all" or node.op_type in op}
for attr_name, value in default_configs.items():
inst.set_nodeattr(attr_name, value)

# set node attributes from specified configuration
for attr_name, value in node_config.items():
inst.set_nodeattr(attr_name, value)
except Exception:
continue
used_configurations += [node.name]

# set specified defaults
default_values = []
for key, value in model_config["Defaults"].items():
assert len(value) % 2 == 0
if key not in model_config:
for val, op in zip(value[::2], value[1::2]):
default_values.append((key, val, op))
assert not (op == "all" and len(value) > 2)
default_configs = {key: val for key, val, op in default_values if op == "all" or node.op_type in op}
for attr, value in default_configs.items():
inst.set_nodeattr(attr, value)

# set node attributes from specified configuration
for attr, value in node_config.items():
inst.set_nodeattr(attr, value)
# Node is not a custom op, but it might have subgraphs
pass

# Recursively handle nested subgraphs
for attr in node.attribute:
if attr.type == AttributeProto.GRAPH:
# Build the subgraph hierarchy including the attribute name
if subgraph_hier is None:
new_hier = node.name
else:
new_hier = str(subgraph_hier) + "_" + node.name
# Include the subgraph attribute name in the hierarchy
new_hier = new_hier + "_" + attr.name
self.configure_network(attr.g, model_config, subgraph_hier=new_hier)

def apply(self, model):
if isinstance(self.config, dict):
model_config = self.config
else:
with open(self.config, "r") as f:
model_config = json.load(f)

# apply configuration on upper level
self.configure_network(model.model.graph, model_config, subgraph_hier=None)

# Configuration verification
if len(missing_configurations) > 0:
warnings.warn("\nNo HW configuration for nodes: " + ", ".join(missing_configurations))
# Remove duplicates from missing_configurations (can happen with shared subgraphs in If nodes)
unique_missing = list(dict.fromkeys(self.missing_configurations))
if len(unique_missing) > 0:
warnings.warn("\nNo HW configuration for nodes: " + ", ".join(unique_missing))

unused_configs = [x for x in model_config if x not in used_configurations]
# Check for unused configs (top-level configs that weren't applied)
unused_configs = [x for x in model_config if x not in self.used_configurations and x != "Defaults"]
if len(unused_configs) > 0:
warnings.warn("\nUnused HW configurations: " + ", ".join(unused_configs))

Expand Down
61 changes: 47 additions & 14 deletions src/qonnx/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,59 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import onnx

from qonnx.custom_op.registry import getCustomOp
from qonnx.custom_op.registry import getCustomOp, is_custom_op


# update this code to handle export configs from subgraphs
# where the subgraph is found in a node's attribute as a graph type
def extract_model_config(model, subgraph_hier, attr_names_to_extract):
"""Create a dictionary with layer name -> attribute mappings extracted from the
model. The created dictionary can be later applied on a model with
qonnx.transform.general.ApplyConfig.

Nodes in subgraphs are prefixed with their parent hierarchy using '_' as separator.
For example, a node 'Conv_0' inside a subgraph of node 'IfNode_0' will be exported
as 'IfNode_0_Conv_0' in the config."""

cfg = dict()
cfg["Defaults"] = dict()
for n in model.graph.node:
new_hier = n.name if subgraph_hier is None else str(subgraph_hier) + "_" + n.name

# Check if this is a custom op and prepare to extract attributes
is_custom = is_custom_op(n.domain, n.op_type)
if is_custom:
oi = getCustomOp(n)
layer_dict = dict()

# Process node attributes - handle both subgraphs and extractable attributes
for attr in n.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
# If the attribute is a graph, extract configs from the subgraph recursively
# Include the subgraph attribute name in the hierarchy
subgraph_hier_with_attr = new_hier + "_" + attr.name
cfg.update(
extract_model_config(
model.make_subgraph_modelwrapper(attr.g), subgraph_hier_with_attr, attr_names_to_extract
)
)
elif is_custom and attr.name in attr_names_to_extract:
# For custom ops, extract the requested attribute
layer_dict[attr.name] = oi.get_nodeattr(attr.name)

# Add the node's config if we extracted any attributes
if is_custom and len(layer_dict) > 0:
cfg[new_hier] = layer_dict

return cfg


def extract_model_config_to_json(model, json_filename, attr_names_to_extract):
"""Create a json file with layer name -> attribute mappings extracted from the
model. The created json file can be later applied on a model with
qonnx.transform.general.ApplyConfig."""

cfg = dict()
cfg["Defaults"] = dict()
for n in model.graph.node:
oi = getCustomOp(n)
layer_dict = dict()
for attr in attr_names_to_extract:
try:
layer_dict[attr] = oi.get_nodeattr(attr)
except AttributeError:
pass
if len(layer_dict) > 0:
cfg[n.name] = layer_dict
with open(json_filename, "w") as f:
json.dump(cfg, f, indent=2)
json.dump(extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attr_names_to_extract), f, indent=2)
167 changes: 167 additions & 0 deletions tests/util/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of QONNX nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import os
import pytest
import tempfile

import onnx
import onnx.helper as helper
import numpy as np
from onnxscript import script, FLOAT, BOOL
from onnxscript import opset13 as op
from onnxscript.values import Opset

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.custom_op.registry import getCustomOp
from qonnx.util.basic import qonnx_make_model
from qonnx.util.config import extract_model_config_to_json, extract_model_config
from typing import List, Dict, Any, Tuple
# this is a pretend opset so that we can create
# qonnx custom ops with onnxscript
qops = Opset("qonnx.custom_op.general", 1)

@script(default_opset=op)
def main_graph_fn(main_inp: FLOAT[1, 28, 28, 1], condition: BOOL, nested_condition: BOOL) -> FLOAT[1, 4, 4, 144]:
"""Main graph with nested if statement in else branch."""
im2col_0 = qops.Im2Col(main_inp, stride=[1, 1], kernel_size=[3, 3],
pad_amount=[1, 1, 1, 1], input_shape=[1, 28, 28, 1])

# Python if statement → ONNX If node with subgraph
# settings for Im2Col are meant to validate the extraction/application of attributes
# and are not necessarily realistic or correct
if condition:
# Then branch: simple subgraph (2 levels)
main_out = qops.Im2Col(im2col_0, stride=[2, 1], kernel_size=[5, 5],
pad_amount=[2, 2, 2, 2], input_shape=[1, 14, 14, 144])
else:
im2col_1 = qops.Im2Col(im2col_0, stride=[2, 1], kernel_size=[6, 6],
pad_amount=[3, 3, 3, 3], input_shape=[1, 14, 14, 145])
# Else branch: nested if statement (3 levels)
if nested_condition:
main_out = qops.Im2Col(im2col_1, stride=[3, 1], kernel_size=[7, 7],
pad_amount=[4, 4, 4, 4], input_shape=[1, 4, 4, 146])
else:
main_out = qops.Im2Col(im2col_1, stride=[3, 2], kernel_size=[8, 8],
pad_amount=[5, 5, 5, 5], input_shape=[1, 4, 4, 147])

return main_out


def build_expected_config_from_node(node: onnx.NodeProto, prefix = '') -> Dict[str, Any]:
"""Build expected config dictionary from a given ONNX node."""
custom_op = getCustomOp(node)
attrs = {}
for attr in node.attribute:
attrs[attr.name] = custom_op.get_nodeattr(attr.name)
return {prefix + node.name: attrs}


def make_im2col_test_model():
"""Create a simple ONNX model with a single Im2Col node."""

model_proto = main_graph_fn.to_model_proto()

im2col_node = model_proto.graph.node[0]
if_im2col_then_node = model_proto.graph.node[1].attribute[0].g.node[0]
if_im2col_else_node = model_proto.graph.node[1].attribute[1].g.node[0]
nested_if_im2col_then_node = model_proto.graph.node[1].attribute[1].g.node[1].attribute[0].g.node[0]
nested_if_im2col_else_node = model_proto.graph.node[1].attribute[1].g.node[1].attribute[1].g.node[0]

# this test assumes that all Im2Col nodes have the same name
# to verify that node aliasing is handled correctly between nodes on
# the same and different levels of the hierarchy
assert im2col_node.name == if_im2col_then_node.name
assert im2col_node.name == if_im2col_else_node.name
assert im2col_node.name == nested_if_im2col_then_node.name
assert im2col_node.name == nested_if_im2col_else_node.name

expected_config = {}
expected_config["Defaults"] = {}
expected_config.update(build_expected_config_from_node(im2col_node))
expected_config.update(build_expected_config_from_node(if_im2col_then_node, prefix='n1_then_branch_'))
expected_config.update(build_expected_config_from_node(if_im2col_else_node, prefix='n1_else_branch_'))
expected_config.update(build_expected_config_from_node(nested_if_im2col_then_node, prefix='n1_else_branch_n1_then_branch_'))
expected_config.update(build_expected_config_from_node(nested_if_im2col_else_node, prefix='n1_else_branch_n1_else_branch_'))

return ModelWrapper(model_proto), expected_config

def test_extract_model_config():
"""Test extraction of model config from models with and without subgraphs."""

model, expected_config = make_im2col_test_model()

attrs_to_extract = ["kernel_size", "stride", "pad_amount", "input_shape"]

extracted_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract)
assert extracted_config == expected_config, "Extracted config does not match expected config"


def test_roundtrip_export_import():
"""Test config extraction and re-application preserves node attributes."""
from qonnx.transformation.general import ApplyConfig

model, expected_config = make_im2col_test_model()
attrs_to_extract = ["kernel_size", "stride", "pad_amount", "input_shape"]

# Extract config from model
original_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract)

# Modify all Im2Col nodes to different values (recursively through subgraphs)
def modify_all_im2col_nodes(graph_proto):
for node in graph_proto.node:
if node.op_type == "Im2Col":
inst = getCustomOp(node)
inst.set_nodeattr("kernel_size", [11, 11])
inst.set_nodeattr("stride", [5, 5])
inst.set_nodeattr("pad_amount", [7, 7, 7, 7])
inst.set_nodeattr("input_shape", "") # input_shape is a string attribute
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
modify_all_im2col_nodes(attr.g)

modify_all_im2col_nodes(model.graph)

# Apply the original config via temp JSON file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
config_with_defaults = original_config.copy()
config_with_defaults["Defaults"] = {}
json.dump(config_with_defaults, f, indent=2)
config_json_file = f.name

try:
model = model.transform(ApplyConfig(config_json_file))

# Re-extract config and verify it matches original
restored_config = extract_model_config(model, subgraph_hier=None, attr_names_to_extract=attrs_to_extract)
assert restored_config == original_config, "Config not properly restored after roundtrip"
finally:
if os.path.exists(config_json_file):
os.remove(config_json_file)

Loading