@@ -161,44 +161,31 @@ def test_int8_weight_only_training(self, compile, device):
161161 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
162162 def test_int8_mixed_precision_training (self , compile , config ):
163163 _reset ()
164- bsize = 4
165- embed_dim = 32
164+ bsize = 64
165+ embed_dim = 64
166166 device = "cuda"
167167
168- # only use 1 matmul shape to reduce triton autotune time
169- model_ref = nn .Sequential (
170- nn .Linear (embed_dim , embed_dim , bias = False ),
171- nn .GELU (),
172- nn .Linear (embed_dim , embed_dim ),
173- ).to (device )
174- model_int8mp = copy .deepcopy (model_ref )
175- quantize_ (model_int8mp , int8_mixed_precision_training (config ), set_inductor_config = False )
168+ linear = nn .Linear (embed_dim , embed_dim ).cuda ()
169+ linear_int8mp = copy .deepcopy (linear )
170+ quantize_ (linear_int8mp , int8_mixed_precision_training (config ), set_inductor_config = False )
176171
177172 if compile :
178- model_ref .compile ()
179- model_int8mp .compile ()
173+ linear .compile ()
174+ linear_int8mp .compile ()
180175
181- optim_ref = torch .optim . AdamW ( model_ref . parameters () )
182- optim_int8mp = torch .optim . AdamW ( model_int8mp . parameters () )
176+ inputs = torch .randn ( bsize , embed_dim , device = device )
177+ grad_outputs = torch .randn ( bsize , embed_dim , device = device )
183178
184- for i in range (5 ):
185- inputs = torch .randn (bsize , embed_dim , device = device )
186- labels = torch .randint (embed_dim , size = (bsize ,), device = device )
187- loss_ref = F .cross_entropy (model_ref (inputs ), labels )
188- loss_int8mp = F .cross_entropy (model_int8mp (inputs ), labels )
189-
190- rel_error = abs (loss_int8mp .item () - loss_ref .item ()) / abs (loss_ref .item ())
191- assert rel_error < 3e-3 , (i , rel_error )
192-
193- loss_ref .backward ()
194- optim_ref .step ()
195- optim_ref .zero_grad ()
196-
197- loss_int8mp .backward ()
198- for p in model_int8mp .parameters ():
199- assert p .grad is not None
200- optim_int8mp .step ()
201- optim_int8mp .zero_grad ()
179+ inputs_ref , outputs_ref = self ._forward_and_backward (linear , inputs , grad_outputs )
180+ inputs_int8mp , outputs_int8mp = self ._forward_and_backward (linear_int8mp , inputs , grad_outputs )
181+
182+ def snr (ref , actual ):
183+ error = actual - ref
184+ return 20 * torch .log10 (ref .norm () / error .norm ())
185+
186+ assert snr (outputs_ref , outputs_int8mp ) > 20
187+ assert snr (inputs_ref .grad , inputs_int8mp .grad ) > 20
188+ assert snr (linear .weight .grad , linear_int8mp .weight .grad ) > 20
202189
203190
204191_FSDP_WORLD_SIZE = 2
0 commit comments