1515"""
1616
1717from abc import ABC , abstractmethod
18- from typing import Mapping
18+ from typing import ClassVar , Mapping
1919
2020import numpy as np
2121
@@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
3333class _FuseBatchNormBase (RewriteRuleClassBase , ABC ):
3434 """Interface for BatchNormalization nodes fusion."""
3535
36- def __init__ (
37- self ,
38- op_type : str ,
39- name : str | None = None ,
40- remove_nodes : bool = True ,
41- as_function : bool = False ,
42- ) -> None :
43- super ().__init__ (name = name , remove_nodes = remove_nodes , as_function = as_function )
44- self .op_type = op_type
45-
4636 @abstractmethod
4737 def get_filters_axis (self , attributes : Mapping [str , ir .Attr ]) -> int :
4838 """Return the axis along which BatchNorm scale should be broadcasted."""
@@ -116,8 +106,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M
116106class FuseBatchNormIntoConv (_FuseBatchNormBase ):
117107 """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
118108
119- def __init__ (self ):
120- super ().__init__ ("Conv" )
109+ op_type : ClassVar = "Conv"
121110
122111 def get_filters_axis (self , attributes : Mapping [str , ir .Attr ]) -> int :
123112 return 0
@@ -133,8 +122,7 @@ def pattern(self, op, x):
133122class FuseBatchNormIntoConvTranspose (_FuseBatchNormBase ):
134123 """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
135124
136- def __init__ (self ):
137- super ().__init__ ("ConvTranspose" )
125+ op_type : ClassVar = "ConvTranspose"
138126
139127 def get_filters_axis (self , attributes : Mapping [str , ir .Attr ]) -> int :
140128 return 1
@@ -150,8 +138,7 @@ def pattern(self, op, x):
150138class FuseBatchNormIntoGemm (_FuseBatchNormBase ):
151139 """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
152140
153- def __init__ (self ):
154- super ().__init__ ("Gemm" )
141+ op_type : ClassVar = "Gemm"
155142
156143 def get_filters_axis (self , attributes : Mapping [str , ir .Attr ]) -> int :
157144 return (
0 commit comments