-
Notifications
You must be signed in to change notification settings - Fork 75
Enable triton kernels matmul tests #5128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Is this for testing? If it is not ready for review, please convert to draft. |
84862d8 to
aafbe1a
Compare
5b2d42b to
990069a
Compare
990069a to
f51fe27
Compare
2d19f0a to
8c7870b
Compare
50639fa to
e5b7881
Compare
d291f7c to
b42922d
Compare
|
The last eight tests could be fixed by this PR: #5128. With this rounding, for the test parameters below we could to pass failed tests: sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
if is_xpu(): # additional rounding on XPU for references values
sep_scatter = sep_scatter or (do_scatter and not fused_scatter and n_expts_tot > 1 and split_k == 1 and act_dtype_str == "float8_e4m3fn")
y_scale = flex.out_data.expected_scale if act_is_float8 else 1
def round_x(x, idx):
return x.to(act_dtype).to(torch.float32) if sep_gather else x
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref,
inner_routing_data=inner_routing_data, device=device)Without this fix, there’s look on small precision difference. It’s sufficient to pass all assertions, but the test fails when comparing the actual_scale for these parameters. Pls looks on this @etiotto @whitneywhtsang |
b42922d to
914acf2
Compare
914acf2 to
12a2098
Compare
Signed-off-by: Witold Dziurdz <witold.dziurdz@intel.com>
12a2098 to
d9f8c41
Compare
|
Given that there is a proposed fix upstream we should wait for that change to land. |
| @@ -1,3 +1,2 @@ | |||
| tests/test_matmul.py::test_op | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most likely it will not work as is on a770, arl-h, arl-s and mtl.
Fixes #5074