From d6d95c57a5aa160afcb5adcde2a6383385277c58 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 17:12:15 -0700 Subject: [PATCH 01/14] Fix channels_last transformation for new registry --- docs/overview.rst | 7 ++ setup.cfg | 4 ++ src/qonnx/__init__.py | 26 +++++++ src/qonnx/custom_op/channels_last/__init__.py | 6 +- .../channels_last/batch_normalization.py | 2 + src/qonnx/custom_op/channels_last/conv.py | 2 + src/qonnx/custom_op/channels_last/max_pool.py | 2 + src/qonnx/custom_op/general/__init__.py | 24 +++---- src/qonnx/custom_op/general/bipolar_quant.py | 2 + src/qonnx/custom_op/general/debugmarker.py | 2 + .../custom_op/general/genericpartition.py | 2 + src/qonnx/custom_op/general/im2col.py | 2 + src/qonnx/custom_op/general/maxpoolnhwc.py | 2 + src/qonnx/custom_op/general/multithreshold.py | 2 + src/qonnx/custom_op/general/quant.py | 2 + src/qonnx/custom_op/general/quantavgpool2d.py | 2 + src/qonnx/custom_op/general/trunc.py | 2 + src/qonnx/custom_op/general/xnorpopcount.py | 2 + src/qonnx/custom_op/registry.py | 70 ++++++++++++++++--- src/qonnx/transformation/channels_last.py | 3 +- tests/custom_op/test_attr.py | 5 +- 21 files changed, 142 insertions(+), 29 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 8e2002d7..935ef4d9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,6 +45,13 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. +Custom ops can be registered automatically via Python entry points using the +``qonnx_custom_ops`` group. Each operator class should be decorated with +``@register_op(domain="...", op_type="...")`` from +``qonnx.custom_op.registry``. Packages installed with such an entry point will +be discovered on import and their ops made available through +``getCustomOp``. + Custom ONNX Execution Flow ========================== diff --git a/setup.cfg b/setup.cfg index 9b71bb56..a3038f13 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,6 +98,10 @@ console_scripts = qonnx-tensor-stats = qonnx.analysis.tensor_stats:main pytest_randomly.random_seeder = qonnx = qonnx.util.random_reseed:reseed +# entry points for custom op modules +qonnx_custom_ops = + qonnx = qonnx.custom_op.general + qonnx_channels_last = qonnx.custom_op.channels_last # Add here console scripts like: # console_scripts = # script_name = qonnx.module:function diff --git a/src/qonnx/__init__.py b/src/qonnx/__init__.py index e69de29b..bb2c88d0 100644 --- a/src/qonnx/__init__.py +++ b/src/qonnx/__init__.py @@ -0,0 +1,26 @@ +"""QONNX package initialization.""" + +import warnings +from importlib import metadata + + +def _load_custom_op_entry_points(): + """Import modules registered under the ``qonnx_custom_ops`` entry point.""" + + try: + eps = metadata.entry_points() + if hasattr(eps, "select"): + eps = eps.select(group="qonnx_custom_ops") + else: + eps = eps.get("qonnx_custom_ops", []) + for ep in eps: + try: + ep.load() + except Exception as e: # pragma: no cover - import failure warning + warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") + except Exception as e: # pragma: no cover - metadata failure warning + warnings.warn(f"Failed to query custom op entry points: {e}") + + +_load_custom_op_entry_points() + diff --git a/src/qonnx/custom_op/channels_last/__init__.py b/src/qonnx/custom_op/channels_last/__init__.py index f1d7c39b..1b2ebe01 100644 --- a/src/qonnx/custom_op/channels_last/__init__.py +++ b/src/qonnx/custom_op/channels_last/__init__.py @@ -2,8 +2,4 @@ from qonnx.custom_op.channels_last.conv import Conv from qonnx.custom_op.channels_last.max_pool import MaxPool -custom_op = dict() - -custom_op["Conv"] = Conv -custom_op["MaxPool"] = MaxPool -custom_op["BatchNormalization"] = BatchNormalization +__all__ = ["Conv", "MaxPool", "BatchNormalization"] diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index f3b3f872..bd5d3b60 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,8 +30,10 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.channels_last", op_type="BatchNormalization") class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index b0ff237b..06a25508 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,8 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.channels_last", op_type="Conv") class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index 383f3008..1bb9a1ce 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -30,9 +30,11 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim +@register_op(domain="qonnx.custom_op.channels_last", op_type="MaxPool") class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index a656d4a5..09b9380c 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -37,15 +37,15 @@ from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -custom_op = dict() - -custom_op["DebugMarker"] = DebugMarker -custom_op["QuantAvgPool2d"] = QuantAvgPool2d -custom_op["MaxPoolNHWC"] = MaxPoolNHWC -custom_op["GenericPartition"] = GenericPartition -custom_op["MultiThreshold"] = MultiThreshold -custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul -custom_op["Im2Col"] = Im2Col -custom_op["Quant"] = Quant -custom_op["Trunc"] = Trunc -custom_op["BipolarQuant"] = BipolarQuant +__all__ = [ + "DebugMarker", + "QuantAvgPool2d", + "MaxPoolNHWC", + "GenericPartition", + "MultiThreshold", + "XnorPopcountMatMul", + "Im2Col", + "Quant", + "Trunc", + "BipolarQuant", +] diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index 986a7082..e6a72486 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def binary_quant(inp_tensor, scale): @@ -47,6 +48,7 @@ def binary_quant(inp_tensor, scale): return out_tensor +@register_op(domain="qonnx.custom_op.general", op_type="BipolarQuant") class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index ae8cbce5..15e88d8e 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,8 +29,10 @@ from onnx import helper from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.general", op_type="DebugMarker") class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 841e4e9b..0f6fa104 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,8 +29,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op +@register_op(domain="qonnx.custom_op.general", op_type="GenericPartition") class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 42477832..22b08a5a 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,6 +31,7 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -140,6 +141,7 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case +@register_op(domain="qonnx.custom_op.general", op_type="Im2Col") class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index eb964fc4..81a6c4cb 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,6 +33,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.util.basic import qonnx_make_model @@ -44,6 +45,7 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) +@register_op(domain="qonnx.custom_op.general", op_type="MaxPoolNHWC") class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 6df58f95..0a5ec596 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -84,6 +85,7 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias +@register_op(domain="qonnx.custom_op.general", op_type="MultiThreshold") class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index f81495d2..39c9f0f4 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -165,6 +166,7 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") +@register_op(domain="qonnx.custom_op.general", op_type="Quant") class Quant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index c0e24071..344d999b 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -32,10 +32,12 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim from qonnx.util.basic import qonnx_make_model +@register_op(domain="qonnx.custom_op.general", op_type="QuantAvgPool2d") class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 8e2eaa19..ca2310b0 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op from qonnx.custom_op.general.quant import resolve_rounding_mode @@ -58,6 +59,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y +@register_op(domain="qonnx.custom_op.general", op_type="Trunc") class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index 9a640599..a91d412b 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,6 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op def xnorpopcountmatmul(inp0, inp1): @@ -60,6 +61,7 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 +@register_op(domain="qonnx.custom_op.general", op_type="XnorPopcountMatMul") class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..a3403918 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -27,24 +27,78 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import importlib +import warnings +from importlib import metadata from qonnx.util.basic import get_preferred_onnx_opset +# global registry mapping (domain, op_type) -> CustomOp subclass +CUSTOM_OP_REGISTRY = {} + + +def register_op(domain, op_type): + """Decorator for registering CustomOp classes.""" + + def decorator(cls): + CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + return cls + + return decorator + + +def _load_entry_points(): + """Load custom op modules registered via entry points.""" + + try: + eps = metadata.entry_points() + # compatibility between Python versions + if hasattr(eps, "select"): + eps = eps.select(group="qonnx_custom_ops") + else: + eps = eps.get("qonnx_custom_ops", []) + for ep in eps: + try: + ep.load() + except Exception as e: # pragma: no cover - import failure warning + warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") + except Exception as e: # pragma: no cover - metadata failure warning + warnings.warn(f"Failed to query custom op entry points: {e}") + + +# load entry points on module import +_load_entry_points() + def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - "Return a QONNX CustomOp instance for the given ONNX node, if it exists." + """Return a QONNX CustomOp instance for the given ONNX node, if it exists.""" + op_type = node.op_type domain = node.domain if brevitas_exception: # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + + key = (domain, op_type) + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + try: opset_module = importlib.import_module(domain) - assert type(opset_module.custom_op) is dict, "custom_op dict not found in Python module %s" % domain - inst_wrapper = opset_module.custom_op[op_type] - inst = inst_wrapper(node, onnx_opset_version=onnx_opset_version) - return inst except ModuleNotFoundError: - raise Exception("Could not load custom opset %s, check your PYTHONPATH" % domain) - except KeyError: - raise Exception("Op %s not found in custom opset %s" % (op_type, domain)) + raise Exception(f"Could not load custom opset {domain}, check your PYTHONPATH") + + # op may have registered itself on import + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + + # fallback to legacy custom_op dictionary + if hasattr(opset_module, "custom_op") and isinstance(opset_module.custom_op, dict): + try: + inst_wrapper = opset_module.custom_op[op_type] + return inst_wrapper(node, onnx_opset_version=onnx_opset_version) + except KeyError: + pass + + raise Exception(f"Op {op_type} not found in custom opset {domain}") diff --git a/src/qonnx/transformation/channels_last.py b/src/qonnx/transformation/channels_last.py index 175af058..5d585d0c 100644 --- a/src/qonnx/transformation/channels_last.py +++ b/src/qonnx/transformation/channels_last.py @@ -44,7 +44,8 @@ from qonnx.util.onnx import is_eltwise_optype # Standard ONNX nodes which require a ChannelsLast data format to function properly -_channelsLast_node_types = list(channels_last.custom_op.keys()) +# use the list of exported op names from the channels_last package +_channelsLast_node_types = list(channels_last.__all__) # Nodes, which do not modify the shape of the tensor # And modify all values in the same way. diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index cde5a321..4ea18d30 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -29,12 +29,12 @@ import numpy as np import onnx.parser as oprs -import qonnx.custom_op.general as general from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp +from qonnx.custom_op.registry import getCustomOp, register_op +@register_op(domain="qonnx.custom_op.general", op_type="AttrTestOp") class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} @@ -60,7 +60,6 @@ def verify_node(self): def test_attr(): - general.custom_op["AttrTestOp"] = AttrTestOp ishp = (1, 10) wshp = (1, 3) oshp = wshp From 858cf562508f8e70d0600dcab59a00823e905609 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 18:17:15 -0700 Subject: [PATCH 02/14] Add legacy domain fallback test --- tests/custom_op/legacy_custom_op.py | 21 ++++++++++++++++++ tests/custom_op/test_old_domain.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/custom_op/legacy_custom_op.py create mode 100644 tests/custom_op/test_old_domain.py diff --git a/tests/custom_op/legacy_custom_op.py b/tests/custom_op/legacy_custom_op.py new file mode 100644 index 00000000..95a1302b --- /dev/null +++ b/tests/custom_op/legacy_custom_op.py @@ -0,0 +1,21 @@ +from qonnx.custom_op.base import CustomOp +from qonnx.custom_op.registry import register_op + +@register_op(domain="legacy_custom_op", op_type="LegacyAdd") +class LegacyAdd(CustomOp): + def get_nodeattr_types(self): + return {} + + def make_shape_compatible_op(self, model): + return super().make_const_shape_op([1]) + + def infer_node_datatype(self, model): + pass + + def execute_node(self, context, graph): + a = context[self.onnx_node.input[0]] + b = context[self.onnx_node.input[1]] + context[self.onnx_node.output[0]] = a + b + + def verify_node(self): + pass diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py new file mode 100644 index 00000000..eb3479e9 --- /dev/null +++ b/tests/custom_op/test_old_domain.py @@ -0,0 +1,33 @@ +import sys +from onnx import helper, TensorProto + +from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import getCustomOp +from qonnx.util.basic import qonnx_make_model + + +def test_get_custom_op_old_domain(): + print('sys.path0', sys.path[0]) + assert "legacy_custom_op" not in sys.modules + + node = helper.make_node( + "LegacyAdd", + ["a", "b"], + ["c"], + domain="legacy_custom_op", + ) + + graph = helper.make_graph( + [node], + "legacy_graph", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b", TensorProto.FLOAT, [1]), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [1])], + ) + model = qonnx_make_model(graph, producer_name="legacy-test") + model = ModelWrapper(model) + + inst = getCustomOp(model.graph.node[0]) + assert inst.__class__.__name__ == "LegacyAdd" From 5036a7af98c63c1e02e557a474629ff63f530d35 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sun, 15 Jun 2025 18:54:16 -0700 Subject: [PATCH 03/14] Remove debug output from old domain test --- tests/custom_op/test_old_domain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py index eb3479e9..5f9ec57d 100644 --- a/tests/custom_op/test_old_domain.py +++ b/tests/custom_op/test_old_domain.py @@ -7,7 +7,6 @@ def test_get_custom_op_old_domain(): - print('sys.path0', sys.path[0]) assert "legacy_custom_op" not in sys.modules node = helper.make_node( From e59e558c75d154e7eafdb79671373032a0c17106 Mon Sep 17 00:00:00 2001 From: tafk7 Date: Fri, 20 Jun 2025 17:54:52 +0000 Subject: [PATCH 04/14] Added passthrough Quant class --- src/qonnx/custom_op/general/__init__.py | 1 + src/qonnx/custom_op/general/quant.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index 0e5d9f53..c2bb7a82 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -35,6 +35,7 @@ from qonnx.custom_op.general.maxpoolnhwc import MaxPoolNHWC from qonnx.custom_op.general.multithreshold import MultiThreshold from qonnx.custom_op.general.quantavgpool2d import QuantAvgPool2d +from qonnx.custom_op.general.quant import Quant from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 3d448dc3..d204bce6 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,11 +26,18 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.general.intquant import IntQuant as Quant +from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode +from qonnx.custom_op.registry import register_op -Quant = Quant +# Create alias and register it separately for "Quant" op_type +@register_op(domain="qonnx.custom_op.general", op_type="Quant") +class Quant(IntQuant): + """Alias for IntQuant to support legacy \"Quant\" op_type.""" + pass + +# Re-export functions quant = quant max_int = max_int min_int = min_int From 30df133a2c11a2660d4b23bd8d42a05672566c97 Mon Sep 17 00:00:00 2001 From: auphelia Date: Mon, 23 Jun 2025 10:06:54 +0100 Subject: [PATCH 05/14] Bring back lost changes from custom/brainsmith branch --- setup.cfg | 2 +- .../transformation/extract_quant_scale_zeropt.py | 8 ++++++++ src/qonnx/transformation/gemm_to_matmul.py | 13 ++++++++++++- src/qonnx/util/basic.py | 2 +- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index a8b8f915..2fc28b09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = importlib-metadata attrs>=22.2.0 clize>=5.0.1 - protobuf==3.20.3 + protobuf>=3.20.3 bitstring>=3.1.7 numpy>=1.24.1 onnx>=1.13.0 diff --git a/src/qonnx/transformation/extract_quant_scale_zeropt.py b/src/qonnx/transformation/extract_quant_scale_zeropt.py index 58863f08..614df416 100644 --- a/src/qonnx/transformation/extract_quant_scale_zeropt.py +++ b/src/qonnx/transformation/extract_quant_scale_zeropt.py @@ -69,6 +69,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_scaled) inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm]) + if hasattr(node, "metadata_props"): + inp_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_scale_node) # create new Mul node # remove scale from Quant node @@ -87,6 +89,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(inp_zeropt) inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm]) + if hasattr(node, "metadata_props"): + inp_zeropt_node.metadata_props.extend(node.metadata_props) graph.node.append(inp_zeropt_node) # remove zeropt from Quant node new_zeropt_nm = model.make_new_valueinfo_name() @@ -108,6 +112,8 @@ def apply(self, model: ModelWrapper): ) graph.value_info.append(out_zeropt) out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_zeropt_node.metadata_props.extend(node.metadata_props) last_node.output[0] = out_zeropt_nm graph.node.append(out_zeropt_node) # important: when tracking a pointer to newly added nodes, @@ -127,6 +133,8 @@ def apply(self, model: ModelWrapper): last_node.output[0] = out_scale_nm graph.value_info.append(out_scale) out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output]) + if hasattr(node, "metadata_props"): + out_scale_node.metadata_props.extend(node.metadata_props) graph.node.append(out_scale_node) if extract_scale or extract_zeropt: diff --git a/src/qonnx/transformation/gemm_to_matmul.py b/src/qonnx/transformation/gemm_to_matmul.py index 5396a7d6..1298f3d6 100644 --- a/src/qonnx/transformation/gemm_to_matmul.py +++ b/src/qonnx/transformation/gemm_to_matmul.py @@ -76,6 +76,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[0]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[0]) @@ -98,6 +100,8 @@ def apply(self, model): ) graph.value_info.append(inp_trans_out) inp_trans_node = helper.make_node("Transpose", [n.input[1]], [inp_trans_out.name]) + if hasattr(n, "metadata_props"): + inp_trans_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, inp_trans_node) running_node_index += 1 # Copy over the datatype @@ -109,6 +113,8 @@ def apply(self, model): # Insert MatMul: A * B matMul_node = helper.make_node("MatMul", [n.input[0], n.input[1]], [n.output[0]]) + if hasattr(n, "metadata_props"): + matMul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, matMul_node) matMul_node = graph.node[running_node_index] running_node_index += 1 @@ -144,6 +150,8 @@ def apply(self, model): [act_mul_tensor.name, mul_tensor.name], [n.output[0]], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) mul_node_main_branch = graph.node[running_node_index] running_node_index += 1 @@ -175,6 +183,8 @@ def apply(self, model): [n.input[2], mul_tensor.name], [act_mul_tensor.name], ) + if hasattr(n, "metadata_props"): + mul_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, mul_node) running_node_index += 1 dt = model.get_tensor_datatype(n.input[2]) @@ -196,7 +206,8 @@ def apply(self, model): [act_add_tensor.name, n.input[2]], [n.output[0]], ) - + if hasattr(n, "metadata_props"): + add_node.metadata_props.extend(n.metadata_props) graph.node.insert(running_node_index, add_node) running_node_index += 1 diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 303331c5..a7ee197b 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -64,7 +64,7 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") + return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") or op_type.startswith("brainsmith") def get_num_default_workers(): From dad06c748d933b2807cb7b7f1077b2ced1abcc8d Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Tue, 8 Jul 2025 00:09:15 +0000 Subject: [PATCH 06/14] Refined domain-based registration --- src/qonnx/custom_op/__init__.py | 35 +++++++++++++ src/qonnx/custom_op/registry.py | 90 ++++++++++++++++++++++++--------- src/qonnx/util/basic.py | 3 +- 3 files changed, 103 insertions(+), 25 deletions(-) diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index e69de29b..95378696 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020 Xilinx, Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of Xilinx nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from qonnx.custom_op.registry import register_custom_domain + +# Pre-register known custom op domains +register_custom_domain("qonnx.custom_op") +register_custom_domain("finn") +register_custom_domain("brainsmith") +register_custom_domain("onnx.brevitas") \ No newline at end of file diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index a3403918..f7ad3ad8 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -35,38 +35,80 @@ # global registry mapping (domain, op_type) -> CustomOp subclass CUSTOM_OP_REGISTRY = {} - -def register_op(domain, op_type): - """Decorator for registering CustomOp classes.""" +# global registry for custom op domains +_CUSTOM_DOMAINS = set() + +# global registry for custom op metadata +_OP_METADATA = {} + + +def register_custom_domain(domain): + """Register a domain as containing custom ops.""" + _CUSTOM_DOMAINS.add(domain) + + +def is_custom_op_domain(domain): + """Check if domain is registered for custom ops.""" + return any(domain.startswith(d) for d in _CUSTOM_DOMAINS) + + +def hasCustomOp(domain, op_type): + """Check if a custom op exists without creating an instance. + + Args: + domain: The domain of the custom op + op_type: The op_type of the custom op + + Returns: + bool: True if the op is registered, False otherwise + """ + return (domain, op_type) in CUSTOM_OP_REGISTRY + + +def get_ops_in_domain(domain): + """Get all registered ops in a domain. + + Args: + domain: The domain to query + + Returns: + List[Tuple[str, Type[CustomOp]]]: List of (op_type, class) tuples + """ + return [(op_type, cls) for (d, op_type), cls in CUSTOM_OP_REGISTRY.items() + if d == domain] + + +def register_op(domain, op_type, metadata=None): + """Decorator for registering CustomOp classes. + + Args: + domain: The domain for the custom op + op_type: The op_type for the custom op + metadata: Optional dict of metadata about the op (backend, version, etc.) + """ def decorator(cls): + # Auto-register the domain when an op is registered + register_custom_domain(domain) CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + if metadata is not None: + _OP_METADATA[(domain, op_type)] = metadata return cls return decorator -def _load_entry_points(): - """Load custom op modules registered via entry points.""" - - try: - eps = metadata.entry_points() - # compatibility between Python versions - if hasattr(eps, "select"): - eps = eps.select(group="qonnx_custom_ops") - else: - eps = eps.get("qonnx_custom_ops", []) - for ep in eps: - try: - ep.load() - except Exception as e: # pragma: no cover - import failure warning - warnings.warn(f"Failed to load custom op entry point {ep.name}: {e}") - except Exception as e: # pragma: no cover - metadata failure warning - warnings.warn(f"Failed to query custom op entry points: {e}") - - -# load entry points on module import -_load_entry_points() +def get_op_metadata(domain, op_type): + """Get metadata for a registered custom op. + + Args: + domain: The domain of the custom op + op_type: The op_type of the custom op + + Returns: + dict: The metadata dict if available, None otherwise + """ + return _OP_METADATA.get((domain, op_type)) def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index a7ee197b..abc5449f 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -64,7 +64,8 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") or op_type.startswith("brainsmith") + from qonnx.custom_op.registry import is_custom_op_domain + return is_custom_op_domain(op_type) def get_num_default_workers(): From f6806f6186955e7665b2edbe603cc9d1e2d9e28a Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Wed, 16 Jul 2025 05:25:28 +0000 Subject: [PATCH 07/14] Refined custom_op registration --- docs/overview.rst | 6 +- src/qonnx/custom_op/__init__.py | 17 ++- .../channels_last/batch_normalization.py | 4 +- src/qonnx/custom_op/channels_last/conv.py | 4 +- src/qonnx/custom_op/channels_last/max_pool.py | 4 +- src/qonnx/custom_op/general/bipolar_quant.py | 4 +- src/qonnx/custom_op/general/debugmarker.py | 4 +- src/qonnx/custom_op/general/floatquant.py | 4 +- .../custom_op/general/genericpartition.py | 4 +- src/qonnx/custom_op/general/im2col.py | 4 +- src/qonnx/custom_op/general/intquant.py | 4 +- src/qonnx/custom_op/general/maxpoolnhwc.py | 4 +- src/qonnx/custom_op/general/multithreshold.py | 4 +- src/qonnx/custom_op/general/quant.py | 4 +- src/qonnx/custom_op/general/quantavgpool2d.py | 4 +- src/qonnx/custom_op/general/trunc.py | 4 +- src/qonnx/custom_op/general/xnorpopcount.py | 4 +- src/qonnx/custom_op/registry.py | 129 ++++++++++++++---- src/qonnx/util/basic.py | 6 +- tests/custom_op/legacy_custom_op.py | 22 --- tests/custom_op/test_attr.py | 4 +- tests/custom_op/test_old_domain.py | 32 ----- 22 files changed, 148 insertions(+), 128 deletions(-) delete mode 100644 tests/custom_op/legacy_custom_op.py delete mode 100644 tests/custom_op/test_old_domain.py diff --git a/docs/overview.rst b/docs/overview.rst index 935ef4d9..5fd87de9 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -47,9 +47,9 @@ QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defin Custom ops can be registered automatically via Python entry points using the ``qonnx_custom_ops`` group. Each operator class should be decorated with -``@register_op(domain="...", op_type="...")`` from -``qonnx.custom_op.registry``. Packages installed with such an entry point will -be discovered on import and their ops made available through +``@register_custom_op`` from ``qonnx.custom_op.registry``, which automatically +infers the domain from the module path. Packages installed with such an entry +point will be discovered on import and their ops made available through ``getCustomOp``. diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index 95378696..be4f2162 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -26,10 +26,15 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.registry import register_custom_domain +from qonnx.custom_op.registry import register_domain -# Pre-register known custom op domains -register_custom_domain("qonnx.custom_op") -register_custom_domain("finn") -register_custom_domain("brainsmith") -register_custom_domain("onnx.brevitas") \ No newline at end of file +# Register QONNX domains (module path defaults to domain name) +register_domain("qonnx.custom_op.general") +register_domain("qonnx.custom_op.channels_last") + +# Register parent domain for hierarchy checking +register_domain("qonnx.custom_op") + +# Special case: Brevitas compatibility domain +# (QONNX handles Brevitas ops for backward compatibility) +register_domain("onnx.brevitas") \ No newline at end of file diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index bd5d3b60..b97ab1a8 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,10 +30,10 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="BatchNormalization") +@register_custom_op class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index 06a25508..dc78a0fd 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,10 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="Conv") +@register_custom_op class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index aec2c908..53a7b617 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -31,10 +31,10 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.channels_last", op_type="MaxPool") +@register_custom_op class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index e6a72486..102f5210 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def binary_quant(inp_tensor, scale): @@ -48,7 +48,7 @@ def binary_quant(inp_tensor, scale): return out_tensor -@register_op(domain="qonnx.custom_op.general", op_type="BipolarQuant") +@register_custom_op class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index 15e88d8e..3da80521 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,10 +29,10 @@ from onnx import helper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="DebugMarker") +@register_custom_op class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index 56698efb..ab74b1df 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -33,7 +33,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def compute_default_exponent_bias(exponent_bitwidth): @@ -121,7 +121,7 @@ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): return x_q * scale # , self.saturating, self.inf_values, self.nan_values -@register_op(domain="qonnx.custom_op.general", op_type="FloatQuant") +@register_custom_op class FloatQuant(CustomOp): """Floating point quantization operation for QONNX. diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 0f6fa104..3418f9a5 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,10 +29,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="GenericPartition") +@register_custom_op class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 22b08a5a..276caf7d 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,7 +31,7 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -141,7 +141,7 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case -@register_op(domain="qonnx.custom_op.general", op_type="Im2Col") +@register_custom_op class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py index 7663e95f..a053e7ef 100644 --- a/src/qonnx/custom_op/general/intquant.py +++ b/src/qonnx/custom_op/general/intquant.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -166,7 +166,7 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") -@register_op(domain="qonnx.custom_op.general", op_type="IntQuant") +@register_custom_op class IntQuant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index 81a6c4cb..44aa04bc 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,7 +33,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model @@ -45,7 +45,7 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) -@register_op(domain="qonnx.custom_op.general", op_type="MaxPoolNHWC") +@register_custom_op class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 0a5ec596..76df4752 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -85,7 +85,7 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias -@register_op(domain="qonnx.custom_op.general", op_type="MultiThreshold") +@register_custom_op class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index d204bce6..858eafc2 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -29,10 +29,10 @@ from qonnx.custom_op.general.intquant import IntQuant from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op # Create alias and register it separately for "Quant" op_type -@register_op(domain="qonnx.custom_op.general", op_type="Quant") +@register_custom_op class Quant(IntQuant): """Alias for IntQuant to support legacy \"Quant\" op_type.""" pass diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index b152171f..7e51f3e3 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,11 +33,11 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model -@register_op(domain="qonnx.custom_op.general", op_type="QuantAvgPool2d") +@register_custom_op class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index 6a59e91b..d2921262 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -32,7 +32,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): @@ -59,7 +59,7 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y -@register_op(domain="qonnx.custom_op.general", op_type="Trunc") +@register_custom_op class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index a91d412b..c068cb9d 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,7 +31,7 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op +from qonnx.custom_op.registry import register_custom_op def xnorpopcountmatmul(inp0, inp1): @@ -61,7 +61,7 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 -@register_op(domain="qonnx.custom_op.general", op_type="XnorPopcountMatMul") +@register_custom_op class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index f7ad3ad8..045f4c70 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -35,21 +35,20 @@ # global registry mapping (domain, op_type) -> CustomOp subclass CUSTOM_OP_REGISTRY = {} -# global registry for custom op domains -_CUSTOM_DOMAINS = set() - # global registry for custom op metadata _OP_METADATA = {} +# global registry mapping domains to their module paths +# Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) +DOMAIN_REGISTRY = {} + -def register_custom_domain(domain): - """Register a domain as containing custom ops.""" - _CUSTOM_DOMAINS.add(domain) def is_custom_op_domain(domain): """Check if domain is registered for custom ops.""" - return any(domain.startswith(d) for d in _CUSTOM_DOMAINS) + # Check if domain is directly registered or starts with a registered domain + return domain in DOMAIN_REGISTRY or any(domain.startswith(d) for d in DOMAIN_REGISTRY) def hasCustomOp(domain, op_type): @@ -88,8 +87,6 @@ def register_op(domain, op_type, metadata=None): """ def decorator(cls): - # Auto-register the domain when an op is registered - register_custom_domain(domain) CUSTOM_OP_REGISTRY[(domain, op_type)] = cls if metadata is not None: _OP_METADATA[(domain, op_type)] = metadata @@ -111,36 +108,108 @@ def get_op_metadata(domain, op_type): return _OP_METADATA.get((domain, op_type)) -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - """Return a QONNX CustomOp instance for the given ONNX node, if it exists.""" +def register_domain(domain, module_path=None): + """Register a domain with its associated module path. + + This function registers the domain and its module path, allowing classes + defined in any direct child module of this path to use @register_custom_op. + Subfolders/subpackages must be registered separately. + + Args: + domain: The domain to register (e.g., "finn.custom_op.fpgadataflow") + module_path: The Python module path. If None, uses the domain as the path. + """ + DOMAIN_REGISTRY[domain] = module_path + + +# Keep register_domain_path as deprecated alias for backward compatibility +def register_domain_path(module_path, domain): + """Deprecated: Use register_domain instead.""" + return register_domain(domain, module_path) + +def register_custom_op(cls=None, *, op_type=None): + """Register a custom op, inferring domain from parent module path. + + Can be used as @register_custom_op or @register_custom_op(op_type="CustomName"). + Domain is inferred from registered module paths. Op type defaults to class name. + + Args: + cls: The class to register (when used without parentheses) + op_type: Optional custom op_type (defaults to class name) + + Returns: + Decorated class or decorator function + """ + def decorator(cls): + # Get module path + module = cls.__module__ + + # Check if module is a direct child of any registered domain's module path + domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + # Use domain as module path if not specified + if module_path is None: + module_path = registered_domain + # Check if module is direct child of registered path + if module.startswith(module_path + "."): + # Ensure it's a direct child, not nested deeper + remainder = module[len(module_path) + 1:] + if "." not in remainder: # No more dots = direct child + domain = registered_domain + break + # Also check exact match (for __init__.py files) + elif module == module_path: + domain = registered_domain + break + + if domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_op(domain='...', op_type='{cls.__name__}')\n" + f"2. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + # Use class name as op_type if not specified + final_op_type = op_type or cls.__name__ + + # Register using the standard mechanism + return register_op(domain=domain, op_type=final_op_type)(cls) + + # Handle both @register_custom_op and @register_custom_op() + if cls is None: + return decorator + return decorator(cls) + + +def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): + """Return a QONNX CustomOp instance for the given ONNX node.""" op_type = node.op_type domain = node.domain + if brevitas_exception: # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") - + key = (domain, op_type) cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: return cls(node, onnx_opset_version=onnx_opset_version) - - try: - opset_module = importlib.import_module(domain) - except ModuleNotFoundError: - raise Exception(f"Could not load custom opset {domain}, check your PYTHONPATH") - - # op may have registered itself on import - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: - return cls(node, onnx_opset_version=onnx_opset_version) - - # fallback to legacy custom_op dictionary - if hasattr(opset_module, "custom_op") and isinstance(opset_module.custom_op, dict): + + # Check if we need to import the module to trigger registration + if domain.startswith("finn.custom_op"): try: - inst_wrapper = opset_module.custom_op[op_type] - return inst_wrapper(node, onnx_opset_version=onnx_opset_version) - except KeyError: + importlib.import_module(domain) + # Check again after import + cls = CUSTOM_OP_REGISTRY.get(key) + if cls is not None: + return cls(node, onnx_opset_version=onnx_opset_version) + except ImportError: pass - - raise Exception(f"Op {op_type} not found in custom opset {domain}") + + available_domains = sorted(DOMAIN_REGISTRY.keys()) + raise Exception( + f"Op '{op_type}' not found in domain '{domain}'. " + f"Available domains: {available_domains}" + ) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index abc5449f..68d691e1 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -62,10 +62,10 @@ def qonnx_make_model(graph_proto, **kwargs): return make_model(graph_proto, **kwargs) -def is_finn_op(op_type): - "Return whether given op_type string is a QONNX or FINN custom op" +def is_finn_op(domain): + "Return whether given domain string is a QONNX or FINN custom op domain" from qonnx.custom_op.registry import is_custom_op_domain - return is_custom_op_domain(op_type) + return is_custom_op_domain(domain) def get_num_default_workers(): diff --git a/tests/custom_op/legacy_custom_op.py b/tests/custom_op/legacy_custom_op.py deleted file mode 100644 index 95b689b9..00000000 --- a/tests/custom_op/legacy_custom_op.py +++ /dev/null @@ -1,22 +0,0 @@ -from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_op - - -@register_op(domain="legacy_custom_op", op_type="LegacyAdd") -class LegacyAdd(CustomOp): - def get_nodeattr_types(self): - return {} - - def make_shape_compatible_op(self, model): - return super().make_const_shape_op([1]) - - def infer_node_datatype(self, model): - pass - - def execute_node(self, context, graph): - a = context[self.onnx_node.input[0]] - b = context[self.onnx_node.input[1]] - context[self.onnx_node.output[0]] = a + b - - def verify_node(self): - pass diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 4ea18d30..592a99ae 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -31,10 +31,10 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp, register_op +from qonnx.custom_op.registry import getCustomOp, register_custom_op -@register_op(domain="qonnx.custom_op.general", op_type="AttrTestOp") +@register_custom_op class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} diff --git a/tests/custom_op/test_old_domain.py b/tests/custom_op/test_old_domain.py deleted file mode 100644 index 88ec226f..00000000 --- a/tests/custom_op/test_old_domain.py +++ /dev/null @@ -1,32 +0,0 @@ -import sys -from onnx import TensorProto, helper - -from qonnx.core.modelwrapper import ModelWrapper -from qonnx.custom_op.registry import getCustomOp -from qonnx.util.basic import qonnx_make_model - - -def test_get_custom_op_old_domain(): - assert "legacy_custom_op" not in sys.modules - - node = helper.make_node( - "LegacyAdd", - ["a", "b"], - ["c"], - domain="legacy_custom_op", - ) - - graph = helper.make_graph( - [node], - "legacy_graph", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [1]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [1]), - ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [1])], - ) - model = qonnx_make_model(graph, producer_name="legacy-test") - model = ModelWrapper(model) - - inst = getCustomOp(model.graph.node[0]) - assert inst.__class__.__name__ == "LegacyAdd" From f7ab4b5cb8e5e56dc9dce0c22acdb85e95de737e Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Wed, 16 Jul 2025 21:50:57 +0000 Subject: [PATCH 08/14] Dependency resolution --- src/qonnx/custom_op/registry.py | 228 +++++++++++++++++++++++--------- 1 file changed, 163 insertions(+), 65 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 045f4c70..e8be6f22 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -42,7 +42,73 @@ # Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) DOMAIN_REGISTRY = {} +# Track which domains have been loaded +_LOADED_DOMAINS = set() +# Track domain dependencies discovered through inheritance +_DOMAIN_DEPENDENCIES = {} # domain -> set of dependency domains + + + + +def _ensure_domain_loaded(domain): + """Ensure a domain and its dependencies are loaded.""" + if domain in _LOADED_DOMAINS: + return + + # Mark as loaded first to prevent infinite recursion + _LOADED_DOMAINS.add(domain) + + # First load any known dependencies + if domain in _DOMAIN_DEPENDENCIES: + for dep_domain in _DOMAIN_DEPENDENCIES[domain]: + if dep_domain != domain: # Avoid self-dependencies + _ensure_domain_loaded(dep_domain) + + # Try to import the domain module + if domain in DOMAIN_REGISTRY: + module_path = DOMAIN_REGISTRY[domain] or domain + try: + importlib.import_module(module_path) + except ImportError as e: + # Remove from loaded if import failed + _LOADED_DOMAINS.discard(domain) + # Continue without raising - domain might still work + elif domain.startswith(("finn.", "qonnx.")): + # Try importing even if not in registry + try: + importlib.import_module(domain) + except ImportError: + # Remove from loaded if import failed + _LOADED_DOMAINS.discard(domain) + + +def _register_op_with_dependencies(domain, op_type, cls, metadata=None): + """Register an op and track its inheritance dependencies.""" + # Register the op + CUSTOM_OP_REGISTRY[(domain, op_type)] = cls + if metadata is not None: + _OP_METADATA[(domain, op_type)] = metadata + + # Detect dependencies from inheritance + for base in cls.__bases__: + # Skip abstract base classes and non-custom ops + if base.__name__ in ('CustomOp', 'ABC', 'object', 'HWCustomOp', 'HLSBackend', 'RTLBackend'): + continue + + # Check if base class is a registered custom op + for (reg_domain, reg_op), reg_cls in CUSTOM_OP_REGISTRY.items(): + if reg_cls == base: + # Found a dependency - track it + if domain not in _DOMAIN_DEPENDENCIES: + _DOMAIN_DEPENDENCIES[domain] = set() + _DOMAIN_DEPENDENCIES[domain].add(reg_domain) + + # Immediately ensure the dependency is loaded + _ensure_domain_loaded(reg_domain) + break + + return cls def is_custom_op_domain(domain): @@ -61,6 +127,8 @@ def hasCustomOp(domain, op_type): Returns: bool: True if the op is registered, False otherwise """ + # Ensure domain is loaded first + _ensure_domain_loaded(domain) return (domain, op_type) in CUSTOM_OP_REGISTRY @@ -77,24 +145,6 @@ def get_ops_in_domain(domain): if d == domain] -def register_op(domain, op_type, metadata=None): - """Decorator for registering CustomOp classes. - - Args: - domain: The domain for the custom op - op_type: The op_type for the custom op - metadata: Optional dict of metadata about the op (backend, version, etc.) - """ - - def decorator(cls): - CUSTOM_OP_REGISTRY[(domain, op_type)] = cls - if metadata is not None: - _OP_METADATA[(domain, op_type)] = metadata - return cls - - return decorator - - def get_op_metadata(domain, op_type): """Get metadata for a registered custom op. @@ -128,59 +178,102 @@ def register_domain_path(module_path, domain): return register_domain(domain, module_path) -def register_custom_op(cls=None, *, op_type=None): - """Register a custom op, inferring domain from parent module path. + +def register_custom_op(domain=None, op_type=None, *, metadata=None): + """Register a custom op with flexible domain and op_type specification. - Can be used as @register_custom_op or @register_custom_op(op_type="CustomName"). - Domain is inferred from registered module paths. Op type defaults to class name. + Can be used in three ways: + 1. @register_custom_op("domain", "OpType") - Explicit domain and op_type + 2. @register_custom_op("domain") - Explicit domain, class name as op_type + 3. @register_custom_op - Automatic domain inference, class name as op_type Args: - cls: The class to register (when used without parentheses) - op_type: Optional custom op_type (defaults to class name) + domain: The domain for the custom op (optional) + op_type: The op_type for the custom op (optional) + metadata: Optional dict of metadata about the op (backend, version, etc.) Returns: Decorated class or decorator function """ - def decorator(cls): - # Get module path - module = cls.__module__ - - # Check if module is a direct child of any registered domain's module path - domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - # Use domain as module path if not specified - if module_path is None: - module_path = registered_domain - # Check if module is direct child of registered path - if module.startswith(module_path + "."): - # Ensure it's a direct child, not nested deeper - remainder = module[len(module_path) + 1:] - if "." not in remainder: # No more dots = direct child - domain = registered_domain + # Determine which mode we're in based on arguments + if domain is not None and isinstance(domain, str): + # Mode 1 or 2: Explicit domain provided + if op_type is not None and isinstance(op_type, str): + # Mode 1: Both domain and op_type provided + def decorator(cls): + return _register_op_with_dependencies(domain, op_type, cls, metadata) + return decorator + else: + # Mode 2: Only domain provided, use class name as op_type + def decorator(cls): + final_op_type = cls.__name__ + return _register_op_with_dependencies(domain, final_op_type, cls, metadata) + return decorator + else: + # Mode 3: No domain provided, or called without arguments + # Handle the case where it's used as @register_custom_op (no parentheses) + if domain is not None and not isinstance(domain, str): + # This means domain is actually the class (decorator without parentheses) + cls = domain + module = cls.__module__ + + # Find domain from registered domains + inferred_domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + if module_path is None: + module_path = registered_domain + if module.startswith(module_path + "."): + remainder = module[len(module_path) + 1:] + if "." not in remainder: + inferred_domain = registered_domain + break + elif module == module_path: + inferred_domain = registered_domain break - # Also check exact match (for __init__.py files) - elif module == module_path: - domain = registered_domain - break - - if domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_op(domain='...', op_type='{cls.__name__}')\n" - f"2. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - # Use class name as op_type if not specified - final_op_type = op_type or cls.__name__ - - # Register using the standard mechanism - return register_op(domain=domain, op_type=final_op_type)(cls) - - # Handle both @register_custom_op and @register_custom_op() - if cls is None: - return decorator - return decorator(cls) + + if inferred_domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_custom_op('domain', 'OpType')\n" + f"2. Use @register_custom_op('domain') to use class name as op_type\n" + f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + final_op_type = cls.__name__ + return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) + else: + # Decorator called with parentheses but no domain + def decorator(cls): + module = cls.__module__ + + # Find domain from registered domains + inferred_domain = None + for registered_domain, module_path in DOMAIN_REGISTRY.items(): + if module_path is None: + module_path = registered_domain + if module.startswith(module_path + "."): + remainder = module[len(module_path) + 1:] + if "." not in remainder: + inferred_domain = registered_domain + break + elif module == module_path: + inferred_domain = registered_domain + break + + if inferred_domain is None: + raise ValueError( + f"Module '{module}' is not in a registered domain path. " + f"Either:\n" + f"1. Use @register_custom_op('domain', 'OpType')\n" + f"2. Use @register_custom_op('domain') to use class name as op_type\n" + f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" + ) + + # Use provided op_type or default to class name + final_op_type = op_type or cls.__name__ + return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) + return decorator def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): @@ -192,15 +285,20 @@ def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_ex # transparently resolve Brevitas domain ops to qonnx ones domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + # Ensure the domain is loaded (will load dependencies automatically) + _ensure_domain_loaded(domain) + key = (domain, op_type) cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: return cls(node, onnx_opset_version=onnx_opset_version) - # Check if we need to import the module to trigger registration - if domain.startswith("finn.custom_op"): + # If not found and domain starts with finn, try explicit import as fallback + # This handles cases where domain isn't registered but module exists + if domain.startswith("finn.custom_op") and domain not in _LOADED_DOMAINS: try: importlib.import_module(domain) + _LOADED_DOMAINS.add(domain) # Check again after import cls = CUSTOM_OP_REGISTRY.get(key) if cls is not None: From d08c33dba97f2ed9ec0c9426c284696c40dd5270 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Wed, 16 Jul 2025 23:00:47 +0000 Subject: [PATCH 09/14] help multithreshold handle 3-dim more efficiently --- src/qonnx/custom_op/general/multithreshold.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 0a5ec596..f54fb408 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -145,6 +145,10 @@ def execute_node(self, context, graph): # TODO: Seems like a rather sketchy solution to support arbitrary data # layouts. This does not even validate the assumption of channel last # layout. + if v.ndim == 3: + orig_shape = v.shape + v = np.expand_dims(v, axis=0) + if v.ndim not in {2, 4}: # Remember the original shape to be restored later orig_shape = v.shape From d76507a66b87c6b7fbc7566a457ed8bb05b68c40 Mon Sep 17 00:00:00 2001 From: Joshua Monson Date: Thu, 17 Jul 2025 22:53:02 +0000 Subject: [PATCH 10/14] update extract model config to export config for subgraphs --- src/qonnx/util/config.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/qonnx/util/config.py b/src/qonnx/util/config.py index 63661862..2f6383d3 100644 --- a/src/qonnx/util/config.py +++ b/src/qonnx/util/config.py @@ -27,13 +27,15 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import onnx from qonnx.custom_op.registry import getCustomOp - -def extract_model_config_to_json(model, json_filename, attr_names_to_extract): - """Create a json file with layer name -> attribute mappings extracted from the - model. The created json file can be later applied on a model with +# update this code to handle export configs from subgraphs +# where the subgraph is found in a node's attribute as a graph type +def extract_model_config(model, attr_names_to_extract): + """Create a dictionary with layer name -> attribute mappings extracted from the + model. The created dictionary can be later applied on a model with qonnx.transform.general.ApplyConfig.""" cfg = dict() @@ -41,12 +43,22 @@ def extract_model_config_to_json(model, json_filename, attr_names_to_extract): for n in model.graph.node: oi = getCustomOp(n) layer_dict = dict() - for attr in attr_names_to_extract: - try: - layer_dict[attr] = oi.get_nodeattr(attr) - except AttributeError: - pass + for attr in n.attribute: + if attr.type == onnx.AttributeProto.GRAPH: # Graph type + # If the attribute is a graph, we need to extract the attributes from the subgraph + cfg.update(extract_model_config(model.make_subgraph_modelwrapper(attr.g), attr_names_to_extract)) + elif attr.name in attr_names_to_extract: + # If the attribute name is in the list, we can add it directly + layer_dict[attr.name] = oi.get_nodeattr(attr.name) if len(layer_dict) > 0: cfg[n.name] = layer_dict + return cfg + + +def extract_model_config_to_json(model, json_filename, attr_names_to_extract): + """Create a json file with layer name -> attribute mappings extracted from the + model. The created json file can be later applied on a model with + qonnx.transform.general.ApplyConfig.""" + with open(json_filename, "w") as f: - json.dump(cfg, f, indent=2) + json.dump(extract_model_config(model, attr_names_to_extract), f, indent=2) From fa3e0a81adb933847d0655140d15374bceaa6498 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 02:00:42 +0000 Subject: [PATCH 11/14] Removed decorators in favor of pure domain --- docs/overview.rst | 18 +- notebooks/3_custom_op.ipynb | 31 +- src/qonnx/custom_op/__init__.py | 14 +- .../channels_last/batch_normalization.py | 2 - src/qonnx/custom_op/channels_last/conv.py | 2 - src/qonnx/custom_op/channels_last/max_pool.py | 2 - src/qonnx/custom_op/general/bipolar_quant.py | 2 - src/qonnx/custom_op/general/debugmarker.py | 2 - src/qonnx/custom_op/general/floatquant.py | 2 - .../custom_op/general/genericpartition.py | 2 - src/qonnx/custom_op/general/im2col.py | 2 - src/qonnx/custom_op/general/intquant.py | 2 - src/qonnx/custom_op/general/maxpoolnhwc.py | 2 - src/qonnx/custom_op/general/multithreshold.py | 2 - src/qonnx/custom_op/general/quant.py | 17 +- src/qonnx/custom_op/general/quantavgpool2d.py | 2 - src/qonnx/custom_op/general/trunc.py | 2 - src/qonnx/custom_op/general/xnorpopcount.py | 2 - src/qonnx/custom_op/registry.py | 395 +++++++----------- src/qonnx/util/basic.py | 16 +- tests/custom_op/test_attr.py | 8 +- 21 files changed, 184 insertions(+), 343 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 5fd87de9..2f0f3577 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -45,12 +45,18 @@ Custom Operations/Nodes QONNX uses many custom operations (op_type in ONNX NodeProto) that are not defined in the ONNX operator schema. These custom nodes are marked with domain="qonnx.*" in the protobuf to identify them as such. These nodes can represent specific operations that we need for low-bit networks, or operations that are specific to a particular hardware backend. To get more familiar with custom operations and how they are created, please take a look in the Jupyter notebook about CustomOps (see chapter :ref:`tutorials` for details) or directly in the module :py:mod:`qonnx.custom_op`. -Custom ops can be registered automatically via Python entry points using the -``qonnx_custom_ops`` group. Each operator class should be decorated with -``@register_custom_op`` from ``qonnx.custom_op.registry``, which automatically -infers the domain from the module path. Packages installed with such an entry -point will be discovered on import and their ops made available through -``getCustomOp``. +Custom ops are automatically discovered through Python module namespaces. +Simply define your CustomOp subclass in the appropriate domain module +(e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically +available through ``getCustomOp``. + +For dynamic registration (e.g., in tests), use the registry functions: + +* ``getCustomOp(node)`` - Get a custom op instance from an ONNX node +* ``add_op_to_domain(domain, op_type, op_class)`` - Add an op to a domain's namespace +* ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path +* ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain +* ``get_ops_in_domain(domain)`` - List all ops available in a domain Custom ONNX Execution Flow diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index d0cd10fd..cd01686c 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,35 +129,24 @@ { "cell_type": "markdown", "metadata": {}, - "source": [ - "To make sure our custom op is available, it needs to be registered. The best practice for this is to create a submodule under `qonnx.custom_op` which includes a `custom_op` dictionary that maps strings (op names) to classes (op implementations). Since we're in a Jupyter notebook we'll just hijack it at runtime like this:" - ] + "source": "To make sure our custom op is available, we need to add it to the domain's namespace. For production code, you would place your CustomOp class directly in the appropriate module (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import qonnx.custom_op.general as general\n", - "general.custom_op[\"MyPythonPowerOp\"] = MyPythonPowerOp" - ] + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain namespace\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyPythonPowerOp\", MyPythonPowerOp)" }, { "cell_type": "markdown", "metadata": {}, - "source": [ - "We can see which custom ops are registered under this submodule by looking at the dictionary:" - ] + "source": "We can see which custom ops are available in a domain by using the registry function:" }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "general.custom_op" - ] + "source": "from qonnx.custom_op.registry import get_ops_in_domain\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nfrom qonnx.custom_op.registry import hasCustomOp\nprint(f\"MyPythonPowerOp available: {hasCustomOp('qonnx.custom_op.general', 'MyPythonPowerOp')}\")" }, { "cell_type": "markdown", @@ -462,17 +451,9 @@ }, { "cell_type": "code", - "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "# register our new op\n", - "general.custom_op[\"MyMixedPowerOp\"] = MyMixedPowerOp\n", - "\n", - "# make graph with new op\n", - "mixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\n", - "mixedop_graph.graph.node" - ] + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyMixedPowerOp\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node" }, { "cell_type": "markdown", @@ -744,4 +725,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/src/qonnx/custom_op/__init__.py b/src/qonnx/custom_op/__init__.py index be4f2162..7c38a8df 100644 --- a/src/qonnx/custom_op/__init__.py +++ b/src/qonnx/custom_op/__init__.py @@ -26,15 +26,5 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from qonnx.custom_op.registry import register_domain - -# Register QONNX domains (module path defaults to domain name) -register_domain("qonnx.custom_op.general") -register_domain("qonnx.custom_op.channels_last") - -# Register parent domain for hierarchy checking -register_domain("qonnx.custom_op") - -# Special case: Brevitas compatibility domain -# (QONNX handles Brevitas ops for backward compatibility) -register_domain("onnx.brevitas") \ No newline at end of file +# Domain aliases are automatically handled by the registry +# The onnx.brevitas -> qonnx.custom_op.general mapping is built into the registry \ No newline at end of file diff --git a/src/qonnx/custom_op/channels_last/batch_normalization.py b/src/qonnx/custom_op/channels_last/batch_normalization.py index b97ab1a8..f3b3f872 100644 --- a/src/qonnx/custom_op/channels_last/batch_normalization.py +++ b/src/qonnx/custom_op/channels_last/batch_normalization.py @@ -30,10 +30,8 @@ from onnx import TensorProto, helper from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class BatchNormalization(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/conv.py b/src/qonnx/custom_op/channels_last/conv.py index dc78a0fd..b0ff237b 100644 --- a/src/qonnx/custom_op/channels_last/conv.py +++ b/src/qonnx/custom_op/channels_last/conv.py @@ -31,10 +31,8 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.im2col import compute_conv_output_dim -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class Conv(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/channels_last/max_pool.py b/src/qonnx/custom_op/channels_last/max_pool.py index 53a7b617..383f3008 100644 --- a/src/qonnx/custom_op/channels_last/max_pool.py +++ b/src/qonnx/custom_op/channels_last/max_pool.py @@ -31,10 +31,8 @@ from qonnx.custom_op.channels_last.base_wrapped_op import ChannelsLastWrappedOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class MaxPool(ChannelsLastWrappedOp): def get_nodeattr_types(self): """Returns a dict of permitted attributes for node, where: diff --git a/src/qonnx/custom_op/general/bipolar_quant.py b/src/qonnx/custom_op/general/bipolar_quant.py index 102f5210..986a7082 100644 --- a/src/qonnx/custom_op/general/bipolar_quant.py +++ b/src/qonnx/custom_op/general/bipolar_quant.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def binary_quant(inp_tensor, scale): @@ -48,7 +47,6 @@ def binary_quant(inp_tensor, scale): return out_tensor -@register_custom_op class BipolarQuant(CustomOp): """Bipolar quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/debugmarker.py b/src/qonnx/custom_op/general/debugmarker.py index 3da80521..ae8cbce5 100644 --- a/src/qonnx/custom_op/general/debugmarker.py +++ b/src/qonnx/custom_op/general/debugmarker.py @@ -29,10 +29,8 @@ from onnx import helper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class DebugMarker(CustomOp): def get_nodeattr_types(self): return {"export_debug_name": ("s", True, "")} diff --git a/src/qonnx/custom_op/general/floatquant.py b/src/qonnx/custom_op/general/floatquant.py index ab74b1df..a34f6c01 100644 --- a/src/qonnx/custom_op/general/floatquant.py +++ b/src/qonnx/custom_op/general/floatquant.py @@ -33,7 +33,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op def compute_default_exponent_bias(exponent_bitwidth): @@ -121,7 +120,6 @@ def inf_nan_clamp(X, inf_mask, p_max_val_mask, n_max_val_mask): return x_q * scale # , self.saturating, self.inf_values, self.nan_values -@register_custom_op class FloatQuant(CustomOp): """Floating point quantization operation for QONNX. diff --git a/src/qonnx/custom_op/general/genericpartition.py b/src/qonnx/custom_op/general/genericpartition.py index 3418f9a5..841e4e9b 100755 --- a/src/qonnx/custom_op/general/genericpartition.py +++ b/src/qonnx/custom_op/general/genericpartition.py @@ -29,10 +29,8 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.core.onnx_exec import execute_onnx from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op -@register_custom_op class GenericPartition(CustomOp): """Class that corresponds to the meta/container node GenericPartition which is a placeholder for a group of nodes that have been separated diff --git a/src/qonnx/custom_op/general/im2col.py b/src/qonnx/custom_op/general/im2col.py index 276caf7d..42477832 100644 --- a/src/qonnx/custom_op/general/im2col.py +++ b/src/qonnx/custom_op/general/im2col.py @@ -31,7 +31,6 @@ import qonnx.util.basic as util from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op # adapted from A. Karpathy's CS231 im2col code # utilities to generate a patch matrix from a multichannel image @@ -141,7 +140,6 @@ def im2col_indices_nchw( # oh/ow and kh/kw will also be 1 in this case -@register_custom_op class Im2Col(CustomOp): def get_nodeattr_types(self): return { diff --git a/src/qonnx/custom_op/general/intquant.py b/src/qonnx/custom_op/general/intquant.py index a053e7ef..69920b97 100644 --- a/src/qonnx/custom_op/general/intquant.py +++ b/src/qonnx/custom_op/general/intquant.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def min_int(signed: bool, narrow_range: bool, bit_width: int) -> int: @@ -166,7 +165,6 @@ def round_half_down(x): raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}") -@register_custom_op class IntQuant(CustomOp): """Generic quantization operation for QONNX. Takes four inputs: - input tensor to quantize diff --git a/src/qonnx/custom_op/general/maxpoolnhwc.py b/src/qonnx/custom_op/general/maxpoolnhwc.py index 44aa04bc..eb964fc4 100644 --- a/src/qonnx/custom_op/general/maxpoolnhwc.py +++ b/src/qonnx/custom_op/general/maxpoolnhwc.py @@ -33,7 +33,6 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model @@ -45,7 +44,6 @@ def compute_pool_output_dim(ifm_dim, k, stride, pad=0, ceil_mode=0): return int(np.floor(((ifm_dim + 2 * pad - k) / stride) + 1)) -@register_custom_op class MaxPoolNHWC(CustomOp): # a MaxPool node, but using the NHWC data layout diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 76df4752..6df58f95 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def multithreshold(v, thresholds, out_scale=None, out_bias=None): @@ -85,7 +84,6 @@ def multithreshold(v, thresholds, out_scale=None, out_bias=None): return out_scale * ret.reshape(v.shape) + out_bias -@register_custom_op class MultiThreshold(CustomOp): """Class that corresponds to a multithresholding node.""" diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py index 858eafc2..a7356a8f 100644 --- a/src/qonnx/custom_op/general/quant.py +++ b/src/qonnx/custom_op/general/quant.py @@ -26,19 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Import IntQuant to create alias from qonnx.custom_op.general.intquant import IntQuant + +# Re-export functions from intquant for backward compatibility from qonnx.custom_op.general.intquant import int_quant as quant from qonnx.custom_op.general.intquant import max_int, min_int, resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op - -# Create alias and register it separately for "Quant" op_type -@register_custom_op -class Quant(IntQuant): - """Alias for IntQuant to support legacy \"Quant\" op_type.""" - pass -# Re-export functions -quant = quant -max_int = max_int -min_int = min_int -resolve_rounding_mode = resolve_rounding_mode +# Create alias for backward compatibility - Quant is just IntQuant +Quant = IntQuant \ No newline at end of file diff --git a/src/qonnx/custom_op/general/quantavgpool2d.py b/src/qonnx/custom_op/general/quantavgpool2d.py index 7e51f3e3..c0e24071 100644 --- a/src/qonnx/custom_op/general/quantavgpool2d.py +++ b/src/qonnx/custom_op/general/quantavgpool2d.py @@ -33,11 +33,9 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.maxpoolnhwc import compute_pool_output_dim -from qonnx.custom_op.registry import register_custom_op from qonnx.util.basic import qonnx_make_model -@register_custom_op class QuantAvgPool2d(CustomOp): """CustomOp that corresponds to the quantized average pooling layer from Brevitas""" diff --git a/src/qonnx/custom_op/general/trunc.py b/src/qonnx/custom_op/general/trunc.py index d2921262..8e2eaa19 100644 --- a/src/qonnx/custom_op/general/trunc.py +++ b/src/qonnx/custom_op/general/trunc.py @@ -32,7 +32,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp from qonnx.custom_op.general.quant import resolve_rounding_mode -from qonnx.custom_op.registry import register_custom_op def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode): @@ -59,7 +58,6 @@ def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding return y -@register_custom_op class Trunc(CustomOp): """Generic truncation operation for QONNX. Takes four inputs: - input tensor to truncate diff --git a/src/qonnx/custom_op/general/xnorpopcount.py b/src/qonnx/custom_op/general/xnorpopcount.py index c068cb9d..9a640599 100644 --- a/src/qonnx/custom_op/general/xnorpopcount.py +++ b/src/qonnx/custom_op/general/xnorpopcount.py @@ -31,7 +31,6 @@ from qonnx.core.datatype import DataType from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import register_custom_op def xnorpopcountmatmul(inp0, inp1): @@ -61,7 +60,6 @@ def xnorpopcountmatmul(inp0, inp1): return (out + K) * 0.5 -@register_custom_op class XnorPopcountMatMul(CustomOp): """Class that corresponds to a XNOR-popcount matrix multiplication node.""" diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index e8be6f22..c7d964f6 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -27,287 +27,172 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import importlib -import warnings -from importlib import metadata +import inspect +from typing import Dict +from qonnx.custom_op.base import CustomOp from qonnx.util.basic import get_preferred_onnx_opset -# global registry mapping (domain, op_type) -> CustomOp subclass -CUSTOM_OP_REGISTRY = {} +# Domain to module path mapping (only when different) +DOMAIN_MODULES: Dict[str, str] = { + "onnx.brevitas": "qonnx.custom_op.general", # Built-in compatibility +} -# global registry for custom op metadata -_OP_METADATA = {} -# global registry mapping domains to their module paths -# Structure: DOMAIN_REGISTRY[domain] = module_path (or None if module_path == domain) -DOMAIN_REGISTRY = {} - -# Track which domains have been loaded -_LOADED_DOMAINS = set() - -# Track domain dependencies discovered through inheritance -_DOMAIN_DEPENDENCIES = {} # domain -> set of dependency domains - - - - -def _ensure_domain_loaded(domain): - """Ensure a domain and its dependencies are loaded.""" - if domain in _LOADED_DOMAINS: - return - - # Mark as loaded first to prevent infinite recursion - _LOADED_DOMAINS.add(domain) - - # First load any known dependencies - if domain in _DOMAIN_DEPENDENCIES: - for dep_domain in _DOMAIN_DEPENDENCIES[domain]: - if dep_domain != domain: # Avoid self-dependencies - _ensure_domain_loaded(dep_domain) - - # Try to import the domain module - if domain in DOMAIN_REGISTRY: - module_path = DOMAIN_REGISTRY[domain] or domain - try: - importlib.import_module(module_path) - except ImportError as e: - # Remove from loaded if import failed - _LOADED_DOMAINS.discard(domain) - # Continue without raising - domain might still work - elif domain.startswith(("finn.", "qonnx.")): - # Try importing even if not in registry - try: - importlib.import_module(domain) - except ImportError: - # Remove from loaded if import failed - _LOADED_DOMAINS.discard(domain) - - -def _register_op_with_dependencies(domain, op_type, cls, metadata=None): - """Register an op and track its inheritance dependencies.""" - # Register the op - CUSTOM_OP_REGISTRY[(domain, op_type)] = cls - if metadata is not None: - _OP_METADATA[(domain, op_type)] = metadata - - # Detect dependencies from inheritance - for base in cls.__bases__: - # Skip abstract base classes and non-custom ops - if base.__name__ in ('CustomOp', 'ABC', 'object', 'HWCustomOp', 'HLSBackend', 'RTLBackend'): - continue - - # Check if base class is a registered custom op - for (reg_domain, reg_op), reg_cls in CUSTOM_OP_REGISTRY.items(): - if reg_cls == base: - # Found a dependency - track it - if domain not in _DOMAIN_DEPENDENCIES: - _DOMAIN_DEPENDENCIES[domain] = set() - _DOMAIN_DEPENDENCIES[domain].add(reg_domain) - - # Immediately ensure the dependency is loaded - _ensure_domain_loaded(reg_domain) - break - - return cls - - -def is_custom_op_domain(domain): - """Check if domain is registered for custom ops.""" - # Check if domain is directly registered or starts with a registered domain - return domain in DOMAIN_REGISTRY or any(domain.startswith(d) for d in DOMAIN_REGISTRY) - - -def hasCustomOp(domain, op_type): - """Check if a custom op exists without creating an instance. +def add_domain_alias(domain: str, module_path: str) -> None: + """Map a domain name to a different module path. Args: - domain: The domain of the custom op - op_type: The op_type of the custom op + domain: The ONNX domain name + module_path: The Python module path to use instead - Returns: - bool: True if the op is registered, False otherwise + Example: + add_domain_alias("finn.custom_op.fpgadataflow", "finn_custom_ops.fpgadataflow") """ - # Ensure domain is loaded first - _ensure_domain_loaded(domain) - return (domain, op_type) in CUSTOM_OP_REGISTRY + DOMAIN_MODULES[domain] = module_path -def get_ops_in_domain(domain): - """Get all registered ops in a domain. +def add_op_to_domain(domain: str, op_type: str, op_class: type) -> None: + """Add a custom op directly to a domain's module namespace. - Args: - domain: The domain to query - - Returns: - List[Tuple[str, Type[CustomOp]]]: List of (op_type, class) tuples - """ - return [(op_type, cls) for (d, op_type), cls in CUSTOM_OP_REGISTRY.items() - if d == domain] - - -def get_op_metadata(domain, op_type): - """Get metadata for a registered custom op. + This function dynamically adds custom ops to module namespaces at runtime. + Useful for test cases or dynamic op registration. Args: - domain: The domain of the custom op - op_type: The op_type of the custom op + domain: The ONNX domain name (e.g., "qonnx.custom_op.general") + op_type: The operation type name (e.g., "MyCustomOp") + op_class: The CustomOp subclass to add - Returns: - dict: The metadata dict if available, None otherwise + Example: + add_op_to_domain("qonnx.custom_op.general", "TestOp", TestOp) """ - return _OP_METADATA.get((domain, op_type)) - - -def register_domain(domain, module_path=None): - """Register a domain with its associated module path. + if not inspect.isclass(op_class) or not issubclass(op_class, CustomOp): + raise ValueError(f"{op_class} must be a subclass of CustomOp") - This function registers the domain and its module path, allowing classes - defined in any direct child module of this path to use @register_custom_op. - Subfolders/subpackages must be registered separately. + # Get the actual module path + module_path = DOMAIN_MODULES.get(domain, domain) - Args: - domain: The domain to register (e.g., "finn.custom_op.fpgadataflow") - module_path: The Python module path. If None, uses the domain as the path. - """ - DOMAIN_REGISTRY[domain] = module_path - - -# Keep register_domain_path as deprecated alias for backward compatibility -def register_domain_path(module_path, domain): - """Deprecated: Use register_domain instead.""" - return register_domain(domain, module_path) - + try: + # Import the module and add the op to its namespace + module = importlib.import_module(module_path) + setattr(module, op_type, op_class) + except ModuleNotFoundError: + raise ValueError(f"Could not find module for domain '{domain}' (tried: {module_path})") -def register_custom_op(domain=None, op_type=None, *, metadata=None): - """Register a custom op with flexible domain and op_type specification. +def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): + """Get a custom op instance for an ONNX node. - Can be used in three ways: - 1. @register_custom_op("domain", "OpType") - Explicit domain and op_type - 2. @register_custom_op("domain") - Explicit domain, class name as op_type - 3. @register_custom_op - Automatic domain inference, class name as op_type - - Args: - domain: The domain for the custom op (optional) - op_type: The op_type for the custom op (optional) - metadata: Optional dict of metadata about the op (backend, version, etc.) - - Returns: - Decorated class or decorator function + Lookup order: + 1. Direct attribute lookup in module namespace + 2. Legacy custom_op dictionary (backward compatibility) + 3. Search all CustomOp subclasses (fallback) """ - # Determine which mode we're in based on arguments - if domain is not None and isinstance(domain, str): - # Mode 1 or 2: Explicit domain provided - if op_type is not None and isinstance(op_type, str): - # Mode 1: Both domain and op_type provided - def decorator(cls): - return _register_op_with_dependencies(domain, op_type, cls, metadata) - return decorator - else: - # Mode 2: Only domain provided, use class name as op_type - def decorator(cls): - final_op_type = cls.__name__ - return _register_op_with_dependencies(domain, final_op_type, cls, metadata) - return decorator - else: - # Mode 3: No domain provided, or called without arguments - # Handle the case where it's used as @register_custom_op (no parentheses) - if domain is not None and not isinstance(domain, str): - # This means domain is actually the class (decorator without parentheses) - cls = domain - module = cls.__module__ - - # Find domain from registered domains - inferred_domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - if module_path is None: - module_path = registered_domain - if module.startswith(module_path + "."): - remainder = module[len(module_path) + 1:] - if "." not in remainder: - inferred_domain = registered_domain - break - elif module == module_path: - inferred_domain = registered_domain - break - - if inferred_domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_custom_op('domain', 'OpType')\n" - f"2. Use @register_custom_op('domain') to use class name as op_type\n" - f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - final_op_type = cls.__name__ - return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) - else: - # Decorator called with parentheses but no domain - def decorator(cls): - module = cls.__module__ - - # Find domain from registered domains - inferred_domain = None - for registered_domain, module_path in DOMAIN_REGISTRY.items(): - if module_path is None: - module_path = registered_domain - if module.startswith(module_path + "."): - remainder = module[len(module_path) + 1:] - if "." not in remainder: - inferred_domain = registered_domain - break - elif module == module_path: - inferred_domain = registered_domain - break - - if inferred_domain is None: - raise ValueError( - f"Module '{module}' is not in a registered domain path. " - f"Either:\n" - f"1. Use @register_custom_op('domain', 'OpType')\n" - f"2. Use @register_custom_op('domain') to use class name as op_type\n" - f"3. Register the domain: register_domain('your.domain', '{'.'.join(module.split('.')[:-1])}')" - ) - - # Use provided op_type or default to class name - final_op_type = op_type or cls.__name__ - return _register_op_with_dependencies(inferred_domain, final_op_type, cls, metadata) - return decorator - - -def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): - """Return a QONNX CustomOp instance for the given ONNX node.""" op_type = node.op_type domain = node.domain - if brevitas_exception: - # transparently resolve Brevitas domain ops to qonnx ones - domain = domain.replace("onnx.brevitas", "qonnx.custom_op.general") + # Get module path (handles brevitas via DOMAIN_MODULES mapping) + module_path = DOMAIN_MODULES.get(domain, domain) - # Ensure the domain is loaded (will load dependencies automatically) - _ensure_domain_loaded(domain) - - key = (domain, op_type) - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: - return cls(node, onnx_opset_version=onnx_opset_version) - - # If not found and domain starts with finn, try explicit import as fallback - # This handles cases where domain isn't registered but module exists - if domain.startswith("finn.custom_op") and domain not in _LOADED_DOMAINS: - try: - importlib.import_module(domain) - _LOADED_DOMAINS.add(domain) - # Check again after import - cls = CUSTOM_OP_REGISTRY.get(key) - if cls is not None: + try: + # Import the domain module + module = importlib.import_module(module_path) + + # Strategy 1: Direct namespace lookup (preferred) + if hasattr(module, op_type): + obj = getattr(module, op_type) + if inspect.isclass(obj) and issubclass(obj, CustomOp): + return obj(node, onnx_opset_version=onnx_opset_version) + + # Strategy 2: Legacy custom_op dict (backward compatibility) + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + if op_type in module.custom_op: + cls = module.custom_op[op_type] return cls(node, onnx_opset_version=onnx_opset_version) - except ImportError: + + # Strategy 3: Search module for CustomOp subclasses (fallback) + # Useful for debugging and error messages + custom_ops = {} + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_')): # Skip private classes + custom_ops[name] = obj + + # Try case-insensitive match as last resort + for name, cls in custom_ops.items(): + if name.lower() == op_type.lower(): + return cls(node, onnx_opset_version=onnx_opset_version) + + # Not found - provide helpful error + available = list(custom_ops.keys()) + raise KeyError( + f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " + f"Available ops: {available}" + ) + + except ModuleNotFoundError: + raise Exception( + f"Could not load module '{module_path}' for domain '{domain}'. " + f"Ensure the module is installed and on your PYTHONPATH." + ) + + +# Legacy functions for backward compatibility +def hasCustomOp(domain, op_type): + """Check if a custom op exists in the domain's module namespace.""" + try: + # Create a dummy node to test + class DummyNode: pass + node = DummyNode() + node.op_type = op_type + node.domain = domain + + # Try to get the op class + module_path = DOMAIN_MODULES.get(domain, domain) + module = importlib.import_module(module_path) + + # Check namespace first + if hasattr(module, op_type): + obj = getattr(module, op_type) + if inspect.isclass(obj) and issubclass(obj, CustomOp): + return True + + # Check legacy dict + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + return op_type in module.custom_op + + return False + except: + return False + + +def get_ops_in_domain(domain): + """Get all ops in a domain by inspecting the module namespace.""" + ops = [] - available_domains = sorted(DOMAIN_REGISTRY.keys()) - raise Exception( - f"Op '{op_type}' not found in domain '{domain}'. " - f"Available domains: {available_domains}" - ) + try: + module_path = DOMAIN_MODULES.get(domain, domain) + module = importlib.import_module(module_path) + + # Check module namespace + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_')): + ops.append((name, obj)) + + # Also check legacy dict if present + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + for name, cls in module.custom_op.items(): + if not any(op[0] == name for op in ops): + ops.append((name, cls)) + + return ops + except: + return [] + + diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 68d691e1..92ba5d2e 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -33,6 +33,7 @@ import warnings from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import get_ops_in_domain # TODO solve by moving onnx-dependent fxns to onnx.py # finn-examples uses parts of qonnx without having @@ -63,9 +64,18 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(domain): - "Return whether given domain string is a QONNX or FINN custom op domain" - from qonnx.custom_op.registry import is_custom_op_domain - return is_custom_op_domain(domain) + """Return whether given domain string is a QONNX or FINN custom op domain. + + Validates that: + 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas) + 2. The domain exists and contains at least one CustomOp + """ + # Check if domain has known custom op prefix + if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas")): + return False + + # Validate that the domain actually exists and has CustomOps + return len(get_ops_in_domain(domain)) > 0 def get_num_default_workers(): diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 592a99ae..906e154a 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -31,10 +31,9 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.base import CustomOp -from qonnx.custom_op.registry import getCustomOp, register_custom_op +from qonnx.custom_op.registry import getCustomOp, add_op_to_domain -@register_custom_op class AttrTestOp(CustomOp): def get_nodeattr_types(self): my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])} @@ -60,6 +59,9 @@ def verify_node(self): def test_attr(): + # Add the test op to the domain + add_op_to_domain("qonnx.custom_op.general", "AttrTestOp", AttrTestOp) + ishp = (1, 10) wshp = (1, 3) oshp = wshp @@ -86,6 +88,8 @@ def test_attr(): """ model = oprs.parse_model(input) model = ModelWrapper(model) + + # Now getCustomOp should find it through the manual registry inst = getCustomOp(model.graph.node[0]) w_prod = inst.get_nodeattr("tensor_attr") From 68346e3879936b0003d98c224bd56d4bebe267a5 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 04:54:40 +0000 Subject: [PATCH 12/14] Circular import fix --- src/qonnx/util/basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 92ba5d2e..dae5fbd4 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -33,7 +33,6 @@ import warnings from qonnx.core.datatype import DataType -from qonnx.custom_op.registry import get_ops_in_domain # TODO solve by moving onnx-dependent fxns to onnx.py # finn-examples uses parts of qonnx without having @@ -75,6 +74,8 @@ def is_finn_op(domain): return False # Validate that the domain actually exists and has CustomOps + # Lazy import to avoid circular dependency + from qonnx.custom_op.registry import get_ops_in_domain return len(get_ops_in_domain(domain)) > 0 From 93fd8d004e398371dcdf212b9ccb41dbd048df93 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Fri, 18 Jul 2025 17:54:15 +0000 Subject: [PATCH 13/14] Added brainsmith to hide finn ops --- src/qonnx/util/basic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index dae5fbd4..72fe18c2 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -63,14 +63,14 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(domain): - """Return whether given domain string is a QONNX or FINN custom op domain. + """Return whether given domain string is a QONNX, FINN, or Brainsmith custom op domain. Validates that: - 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas) + 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas, brainsmith.) 2. The domain exists and contains at least one CustomOp """ # Check if domain has known custom op prefix - if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas")): + if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas", "brainsmith.")): return False # Validate that the domain actually exists and has CustomOps From f2c4ccd3e71795c9f116ee5a0c87a7dfd590c6d0 Mon Sep 17 00:00:00 2001 From: Thomas Keller Date: Sat, 18 Oct 2025 22:55:02 -0700 Subject: [PATCH 14/14] refactor: migrate registry to thread-safe, cache-based architecture Replace namespace-based custom op registration with centralized registry using (domain, op_type) keys. Add thread-safe operations with RLock, lazy discovery, and caching. Deprecate is_finn_op and hasCustomOp in favor of is_custom_op. Simplify add_op_to_domain signature to derive op_type from class name. Update all call sites across codebase. --- docs/overview.rst | 7 +- notebooks/3_custom_op.ipynb | 11 +- src/qonnx/core/modelwrapper.py | 5 +- src/qonnx/core/onnx_exec.py | 4 +- src/qonnx/custom_op/general/__init__.py | 29 +- src/qonnx/custom_op/registry.py | 318 +++++++++++------- .../transformation/infer_data_layouts.py | 7 +- src/qonnx/transformation/infer_datatypes.py | 5 +- src/qonnx/transformation/infer_shapes.py | 6 +- src/qonnx/util/basic.py | 29 +- tests/custom_op/test_attr.py | 2 +- tests/transformation/test_channelslast.py | 4 +- 12 files changed, 249 insertions(+), 178 deletions(-) diff --git a/docs/overview.rst b/docs/overview.rst index 2f0f3577..161d1e49 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -50,13 +50,14 @@ Simply define your CustomOp subclass in the appropriate domain module (e.g., ``qonnx.custom_op.general`` for general ops) and it will be automatically available through ``getCustomOp``. -For dynamic registration (e.g., in tests), use the registry functions: +For dynamic registration and querying, use the registry functions: * ``getCustomOp(node)`` - Get a custom op instance from an ONNX node -* ``add_op_to_domain(domain, op_type, op_class)`` - Add an op to a domain's namespace +* ``is_custom_op(domain, op_type=None)`` - Check if a specific op or domain has custom ops +* ``add_op_to_domain(domain, op_class)`` - Register an op at runtime (for testing) +* ``get_ops_in_domain(domain)`` - List all ops available in a domain * ``add_domain_alias(domain, module_path)`` - Map a domain to a different module path * ``hasCustomOp(domain, op_type)`` - Check if an op exists in a domain -* ``get_ops_in_domain(domain)`` - List all ops available in a domain Custom ONNX Execution Flow diff --git a/notebooks/3_custom_op.ipynb b/notebooks/3_custom_op.ipynb index cd01686c..1b822163 100644 --- a/notebooks/3_custom_op.ipynb +++ b/notebooks/3_custom_op.ipynb @@ -129,13 +129,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": "To make sure our custom op is available, we need to add it to the domain's namespace. For production code, you would place your CustomOp class directly in the appropriate module (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" + "source": "To make sure our custom op is available, we need to add it to the domain. For production code, you would place your CustomOp class directly in the appropriate module file (e.g., in a file under `qonnx/custom_op/general/`). For testing and experimentation like in this notebook, we can use the `add_op_to_domain` function:" }, { "cell_type": "code", "metadata": {}, "outputs": [], - "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain namespace\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyPythonPowerOp\", MyPythonPowerOp)" + "source": "from qonnx.custom_op.registry import add_op_to_domain\n\n# Add our custom op to the general domain\nadd_op_to_domain(\"qonnx.custom_op.general\", MyPythonPowerOp)", + "execution_count": null }, { "cell_type": "markdown", @@ -146,7 +147,8 @@ "cell_type": "code", "metadata": {}, "outputs": [], - "source": "from qonnx.custom_op.registry import get_ops_in_domain\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nfrom qonnx.custom_op.registry import hasCustomOp\nprint(f\"MyPythonPowerOp available: {hasCustomOp('qonnx.custom_op.general', 'MyPythonPowerOp')}\")" + "source": "from qonnx.custom_op.registry import get_ops_in_domain, is_custom_op\n\n# See all ops in the general domain\nops = get_ops_in_domain(\"qonnx.custom_op.general\")\nprint(f\"Available ops: {[op[0] for op in ops]}\")\n\n# Check if our op is there\nprint(f\"MyPythonPowerOp available: {is_custom_op('qonnx.custom_op.general', 'MyPythonPowerOp')}\")", + "execution_count": null }, { "cell_type": "markdown", @@ -453,7 +455,8 @@ "cell_type": "code", "metadata": {}, "outputs": [], - "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", \"MyMixedPowerOp\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node" + "source": "# register our new op\nadd_op_to_domain(\"qonnx.custom_op.general\", MyMixedPowerOp)\n\n# make graph with new op\nmixedop_graph = make_graph(input_shape, 2, op_type = \"MyMixedPowerOp\")\nmixedop_graph.graph.node", + "execution_count": null }, { "cell_type": "markdown", diff --git a/src/qonnx/core/modelwrapper.py b/src/qonnx/core/modelwrapper.py index 8890e24a..de289ef3 100644 --- a/src/qonnx/core/modelwrapper.py +++ b/src/qonnx/core/modelwrapper.py @@ -38,6 +38,7 @@ import qonnx.util.basic as util import qonnx.util.onnx as onnxutil from qonnx.core.datatype import DataType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.double_to_single_float import DoubleToSingleFloat from qonnx.transformation.general import ( RemoveStaticGraphInputs, @@ -624,11 +625,11 @@ def get_nodes_by_op_type(self, op_type): def get_finn_nodes(self): """Returns a list of nodes where domain == 'qonnx.*'.""" - return list(filter(lambda x: util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: is_custom_op(x.domain), self.graph.node)) def get_non_finn_nodes(self): """Returns a list of nodes where domain != 'qonnx.*'.""" - return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node)) + return list(filter(lambda x: not is_custom_op(x.domain), self.graph.node)) def get_node_index(self, node): """Returns current index of given node, or None if not found.""" diff --git a/src/qonnx/core/onnx_exec.py b/src/qonnx/core/onnx_exec.py index a8f4774c..3a686f7e 100644 --- a/src/qonnx/core/onnx_exec.py +++ b/src/qonnx/core/onnx_exec.py @@ -35,10 +35,10 @@ import qonnx.analysis.topology as ta import qonnx.core.execute_custom_node as ex_cu_node +from qonnx.custom_op.registry import is_custom_op from qonnx.util.basic import ( get_preferred_onnx_opset, get_sanitize_quant_tensors, - is_finn_op, qonnx_make_model, sanitize_quant_values, ) @@ -49,7 +49,7 @@ def execute_node(node, context, graph, return_full_exec_context=False, opset_ver Input/output provided via context.""" - if is_finn_op(node.domain): + if is_custom_op(node.domain, node.op_type): ex_cu_node.execute_custom_node(node, context, graph, onnx_opset_version=opset_version) else: # onnxruntime unfortunately does not implement run_node as defined by ONNX, diff --git a/src/qonnx/custom_op/general/__init__.py b/src/qonnx/custom_op/general/__init__.py index c2bb7a82..e859d860 100644 --- a/src/qonnx/custom_op/general/__init__.py +++ b/src/qonnx/custom_op/general/__init__.py @@ -39,17 +39,18 @@ from qonnx.custom_op.general.trunc import Trunc from qonnx.custom_op.general.xnorpopcount import XnorPopcountMatMul -__all__ = [ - "DebugMarker", - "QuantAvgPool2d", - "MaxPoolNHWC", - "GenericPartition", - "MultiThreshold", - "XnorPopcountMatMul", - "Im2Col", - "IntQuant", - "Quant", - "Trunc", - "BipolarQuant", - "FloatQuant", -] +# Legacy dictionary for backward compatibility +custom_op = { + "DebugMarker": DebugMarker, + "QuantAvgPool2d": QuantAvgPool2d, + "MaxPoolNHWC": MaxPoolNHWC, + "GenericPartition": GenericPartition, + "MultiThreshold": MultiThreshold, + "XnorPopcountMatMul": XnorPopcountMatMul, + "Im2Col": Im2Col, + "IntQuant": IntQuant, + "Quant": IntQuant, # Alias + "Trunc": Trunc, + "BipolarQuant": BipolarQuant, + "FloatQuant": FloatQuant, +} diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index c7d964f6..8eb1b378 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,171 +28,233 @@ import importlib import inspect -from typing import Dict +from threading import RLock +from typing import Dict, List, Optional, Tuple, Type from qonnx.custom_op.base import CustomOp from qonnx.util.basic import get_preferred_onnx_opset -# Domain to module path mapping (only when different) -DOMAIN_MODULES: Dict[str, str] = { - "onnx.brevitas": "qonnx.custom_op.general", # Built-in compatibility +# Registry keyed by original ONNX domain: (domain, op_type) -> CustomOp class +_OP_REGISTRY: Dict[Tuple[str, str], Type[CustomOp]] = {} + +_REGISTRY_LOCK = RLock() + +# Maps ONNX domain names to Python module paths (used for imports only) +_DOMAIN_ALIASES: Dict[str, str] = { + "onnx.brevitas": "qonnx.custom_op.general", } def add_domain_alias(domain: str, module_path: str) -> None: """Map a domain name to a different module path. - + + Args: + domain: The ONNX domain name (e.g., "finn.custom_op.fpgadataflow") + module_path: The Python module path to use instead (e.g., "finn_custom_ops.fpgadataflow") + """ + with _REGISTRY_LOCK: + _DOMAIN_ALIASES[domain] = module_path + + +def resolve_domain(domain: str) -> str: + """Resolve a domain to its actual module path, handling aliases. + Args: domain: The ONNX domain name - module_path: The Python module path to use instead - - Example: - add_domain_alias("finn.custom_op.fpgadataflow", "finn_custom_ops.fpgadataflow") + + Returns: + Resolved module path """ - DOMAIN_MODULES[domain] = module_path + return _DOMAIN_ALIASES.get(domain, domain) + + +def add_op_to_domain(domain: str, op_class: Type[CustomOp]) -> None: + """Register a custom op directly to a domain at runtime. + The op_type is automatically derived from the class name. + Useful for testing and experimentation. For production, define CustomOps + in the appropriate module file. -def add_op_to_domain(domain: str, op_type: str, op_class: type) -> None: - """Add a custom op directly to a domain's module namespace. - - This function dynamically adds custom ops to module namespaces at runtime. - Useful for test cases or dynamic op registration. - Args: - domain: The ONNX domain name (e.g., "qonnx.custom_op.general") - op_type: The operation type name (e.g., "MyCustomOp") - op_class: The CustomOp subclass to add - + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + op_class: CustomOp subclass + Example: - add_op_to_domain("qonnx.custom_op.general", "TestOp", TestOp) + add_op_to_domain("qonnx.custom_op.general", MyTestOp) """ if not inspect.isclass(op_class) or not issubclass(op_class, CustomOp): raise ValueError(f"{op_class} must be a subclass of CustomOp") - - # Get the actual module path - module_path = DOMAIN_MODULES.get(domain, domain) - + + op_type = op_class.__name__ + + with _REGISTRY_LOCK: + _OP_REGISTRY[(domain, op_type)] = op_class + + +def _discover_custom_op(domain: str, op_type: str) -> bool: + """Discover and register a single custom op. + + Args: + domain: The ONNX domain name + op_type: The specific op type to discover + + Returns: + True if op was found and registered, False otherwise + """ + module_path = resolve_domain(domain) + try: - # Import the module and add the op to its namespace module = importlib.import_module(module_path) - setattr(module, op_type, op_class) except ModuleNotFoundError: - raise ValueError(f"Could not find module for domain '{domain}' (tried: {module_path})") + return False + + # Try namespace lookup + op_class = getattr(module, op_type, None) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + # Try legacy dict + custom_op_dict = getattr(module, 'custom_op', None) + if isinstance(custom_op_dict, dict): + op_class = custom_op_dict.get(op_type) + if inspect.isclass(op_class) and issubclass(op_class, CustomOp): + _OP_REGISTRY[(domain, op_type)] = op_class + return True + + return False def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset()): """Get a custom op instance for an ONNX node. - - Lookup order: - 1. Direct attribute lookup in module namespace - 2. Legacy custom_op dictionary (backward compatibility) - 3. Search all CustomOp subclasses (fallback) + + Args: + node: ONNX node with domain and op_type attributes + onnx_opset_version: ONNX opset version to use + + Returns: + CustomOp instance for the node + + Raises: + KeyError: If op_type not found in domain """ op_type = node.op_type domain = node.domain - - # Get module path (handles brevitas via DOMAIN_MODULES mapping) - module_path = DOMAIN_MODULES.get(domain, domain) - - try: - # Import the domain module - module = importlib.import_module(module_path) - - # Strategy 1: Direct namespace lookup (preferred) - if hasattr(module, op_type): - obj = getattr(module, op_type) - if inspect.isclass(obj) and issubclass(obj, CustomOp): - return obj(node, onnx_opset_version=onnx_opset_version) - - # Strategy 2: Legacy custom_op dict (backward compatibility) - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - if op_type in module.custom_op: - cls = module.custom_op[op_type] - return cls(node, onnx_opset_version=onnx_opset_version) - - # Strategy 3: Search module for CustomOp subclasses (fallback) - # Useful for debugging and error messages - custom_ops = {} - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - issubclass(obj, CustomOp) and - obj is not CustomOp and - not name.startswith('_')): # Skip private classes - custom_ops[name] = obj - - # Try case-insensitive match as last resort - for name, cls in custom_ops.items(): - if name.lower() == op_type.lower(): - return cls(node, onnx_opset_version=onnx_opset_version) - - # Not found - provide helpful error - available = list(custom_ops.keys()) + key = (domain, op_type) + + with _REGISTRY_LOCK: + if key in _OP_REGISTRY: + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + if _discover_custom_op(domain, op_type): + return _OP_REGISTRY[key](node, onnx_opset_version=onnx_opset_version) + + module_path = resolve_domain(domain) raise KeyError( f"Op '{op_type}' not found in domain '{domain}' (module: {module_path}). " - f"Available ops: {available}" - ) - - except ModuleNotFoundError: - raise Exception( - f"Could not load module '{module_path}' for domain '{domain}'. " - f"Ensure the module is installed and on your PYTHONPATH." + f"Ensure it's exported in the module namespace or in the custom_op dict." ) -# Legacy functions for backward compatibility -def hasCustomOp(domain, op_type): - """Check if a custom op exists in the domain's module namespace.""" - try: - # Create a dummy node to test - class DummyNode: - pass - node = DummyNode() - node.op_type = op_type - node.domain = domain - - # Try to get the op class - module_path = DOMAIN_MODULES.get(domain, domain) - module = importlib.import_module(module_path) - - # Check namespace first - if hasattr(module, op_type): - obj = getattr(module, op_type) - if inspect.isclass(obj) and issubclass(obj, CustomOp): - return True - - # Check legacy dict - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - return op_type in module.custom_op - - return False - except: +def is_custom_op(domain: str, op_type: Optional[str] = None) -> bool: + """Check if a custom op exists or if a domain has any custom ops. + + Args: + domain: The ONNX domain name + op_type: Optional operation type name. If None, checks if domain has any ops. + + Returns: + True if the specific op exists (when op_type given) or + if any ops exist for the domain (when op_type=None), False otherwise + """ + # Empty domain means standard ONNX op + if not domain: return False + with _REGISTRY_LOCK: + if op_type is not None: + # Check for specific op + key = (domain, op_type) + if key in _OP_REGISTRY: + return True + return _discover_custom_op(domain, op_type) + else: + # Check if domain has any registered ops + if any(d == domain for d, _ in _OP_REGISTRY.keys()): + return True + # Try to import the domain module as fallback + module_path = resolve_domain(domain) + try: + importlib.import_module(module_path) + return True + except (ModuleNotFoundError, ValueError): + return False + + +def hasCustomOp(domain: str, op_type: str) -> bool: + """Deprecated: Use is_custom_op instead. + + Check if a custom op exists. -def get_ops_in_domain(domain): - """Get all ops in a domain by inspecting the module namespace.""" + Args: + domain: The ONNX domain name + op_type: The operation type name + + Returns: + True if the op exists, False otherwise + """ + import warnings + warnings.warn( + "hasCustomOp is deprecated and will be removed in QONNX v1.0. " + "Use is_custom_op instead.", + DeprecationWarning, + stacklevel=2 + ) + return is_custom_op(domain, op_type) + + +def get_ops_in_domain(domain: str) -> List[Tuple[str, Type[CustomOp]]]: + """Get all CustomOp classes available in a domain. + + Args: + domain: ONNX domain name (e.g., "qonnx.custom_op.general") + + Returns: + List of (op_type, op_class) tuples + + Example: + ops = get_ops_in_domain("qonnx.custom_op.general") + for op_name, op_class in ops: + print(f"{op_name}: {op_class}") + """ ops = [] - - try: - module_path = DOMAIN_MODULES.get(domain, domain) - module = importlib.import_module(module_path) - - # Check module namespace - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and - issubclass(obj, CustomOp) and - obj is not CustomOp and - not name.startswith('_')): - ops.append((name, obj)) - - # Also check legacy dict if present - if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): - for name, cls in module.custom_op.items(): - if not any(op[0] == name for op in ops): - ops.append((name, cls)) - - return ops - except: - return [] + module_path = resolve_domain(domain) + + with _REGISTRY_LOCK: + # Strategy 1: Get cached ops (fast path) + for (d, op_type), op_class in _OP_REGISTRY.items(): + if d == domain: + ops.append((op_type, op_class)) + + # Strategy 2: Discover from module (for uncached ops) + try: + module = importlib.import_module(module_path) + + # Check namespace exports + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + issubclass(obj, CustomOp) and + obj is not CustomOp and + not name.startswith('_') and + not any(op[0] == name for op in ops)): + ops.append((name, obj)) + # Check legacy custom_op dict + if hasattr(module, 'custom_op') and isinstance(module.custom_op, dict): + for name, cls in module.custom_op.items(): + if not any(op[0] == name for op in ops): + ops.append((name, cls)) + except ModuleNotFoundError: + pass # Domain doesn't exist as module, return cached ops only + return ops diff --git a/src/qonnx/transformation/infer_data_layouts.py b/src/qonnx/transformation/infer_data_layouts.py index 81143e45..2e23d771 100644 --- a/src/qonnx/transformation/infer_data_layouts.py +++ b/src/qonnx/transformation/infer_data_layouts.py @@ -30,15 +30,16 @@ import qonnx.core.data_layout as DataLayout import qonnx.custom_op.registry as registry +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def _dims_to_layout(model, node, ndims): if ndims == 2: return DataLayout.NC else: - if is_finn_op(node.domain): + if is_custom_op(node.domain): if node.op_type == "MultiThreshold" or node.op_type == "QuantAvgPool2d": mt_inst = registry.getCustomOp(node) layout = mt_inst.get_nodeattr("data_layout") @@ -72,7 +73,7 @@ def _infer_node_data_layout(model, node): Returns True if any changes were made.""" old_layouts = list(map(lambda x: model.get_tensor_layout(x), node.output)) try: - if is_finn_op(node.domain): + if is_custom_op(node.domain): # try to guess based on number of output dims for o in node.output: ndims = len(model.get_tensor_shape(o)) diff --git a/src/qonnx/transformation/infer_datatypes.py b/src/qonnx/transformation/infer_datatypes.py index d54fd34f..167e0c3e 100644 --- a/src/qonnx/transformation/infer_datatypes.py +++ b/src/qonnx/transformation/infer_datatypes.py @@ -28,9 +28,10 @@ import qonnx.custom_op.registry as registry from qonnx.core.datatype import DataType, ScaledIntType +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation from qonnx.transformation.qcdq_to_qonnx import extract_elem_type -from qonnx.util.basic import get_by_name, is_finn_op +from qonnx.util.basic import get_by_name def is_scaled_int(x): @@ -82,7 +83,7 @@ def _infer_node_datatype(model, node, allow_scaledint_dtypes): idtypes = list(map(lambda x: model.get_tensor_datatype(x), node.input)) odtypes = list(map(lambda x: model.get_tensor_datatype(x), node.output)) op_type = node.op_type - if is_finn_op(node.domain): + if is_custom_op(node.domain): # handle DataType inference for CustomOp try: # lookup op_type in registry of CustomOps diff --git a/src/qonnx/transformation/infer_shapes.py b/src/qonnx/transformation/infer_shapes.py index 87fbf0ee..3e532abf 100644 --- a/src/qonnx/transformation/infer_shapes.py +++ b/src/qonnx/transformation/infer_shapes.py @@ -30,14 +30,14 @@ import qonnx.custom_op.registry as registry from qonnx.core.modelwrapper import ModelWrapper +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.base import Transformation -from qonnx.util.basic import is_finn_op def _make_shape_compatible_op(node, model): """Return a shape-compatible non-QONNX op for a given QONNX op. Used for shape inference with custom ops.""" - assert is_finn_op(node.domain), "Node domain is not set to qonnx.*" + assert is_custom_op(node.domain), "Node domain is not a registered custom op domain" op_type = node.op_type try: # lookup op_type in registry of CustomOps @@ -56,7 +56,7 @@ def _hide_finn_ops(model): node_ind = 0 for node in model.graph.node: node_ind += 1 - if is_finn_op(node.domain): + if is_custom_op(node.domain): new_node = _make_shape_compatible_op(node, model) # keep old node name to help debug shape inference issues new_node.name = node.name diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 72fe18c2..3253d873 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -62,21 +62,22 @@ def qonnx_make_model(graph_proto, **kwargs): return make_model(graph_proto, **kwargs) -def is_finn_op(domain): - """Return whether given domain string is a QONNX, FINN, or Brainsmith custom op domain. - - Validates that: - 1. The domain starts with known custom op prefixes (qonnx., finn., onnx.brevitas, brainsmith.) - 2. The domain exists and contains at least one CustomOp +def is_finn_op(op_type): + """Deprecated: Use is_custom_op from qonnx.custom_op.registry instead. + + Return whether given op_type string is a QONNX or FINN custom op. + This function uses hard-coded string matching and will be removed in QONNX v1.0. + Use the registry-based is_custom_op for better accuracy and extensibility. """ - # Check if domain has known custom op prefix - if not domain.startswith(("qonnx.", "finn.", "onnx.brevitas", "brainsmith.")): - return False - - # Validate that the domain actually exists and has CustomOps - # Lazy import to avoid circular dependency - from qonnx.custom_op.registry import get_ops_in_domain - return len(get_ops_in_domain(domain)) > 0 + import warnings + warnings.warn( + "is_finn_op is deprecated and will be removed in QONNX v1.0. " + "Use 'from qonnx.custom_op.registry import is_custom_op' instead.", + DeprecationWarning, + stacklevel=2 + ) + from qonnx.custom_op.registry import is_custom_op + return is_custom_op(op_type) def get_num_default_workers(): diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 906e154a..ac4f7a5c 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -60,7 +60,7 @@ def verify_node(self): def test_attr(): # Add the test op to the domain - add_op_to_domain("qonnx.custom_op.general", "AttrTestOp", AttrTestOp) + add_op_to_domain("qonnx.custom_op.general", AttrTestOp) ishp = (1, 10) wshp = (1, 3) diff --git a/tests/transformation/test_channelslast.py b/tests/transformation/test_channelslast.py index 24e64b4f..30382c64 100644 --- a/tests/transformation/test_channelslast.py +++ b/tests/transformation/test_channelslast.py @@ -43,11 +43,11 @@ MoveTransposePastFork, RemoveConsecutiveChanFirstAndChanLastTrafos, ) +from qonnx.custom_op.registry import is_custom_op from qonnx.transformation.general import GiveUniqueNodeNames from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit -from qonnx.util.basic import is_finn_op from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details from qonnx.util.to_channels_last import to_channels_last @@ -126,7 +126,7 @@ def analysis_test_for_left_transposes(model, test_model, make_input_channels_las def verify_all_nodes(model): result = dict() for n in model.graph.node: - if is_finn_op(n.domain): + if is_custom_op(n.domain): n_instance = getCustomOp(n) verify_result = n_instance.verify_node() result[n.name] = verify_result