Skip to content

Commit 41a6c11

Browse files
authored
call named_modules once per model prepare
Differential Revision: D84389318 Pull Request resolved: #3159
1 parent b1097de commit 41a6c11

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchao/quantization/pt2e/prepare.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def _maybe_insert_input_and_output_observers_for_node(
585585
node: Node,
586586
model: torch.fx.GraphModule,
587587
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
588+
named_modules: dict[str, torch.nn.Module],
588589
is_qat: bool,
589590
model_device: Optional[torch.device] = None,
590591
):
@@ -594,7 +595,6 @@ def _maybe_insert_input_and_output_observers_for_node(
594595
if this_node_quantization_annotation is None:
595596
return
596597

597-
named_modules = dict(model.named_modules(remove_duplicate=False))
598598
_maybe_insert_input_observers_for_node(
599599
node,
600600
None, # qconfig
@@ -666,13 +666,15 @@ def prepare(
666666
if obs_or_fq_callback:
667667
obs_or_fq_callback(model, obs_or_fq_map)
668668
model_device = _assert_and_get_unique_device(model)
669+
named_modules = dict(model.named_modules(remove_duplicate=False))
669670

670671
for node in nodes_before_observation:
671672
# TODO: simplify logic for inserting observers
672673
_maybe_insert_input_and_output_observers_for_node(
673674
node,
674675
model,
675676
obs_or_fq_map,
677+
named_modules,
676678
is_qat,
677679
model_device,
678680
)

0 commit comments

Comments
 (0)