Skip to content

Commit 9266734

Browse files
Use nn_module_stack instead
Differential Revision: D85959415 Pull Request resolved: #3268
1 parent f657903 commit 9266734

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/quantization/pt2e/qat_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,8 +887,8 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
887887
node.target == torch.ops.aten.add_.Tensor
888888
and node.args[0].op == "get_attr"
889889
and node.args[1] == 1
890-
and torch.nn.modules.batchnorm.BatchNorm2d
891-
in [val[1] for val in node.meta["source_fn_stack"]]
890+
and "torch.nn.modules.batchnorm.BatchNorm2d"
891+
in [val[1] for _, val in node.meta["nn_module_stack"].items()]
892892
):
893893
m.graph.erase_node(node)
894894

0 commit comments

Comments
 (0)