Skip to content

Commit 0e79b62

Browse files
[Rewriter] Add fuse batchnorm to default rules (#2553)
This PR adds `fuse_batchnorm` rules to default rules. --------- Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 647b22a commit 0e79b62

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_broadcast_to_matmul,
3636
_cast_constant_of_shape,
3737
_collapse_slices,
38+
_fuse_batchnorm,
3839
_fuse_pad_into_conv,
3940
_fuse_relus_clips,
4041
_min_max_to_clip,
@@ -53,6 +54,7 @@
5354
*_basic_rules.basic_optimization_rules(),
5455
*_redundant_scatter_nd.rules,
5556
*_fuse_pad_into_conv.rules,
57+
*_fuse_batchnorm.rules,
5658
)
5759

5860

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
from abc import ABC, abstractmethod
18-
from typing import Mapping
18+
from typing import ClassVar, Mapping
1919

2020
import numpy as np
2121

@@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
3333
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
3434
"""Interface for BatchNormalization nodes fusion."""
3535

36-
def __init__(
37-
self,
38-
op_type: str,
39-
name: str | None = None,
40-
remove_nodes: bool = True,
41-
as_function: bool = False,
42-
) -> None:
43-
super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function)
44-
self.op_type = op_type
45-
4636
@abstractmethod
4737
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
4838
"""Return the axis along which BatchNorm scale should be broadcasted."""
@@ -116,8 +106,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M
116106
class FuseBatchNormIntoConv(_FuseBatchNormBase):
117107
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
118108

119-
def __init__(self):
120-
super().__init__("Conv")
109+
op_type: ClassVar = "Conv"
121110

122111
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
123112
return 0
@@ -133,8 +122,7 @@ def pattern(self, op, x):
133122
class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
134123
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
135124

136-
def __init__(self):
137-
super().__init__("ConvTranspose")
125+
op_type: ClassVar = "ConvTranspose"
138126

139127
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
140128
return 1
@@ -150,8 +138,7 @@ def pattern(self, op, x):
150138
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
151139
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
152140

153-
def __init__(self):
154-
super().__init__("Gemm")
141+
op_type: ClassVar = "Gemm"
155142

156143
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
157144
return (

0 commit comments

Comments
 (0)