From 42c644e74a4e7f0f2ae56dfe7d5f565ffc0a0f2a Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Sun, 9 Nov 2025 19:28:11 -0800 Subject: [PATCH 1/2] Add float4_e2m1fn_x2 support for concat --- src/ATen/native/xpu/sycl/Shape.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/Shape.cpp b/src/ATen/native/xpu/sycl/Shape.cpp index c8e4236401..7f449f86a8 100644 --- a/src/ATen/native/xpu/sycl/Shape.cpp +++ b/src/ATen/native/xpu/sycl/Shape.cpp @@ -395,7 +395,8 @@ void cat_out_kernel( kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), - AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kFloat4_e2m1fn_x2); } else { offset = 0; for (j = 0; j < numInputs; j++) { From d3f73a90e5cc84afb4c4e9dc40550a2771cc8da7 Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Mon, 10 Nov 2025 04:45:40 -0800 Subject: [PATCH 2/2] Add FP4 UT --- test/regressions/test_cat.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/regressions/test_cat.py b/test/regressions/test_cat.py index 1c4af32354..37ea1205a7 100644 --- a/test/regressions/test_cat.py +++ b/test/regressions/test_cat.py @@ -7,6 +7,9 @@ from torch.testing._internal.common_dtype import float8_types_and from torch.testing._internal.common_utils import run_tests, TestCase +cpu_device = torch.device("cpu") +xpu_device = torch.device("xpu") + class TestTorchMethod(TestCase): def _create_input_tensors(self, shape, dtype, memory_format=None): @@ -61,6 +64,21 @@ def test_cat_simple(self, dtype): self._test_cat_float8_core(tensors, dim, dtype) + def _float4_dummy_tensor(self, shape, device): + data = torch.ones(shape, dtype=torch.uint8, device=device) + return data.view(torch.float4_e2m1fn_x2) + + def test_cat_float4_simple(self): + input_cpu1 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device) + input_cpu2 = self._float4_dummy_tensor([2, 2, 6], device=cpu_device) + output_cpu = torch.stack([input_cpu1, input_cpu2]).view(torch.uint8) + + input_xpu1 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device) + input_xpu2 = self._float4_dummy_tensor([2, 2, 6], device=xpu_device) + output_xpu = torch.stack([input_xpu1, input_xpu2]).view(torch.uint8) + + self.assertEqual(output_xpu, output_cpu) + def test_cat_8d(self, dtype=torch.float): input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype) input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)