Skip to content

Conversation

@wdziurdz
Copy link
Contributor

@wdziurdz wdziurdz commented Sep 17, 2025

Fixes #5074

@whitneywhtsang
Copy link
Contributor

Is this for testing? If it is not ready for review, please convert to draft.

@whitneywhtsang whitneywhtsang marked this pull request as draft September 25, 2025 00:33
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch 2 times, most recently from 84862d8 to aafbe1a Compare October 6, 2025 12:37
@wdziurdz wdziurdz marked this pull request as ready for review October 6, 2025 12:39
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch 3 times, most recently from 5b2d42b to 990069a Compare October 8, 2025 10:32
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch from 990069a to f51fe27 Compare October 10, 2025 07:57
@wdziurdz wdziurdz self-assigned this Oct 10, 2025
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch 5 times, most recently from 2d19f0a to 8c7870b Compare November 2, 2025 19:41
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch 4 times, most recently from 50639fa to e5b7881 Compare November 12, 2025 09:08
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch 4 times, most recently from d291f7c to b42922d Compare November 12, 2025 12:18
@wdziurdz
Copy link
Contributor Author

The last eight tests could be fixed by this PR: #5128.
Example of one failed test:

AssertionError: ref_y_scale: 0.004773152060806751, tri_y_scale: 0.005022321827709675
  assert tensor(False, device='xpu:0')
   +  where tensor(False, device='xpu:0') = <built-in method all of type object at 0x7f4175d82400>(tensor([0.0002], device='xpu:0', grad_fn=<AbsBackward0>) < 1e-10)
   +    where <built-in method all of type object at 0x7f4175d82400> = torch.all
   +    and   tensor([0.0002], device='xpu:0', grad_fn=<AbsBackward0>) = <built-in method abs of Tensor object at 0x7f406d589260>()
   +      where <built-in method abs of Tensor object at 0x7f406d589260> = (tensor(0.0048, device='xpu:0', grad_fn=<DivBackward0>) - tensor([0.0050], device='xpu:0')).abs

With this rounding, for the test parameters below we could to pass failed tests:

    sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
    if is_xpu():  # additional rounding on XPU for references values
        sep_scatter = sep_scatter or (do_scatter and not fused_scatter and n_expts_tot > 1 and split_k == 1 and act_dtype_str == "float8_e4m3fn")
    y_scale = flex.out_data.expected_scale if act_is_float8 else 1

    def round_x(x, idx):
        return x.to(act_dtype).to(torch.float32) if sep_gather else x

    round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
    ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref,  #
                             rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref,
                             inner_routing_data=inner_routing_data, device=device)

Without this fix, there’s look on small precision difference. It’s sufficient to pass all assertions, but the test fails when comparing the actual_scale for these parameters.

Pls looks on this @etiotto @whitneywhtsang

@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch from b42922d to 914acf2 Compare November 12, 2025 15:30
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch from 914acf2 to 12a2098 Compare November 13, 2025 08:16
Signed-off-by: Witold Dziurdz <witold.dziurdz@intel.com>
@wdziurdz wdziurdz force-pushed the dev/wdziurdz/test-matmul-1 branch from 12a2098 to d9f8c41 Compare November 13, 2025 13:56
@etiotto
Copy link
Contributor

etiotto commented Nov 13, 2025

Given that there is a proposed fix upstream we should wait for that change to land.

@etiotto etiotto marked this pull request as draft November 13, 2025 14:51
@@ -1,3 +1,2 @@
tests/test_matmul.py::test_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely it will not work as is on a770, arl-h, arl-s and mtl.

@wdziurdz wdziurdz marked this pull request as ready for review November 14, 2025 09:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Some python/triton_kernels/tests/test_matmul.py::test_op test cases don't work

6 participants