diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 51be18640955..2baf5a24090c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -250,11 +250,16 @@ jax_multiplatform_test( srcs = [ "mosaic_gpu_test.py", ], + disable_configs = [ + # TODO(b/462499936): Re-enable when test passes on MIG partition. + "gpu_b200", + ], enable_backends = [], enable_configs = [ "gpu_h100_x32", "gpu_h100", - "gpu_b200", + # TODO(b/462499936): Remove gpu_b200_full when test passes onMIG partition. + "gpu_b200_full", ], shard_count = 4, tags = [