File tree Expand file tree Collapse file tree 3 files changed +8
-9
lines changed
transformer_engine/pytorch/ops Expand file tree Collapse file tree 3 files changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -901,15 +901,15 @@ def _test_basic_linear(
901901 dtype = dtype ,
902902 accumulate_into_main_grad = accumulate_into_main_grad ,
903903 )
904+ forward = te_ops .Sequential (
905+ te_ops .Quantize (forward = quantized_input , backward = quantized_grad_input ),
906+ op ,
907+ te_ops .Quantize (forward = quantized_output , backward = quantized_grad_output ),
908+ )
904909 with torch .no_grad ():
905910 op .weight .copy_ (w_test )
906911 del w_test
907912 op .weight .main_grad = torch .full_like (op .weight , 0.5 , dtype = torch .float32 )
908- forward = te_ops .Sequential (
909- te_ops .Quantize (forward = quantized_input , backward = quantized_grad_input ),
910- op ,
911- te_ops .Quantize (forward = quantized_output , backward = quantized_grad_output ),
912- )
913913 with te .autocast (enabled = quantized_compute , recipe = recipe ):
914914 y_test = forward (x_test )
915915 y_test .backward (dy_test )
Original file line number Diff line number Diff line change @@ -137,8 +137,10 @@ def __init__(
137137 out_features = out_features ,
138138 )
139139
140- # Whether weight tensor is natively quantized
140+ # Initialize recipe state if needed for natively quantized weight
141141 self ._with_quantized_weight : bool = FP8GlobalStateManager .with_fp8_parameters ()
142+ if self ._with_quantized_weight :
143+ self .reset_recipe_state (recipe = FP8GlobalStateManager .get_fp8_recipe ())
142144
143145 # Initialize parameters if needed
144146 weight = torch .empty (
Original file line number Diff line number Diff line change @@ -188,9 +188,6 @@ def __init__(self) -> None:
188188 # Objects for quantization
189189 self ._fp8_metas : Optional [dict [str , dict [str , Any ]]] = None
190190 self ._quantizers : Optional [dict [str , list [Quantizer ]]] = None
191- with_fp8_parameters = FP8GlobalStateManager .with_fp8_parameters ()
192- recipe = FP8GlobalStateManager .get_fp8_recipe () if with_fp8_parameters else None
193- self .reset_recipe_state (recipe = recipe )
194191
195192 @property
196193 def is_fused_op (self ) -> bool :
You can’t perform that action at this time.
0 commit comments