Skip to content

Commit 9ca89e9

Browse files
authored
[PyTorch] Avoid initializing recipe state in fusible op base class constructor (#2421)
Do not initialize recipe state in base op class Op attrs may not be set. Move recipe state initialization to linear op constructor. Signed-off-by: Tim Moon <tmoon@nvidia.com>
1 parent 9f61f8a commit 9ca89e9

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

tests/pytorch/test_fusible_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -901,15 +901,15 @@ def _test_basic_linear(
901901
dtype=dtype,
902902
accumulate_into_main_grad=accumulate_into_main_grad,
903903
)
904+
forward = te_ops.Sequential(
905+
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
906+
op,
907+
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
908+
)
904909
with torch.no_grad():
905910
op.weight.copy_(w_test)
906911
del w_test
907912
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
908-
forward = te_ops.Sequential(
909-
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
910-
op,
911-
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
912-
)
913913
with te.autocast(enabled=quantized_compute, recipe=recipe):
914914
y_test = forward(x_test)
915915
y_test.backward(dy_test)

transformer_engine/pytorch/ops/basic/basic_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,10 @@ def __init__(
137137
out_features=out_features,
138138
)
139139

140-
# Whether weight tensor is natively quantized
140+
# Initialize recipe state if needed for natively quantized weight
141141
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
142+
if self._with_quantized_weight:
143+
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
142144

143145
# Initialize parameters if needed
144146
weight = torch.empty(

transformer_engine/pytorch/ops/op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,6 @@ def __init__(self) -> None:
188188
# Objects for quantization
189189
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
190190
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
191-
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
192-
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
193-
self.reset_recipe_state(recipe=recipe)
194191

195192
@property
196193
def is_fused_op(self) -> bool:

0 commit comments

Comments
 (0)