From 773c1329980eb262393855d6bf446f30727f5993 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Tue, 11 Nov 2025 23:26:47 -0800 Subject: [PATCH 1/2] [xpu] Enable xpu for test_float8_tensor.py --- .../workflows/float8/test_float8_tensor.py | 111 +++++++++++------- torchao/float8/inference.py | 4 +- torchao/quantization/quant_api.py | 10 +- torchao/testing/utils.py | 7 +- torchao/utils.py | 7 ++ 5 files changed, 88 insertions(+), 51 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index df11b71e66..166553585e 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -31,6 +31,7 @@ from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import ( _is_fbgemm_gpu_genai_available, + auto_detect_device, is_sm_at_least_89, is_sm_at_least_90, is_sm_at_least_100, @@ -40,6 +41,8 @@ # Needed since changing args to function causes recompiles torch._dynamo.config.cache_size_limit = 128 +_DEVICE = auto_detect_device() + class ToyLinearModel(torch.nn.Module): def __init__(self, in_features, out_features, bias): @@ -137,16 +140,19 @@ def check_weight_scaling(self, granularity: Granularity): # TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") +@unittest.skipIf( + not torch.accelerator.is_available(), "skipping when gpu is not available" +) +@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_89(), "Need sm89+") class TestFloat8Tensor(TorchAOIntegrationTestCase): def setUp(self): - self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + self.GPU_DEVICES = [_DEVICE] if torch.accelerator.is_available() else [] torch.set_grad_enabled(False) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + torch.cuda.is_available() and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @@ -191,9 +197,10 @@ def test_fp8_linear_variants( ToyLinearModel(K, N, bias), ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + torch.cuda.is_available() and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("mode", ["dynamic", "weight-only"]) @@ -230,7 +237,7 @@ def test_fp8_matmul_lora_variants( kernel_preference, sizes, bias=False, - model=model.to("cuda"), + model=model.to(_DEVICE), ) def _test_fp8_matmul_model( @@ -253,6 +260,8 @@ def _test_fp8_matmul_model( return unittest.skip("unimplemented") elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])): + if _DEVICE.type == "xpu": + return unittest.skip("PerBlock granularity not supported on XPU") if dtype is not torch.bfloat16: return unittest.skip("unimplemented") elif mode != "dynamic": @@ -284,7 +293,8 @@ def _test_fp8_matmul_model( ) if kernel_preference == KernelPreference.FBGEMM and ( - (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_90()) + (not _is_fbgemm_gpu_genai_available()) + or (not torch.cuda.is_available() and not is_sm_at_least_90()) ): return unittest.skip( "Requires fbgemm_gpu_genai to run fbgemm kernel preference test" @@ -298,8 +308,8 @@ def _test_fp8_matmul_model( with error_context: M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") - model = model.eval().to(dtype).to("cuda") + input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE) + model = model.eval().to(dtype).to(_DEVICE) quantized_model = copy.deepcopy(model) @@ -328,9 +338,10 @@ def _test_fp8_matmul_model( f"Quantization error is too high got a SQNR of {error}" ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available") @unittest.skipIf( - not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0" + torch.cuda.is_available() and not is_sm_at_least_100(), + "Requires GPU with compute capability >= 10.0", ) @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), @@ -369,7 +380,7 @@ def test_fp8_conv_variants( kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=_DEVICE) if dim == 3: input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) else: @@ -384,7 +395,7 @@ def test_fp8_conv_variants( bias=False, padding=0, dtype=dtype, - device="cuda", + device=_DEVICE, ).eval() quantized_model = copy.deepcopy(model) @@ -411,9 +422,10 @@ def test_fp8_conv_variants( f"Quantization error is too high got a SQNR of {error}" ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need accelerator available") @unittest.skipIf( - not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0" + torch.cuda.is_available() and not is_sm_at_least_100(), + "Requires GPU with compute capability >= 10.0", ) @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), @@ -453,7 +465,7 @@ def test_fp8_conv_skip_quant( kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=_DEVICE) if dim == 3: input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) else: @@ -467,7 +479,7 @@ def test_fp8_conv_skip_quant( bias=False, padding=0, dtype=dtype, - device="cuda", + device=_DEVICE, ).eval() quantized_model = copy.deepcopy(model) @@ -488,14 +500,14 @@ def test_fp8_conv_skip_quant( @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( - not is_sm_at_least_90(), + torch.cuda.is_available() and not is_sm_at_least_90(), "Failing in SM89 right now: " "AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later", ) def test_slice(self, granularity): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) dummy1.weight = torch.nn.Parameter( @@ -567,9 +579,9 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): """ M, N, K = sizes dtype = torch.bfloat16 - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE) # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to(_DEVICE) # reference kernel preference and results # we are using KerenelPreference.TORCH as the reference @@ -586,6 +598,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): ] if ( _is_fbgemm_gpu_genai_available() + and torch.cuda.is_available() and is_sm_at_least_90() and not isinstance(granularity, PerTensor) ): @@ -614,9 +627,9 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes): @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) def test_slice_preserves_aliasing(self, granularity): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE) ) quantize_(l, config) param = l.weight @@ -631,7 +644,9 @@ def test_slice_and_copy_similar_to_vllm(self, granularity): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) self._test_slice_and_copy_similar_to_vllm(config) - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_bmm(self): # only support per row quantization @@ -646,7 +661,7 @@ def forward(self, x): return torch.bmm(x, self.weight.transpose(-2, -1)) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE B, M, K, N = 10, 32, 128, 256 @@ -659,7 +674,9 @@ def forward(self, x): sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_bmm_weight_in_bkn_layout(self): # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) @@ -679,7 +696,7 @@ def forward(self, x): return torch.bmm(x, self.weight) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE B, M, K, N = 10, 32, 128, 256 @@ -739,7 +756,7 @@ def test_to_device(self, granularity, sizes): def test_cat(self, granularity, sizes): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE M, N, K = sizes linear1 = torch.nn.Linear(K, N, dtype=dtype, device=device) linear2 = torch.nn.Linear(K, N, dtype=dtype, device=device) @@ -788,7 +805,9 @@ def test_cat(self, granularity, sizes): @unittest.skip( "This requires rowwise scaling for weight in layout BKN across axis 1 to work" ) - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_moe_weight_reshape_ops(self): # only per row quantization is supported for bmm @@ -799,7 +818,9 @@ def test_moe_weight_reshape_ops(self): # TODO: we have some other tests living in https://github.com/pytorch/ao/blob/4ecc89edd7b5cfc12e6f80854c85d04c472a0eb0/test/dtypes/test_affine_quantized_float.py#L743 # that should be moved here after v1 config is deprecated: # https://github.com/pytorch/ao/issues/2649 - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_expected_gpu_kernel_fbgemm(self): """Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels @@ -809,7 +830,7 @@ def test_expected_gpu_kernel_fbgemm(self): M, K, N = 128, 256, 512 m = torch.nn.Sequential( - torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16) ) config = Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), @@ -817,7 +838,7 @@ def test_expected_gpu_kernel_fbgemm(self): ) quantize_(m, config) m = torch.compile(m) - x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) out, code = run_and_get_code(m, x) # 1. check at least one occurrence of the quantize op and rowwise gemm op @@ -830,7 +851,9 @@ def test_expected_gpu_kernel_fbgemm(self): ".run(" ).run(code[0]) - @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + ) def test_index_select(self): """ test that `x_0 = x[0]` works when `x` is a 3D `Float8Tensor`. This is @@ -840,7 +863,7 @@ def test_index_select(self): """ E, K, N = 128, 256, 512 - x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(E, N, K, device=_DEVICE, dtype=torch.bfloat16) x_fp8 = Float8Tensor.from_hp(x) x_fp8_1 = x_fp8[1] torch.testing.assert_close( @@ -858,7 +881,7 @@ def test_index_select(self): def test_unsqueeze_operation(self, granularity, sizes): config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE M, N, K = sizes # Create a linear layer and quantize it @@ -914,7 +937,7 @@ def test_unsqueeze_conv2d_weight(self): granularity = PerTensor() config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32) dim = len(spatial_dims) kernel_size = 3 @@ -1002,7 +1025,7 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): """Test slicing operations on 3D Float8Tensor across all dimensions""" config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 - device = "cuda" + device = _DEVICE B, S, H = tensor_shape @@ -1100,7 +1123,7 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): self.assertEqual(sliced_dequantized, sliced_original) def test_to_dtype_layout(self): - x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16) x_fp8 = Float8Tensor.from_hp(x) y_fp8 = torch.ops.aten.to.dtype_layout( x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu" @@ -1110,9 +1133,9 @@ def test_to_dtype_layout(self): self.assertEqual(y_fp8.device, torch.device("cpu")) def test_has_compatible_shallow_copy_type(self): - x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) - x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) - x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) + x1 = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16) + x2 = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16) + x3 = torch.randn(128, 256, device=_DEVICE, dtype=torch.bfloat16) x1_fp8 = Float8Tensor.from_hp(x1) x2_fp8 = Float8Tensor.from_hp(x2) x3_fp8 = Float8Tensor.from_hp(x3) @@ -1123,7 +1146,7 @@ def test_has_compatible_shallow_copy_type(self): self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8)) def test_transpose(self): - x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 512, device=_DEVICE, dtype=torch.bfloat16) x_fp8 = Float8Tensor.from_hp(x) x_fp8_t = x_fp8.t() torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..af3b3fc391 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -285,7 +285,9 @@ def _check_hardware_support( is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities) if is_per_tensor or is_per_row: - assert is_sm_at_least_89() or is_MI300(), ( + assert torch.xpu.is_available() or ( + torch.cuda.is_available() and is_sm_at_least_89() or is_MI300() + ), ( "Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+." ) elif is_a_1_128_w_128_128: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c29382b658..22ffee2a1f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1792,6 +1792,9 @@ def __post_init__(self): ), "unimplemented" assert self.version >= 2, "unimplemented" default_use_fast_accum = False + if torch.xpu.is_available(): + # XPU does not support fast_accum for now + default_use_fast_accum = False if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum) @@ -1898,9 +1901,10 @@ def _float8_dynamic_activation_float8_weight_transform( *, parameter_name: str = "weight", ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) + if torch.cuda.is_available(): + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 10315d45f5..f89a9e7b88 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -26,6 +26,7 @@ from torchao.testing.model_architectures import LlamaModelsLlama4Experts from torchao.utils import ( DummyModule, + auto_detect_device, get_compute_capability, ) @@ -433,15 +434,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): # and does not use tensor parallelism dtype = torch.bfloat16 - device = "cuda" - l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + device = auto_detect_device() + l = torch.nn.Linear(1024, 1024, device=device, dtype=dtype) quantize_(l, config) # high level, we do a narrow for both param.data and the loaded_weights # and do inplace copy_ to copy from the loaded_weights into param.data # simulate loaded_weight - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) # making the weight different dummy_l.weight = torch.nn.Parameter( dummy_l.weight diff --git a/torchao/utils.py b/torchao/utils.py index e123dfe891..ab211313ce 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -150,6 +150,13 @@ def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Te ) +def auto_detect_device(): + if torch.accelerator.is_available(): + return torch.accelerator.current_accelerator() + else: + return None + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded From 1a8cef7a780b3337453892cd4b9564d01ef92a74 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Tue, 11 Nov 2025 23:51:35 -0800 Subject: [PATCH 2/2] Fix typo --- .../quantize_/workflows/float8/test_float8_tensor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 166553585e..7b63f84502 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -645,7 +645,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity): self._test_slice_and_copy_similar_to_vllm(config) @unittest.skipIf( - torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+" ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_bmm(self): @@ -675,7 +675,7 @@ def forward(self, x): self.assertTrue(sqnr > 20) @unittest.skipIf( - torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+" ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_bmm_weight_in_bkn_layout(self): @@ -806,7 +806,7 @@ def test_cat(self, granularity, sizes): "This requires rowwise scaling for weight in layout BKN across axis 1 to work" ) @unittest.skipIf( - torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+" ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_moe_weight_reshape_ops(self): @@ -819,7 +819,7 @@ def test_moe_weight_reshape_ops(self): # that should be moved here after v1 config is deprecated: # https://github.com/pytorch/ao/issues/2649 @unittest.skipIf( - torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+" ) @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") def test_expected_gpu_kernel_fbgemm(self): @@ -852,7 +852,7 @@ def test_expected_gpu_kernel_fbgemm(self): ).run(code[0]) @unittest.skipIf( - torch.cuda.is_available() and not is_sm_at_least_90(), "Nedd sm90+" + torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+" ) def test_index_select(self): """