We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f657903 commit 9266734Copy full SHA for 9266734
torchao/quantization/pt2e/qat_utils.py
@@ -887,8 +887,8 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
887
node.target == torch.ops.aten.add_.Tensor
888
and node.args[0].op == "get_attr"
889
and node.args[1] == 1
890
- and torch.nn.modules.batchnorm.BatchNorm2d
891
- in [val[1] for val in node.meta["source_fn_stack"]]
+ and "torch.nn.modules.batchnorm.BatchNorm2d"
+ in [val[1] for _, val in node.meta["nn_module_stack"].items()]
892
):
893
m.graph.erase_node(node)
894
0 commit comments