@@ -73,7 +73,8 @@ def test_mse_observer_symmetric_scale_range():
7373
7474
7575def test_mse_fp4 ():
76- tensor = torch .arange (24 , dtype = torch .bfloat16 ).reshape ((4 , 6 )) / 24
76+ module = torch .nn .Linear (6 , 4 )
77+ module .weight .data = torch .arange (24 , dtype = torch .bfloat16 ).reshape ((4 , 6 )) / 24
7778
7879 weights = QuantizationArgs (
7980 num_bits = 4 ,
@@ -84,8 +85,15 @@ def test_mse_fp4():
8485 )
8586
8687 observer = weights .observer
87- observer = Observer .load_from_registry (observer , base_name = "weight" , args = weights )
88- scale , zero_point = observer (tensor )
88+ observer = Observer .load_from_registry (
89+ observer , base_name = "weight" , args = weights , module = module
90+ )
8991
90- qdq_tensor = fake_quantize (tensor , scale , zero_point , weights )
91- assert torch .nn .functional .mse_loss (qdq_tensor , tensor ) <= 0.002
92+ global_scale = observer .get_global_scale (module .weight )
93+ module .weight_global_scale = global_scale
94+ scale , zero_point = observer (module .weight )
95+
96+ qdq_tensor = fake_quantize (
97+ module .weight , scale , zero_point , weights , global_scale = global_scale
98+ )
99+ assert torch .nn .functional .mse_loss (qdq_tensor , module .weight ) <= 0.002
0 commit comments