1111
1212DTYPES = [torch .bfloat16 , torch .float ]
1313QUANT_DTYPES = [torch .int8 , torch .float8_e4m3fn ]
14- VEC_HIDDEN_SIZES = range ( 1024 , 1030 )
14+ VEC_HIDDEN_SIZES = [ 1024 , 1025 , 1027 , 1029 ]
1515# Avoid combinatorial explosion with full Cartesian product
1616NUM_TOKENS_HIDDEN_SIZES = [
1717 * [(1 , i ) for i in [1 , 64 , * VEC_HIDDEN_SIZES , 5120 , 5137 ]],
@@ -65,7 +65,7 @@ def ref_dynamic_per_token_quant(
6565 )
6666 else :
6767 assert quant_dtype == torch .int8
68- torch_out , scales = ops .scaled_int8_quant (torch_out )
68+ torch_out , scales , _ = ops .scaled_int8_quant (torch_out )
6969
7070 return torch_out , scales , residual
7171
@@ -109,7 +109,7 @@ def ops_impl(
109109
110110@pytest .mark .parametrize ("num_tokens, hidden_size" , NUM_TOKENS_HIDDEN_SIZES )
111111@pytest .mark .parametrize ("add_residual" , ADD_RESIDUAL )
112- @pytest .mark .parametrize ("scale_ub " , SCALE_UBS )
112+ @pytest .mark .parametrize ("has_scale_ub " , SCALE_UBS )
113113@pytest .mark .parametrize ("dtype" , DTYPES )
114114@pytest .mark .parametrize ("quant_dtype" , QUANT_DTYPES )
115115@pytest .mark .parametrize ("seed" , SEEDS )
@@ -119,7 +119,7 @@ def test_rms_norm(
119119 num_tokens : int ,
120120 hidden_size : int ,
121121 add_residual : bool ,
122- scale_ub : bool ,
122+ has_scale_ub : bool ,
123123 dtype : torch .dtype ,
124124 quant_dtype : torch .dtype ,
125125 seed : int ,
@@ -130,7 +130,7 @@ def test_rms_norm(
130130 torch .cuda .manual_seed (seed )
131131 torch .set_default_device (device )
132132
133- if scale_ub is not None and quant_dtype != torch .float8_e4m3fn :
133+ if has_scale_ub and quant_dtype != torch .float8_e4m3fn :
134134 # skip
135135 return
136136
@@ -143,9 +143,11 @@ def test_rms_norm(
143143 scale = 1 / (hidden_size )
144144 x = torch .randn (num_tokens , hidden_size , dtype = dtype ) * scale
145145 residual = torch .randn_like (x ) * scale if add_residual else None
146- if scale_ub is not None :
146+ if has_scale_ub :
147147 rms_x , _ = ref_rms_norm (layer , x , residual )
148148 scale_ub = torch .mean (rms_x ).to (dtype = torch .float32 , device = "cuda" )
149+ else :
150+ scale_ub = None
149151
150152 ref_out , ref_scales , ref_residual = ref_impl (
151153 layer , x , quant_dtype , residual , scale_ub
@@ -156,14 +158,27 @@ def test_rms_norm(
156158
157159 assert ref_out .dtype == quant_dtype
158160 assert ops_out .dtype == quant_dtype
159- assert torch .allclose (ref_scales , ops_scales )
160161 if quant_dtype == torch .int8 :
162+ assert torch .allclose (ref_scales , ops_scales , atol = 1e-6 )
161163 # big atol to account for round-off errors.
162164 assert torch .allclose (ref_out , ops_out , atol = 1 )
163165 else :
164- assert torch .allclose (
165- ref_out .to (dtype = torch .float32 ), ops_out .to (dtype = torch .float32 )
166- )
166+ assert torch .allclose (ref_scales , ops_scales )
167+ a = ref_out .to (dtype = torch .float32 )
168+ b = ops_out .to (dtype = torch .float32 )
169+ ok = torch .allclose (a , b )
170+ if not ok :
171+ # fallback: compare dequantized values with relaxed tolerance
172+ a_deq = a * ref_scales .view (- 1 , 1 )
173+ b_deq = b * ops_scales .view (- 1 , 1 )
174+ # NOTE: It is possible that some future test cases trigger this
175+ # max diff due to precision issues. If such an error is
176+ # encountered, it's recommended to inspect the differences between
177+ # all corresponding elements from each tensor (e.g. by looping over
178+ # them) and checking how many the max diff error shows up on (just
179+ # a few bad elements should still be considered acceptable).
180+ ok = torch .allclose (a_deq , b_deq , rtol = 5e-2 , atol = 5e-2 )
181+ assert ok
167182 if add_residual :
168183 assert torch .allclose (ref_residual , ops_residual )
169184
0 commit comments