Skip to content

Commit 8845fb2

Browse files
authored
Avoid initializer name collision in _fuse_batchnorm.py (#2680)
Fixes pytorch/pytorch#166797 The original naming collides when there are multiple matched patterns sharing the same parent node. This PR changes the naming to depend on their own Conv weight name, which should be non-duplicated identifier. ~~NOTE: I don't know if my understanding is correct. It seems x is an input of the pattern, which x.name + "_bias" collides with `max_pool` bias (see the pic in the original issue)? If we check the output model after _fuse_batchnorm.py, the bias would be correct with a name `val_17` (the name may be collided and given by NameAuthority?). However, when the following rule _remove_optional_bias tries to fetch the bias, it would see all zero for some reasons.~~
1 parent d80575d commit 8845fb2

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

onnxscript/rewriter/rules/common/_fuse_batchnorm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu
6868
bias_name = inbound_node.inputs[2].name
6969
else:
7070
original_bias = np.zeros_like(input_mean)
71-
bias_name = x.name + "_bias"
71+
# Use inbound input 1 (should be weight) to derive a name for the bias
72+
# to avoid name collision on initializer creation when there are multiple patterns
73+
# sharing the same parent nodes.
74+
bias_name = inbound_node.inputs[1].name + "_bias"
7275
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
7376

7477
return op.op(

onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,64 @@ def test_fuse_batchnorm_graph_inputs(self):
253253
# No changes were applied as W is a graph input
254254
self.assertEqual(count, 0)
255255

256+
def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self):
257+
model_proto = onnx.parser.parse_model("""
258+
< ir_version: 7, opset_import: ["" : 17] >
259+
test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y1, float [N, ?, ?, ?] Y2)
260+
{
261+
X1 = MaxPool<kernel_shape=[3,3]>(X)
262+
X2 = Conv(X1, W1)
263+
Y1 = BatchNormalization(X2, gamma_64, beta_64, input_mean_64, input_var_64)
264+
X3 = Conv(X1, W2)
265+
Y2 = BatchNormalization(X3, gamma_256, beta_256, input_mean_256, input_var_256)
266+
}
267+
""")
268+
initializers = [
269+
onnx.numpy_helper.from_array(
270+
np.random.randn(64, 32, 3, 3).astype(np.float32), name="W1"
271+
),
272+
onnx.numpy_helper.from_array(
273+
np.random.randn(64).astype(np.float32), name="gamma_64"
274+
),
275+
onnx.numpy_helper.from_array(
276+
np.random.randn(64).astype(np.float32), name="beta_64"
277+
),
278+
onnx.numpy_helper.from_array(
279+
np.random.randn(64).astype(np.float32), name="input_mean_64"
280+
),
281+
onnx.numpy_helper.from_array(
282+
np.abs(np.random.randn(64)).astype(np.float32), name="input_var_64"
283+
),
284+
onnx.numpy_helper.from_array(
285+
np.random.randn(256, 32, 3, 3).astype(np.float32), name="W2"
286+
),
287+
onnx.numpy_helper.from_array(
288+
np.random.randn(256).astype(np.float32), name="gamma_256"
289+
),
290+
onnx.numpy_helper.from_array(
291+
np.random.randn(256).astype(np.float32), name="beta_256"
292+
),
293+
onnx.numpy_helper.from_array(
294+
np.random.randn(256).astype(np.float32), name="input_mean_256"
295+
),
296+
onnx.numpy_helper.from_array(
297+
np.abs(np.random.randn(256)).astype(np.float32), name="input_var_256"
298+
),
299+
]
300+
model_proto.graph.initializer.extend(initializers)
301+
onnx.checker.check_model(model_proto, True)
302+
model = ir.serde.deserialize_model(model_proto)
303+
count = _fuse_batchnorm.rules.apply_to_model(model)
304+
305+
# Applied twice, once for each BatchNorm
306+
self.assertEqual(count, 2)
307+
# it should have different bias names for the two fused Conv nodes
308+
conv_nodes = [node for node in model.graph if node.op_type == "Conv"]
309+
self.assertEqual(len(conv_nodes), 2)
310+
bias_names_1 = conv_nodes[0].inputs[2].name
311+
bias_names_2 = conv_nodes[1].inputs[2].name
312+
self.assertNotEqual(bias_names_1, bias_names_2)
313+
256314

257315
if __name__ == "__main__":
258316
unittest.main()

0 commit comments

Comments
 (0)