Skip to content

Commit bb6730b

Browse files
committed
fix: e2e on PR
1 parent d202cd4 commit bb6730b

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

tests/e2e/heterogeneous_clusters_kind_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def test_heterogeneous_clusters(self):
3131
def run_heterogeneous_clusters(
3232
self, gpu_resource_name="nvidia.com/gpu", number_of_gpus=0
3333
):
34+
# Use GPU-enabled Ray image when GPUs are requested
35+
from codeflare_sdk.common.utils import constants
36+
ray_image = f"rayproject/ray:{constants.RAY_VERSION}-gpu" if number_of_gpus > 0 else f"rayproject/ray:{constants.RAY_VERSION}"
37+
3438
for flavor in self.resource_flavors:
3539
node_labels = (
3640
get_flavor_spec(self, flavor).get("spec", {}).get("nodeLabels", {})
@@ -58,6 +62,7 @@ def run_heterogeneous_clusters(
5862
worker_extended_resource_requests={
5963
gpu_resource_name: number_of_gpus
6064
},
65+
image=ray_image,
6166
write_to_file=True,
6267
verify_tls=False,
6368
local_queue=queue_name,

tests/e2e/local_interactive_sdk_kind_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def run_local_interactives(
4949

5050
ray.shutdown()
5151

52+
# Use GPU-enabled Ray image when GPUs are requested
53+
from codeflare_sdk.common.utils import constants
54+
ray_image = f"rayproject/ray:{constants.RAY_VERSION}-gpu" if number_of_gpus > 0 else f"rayproject/ray:{constants.RAY_VERSION}"
55+
5256
cluster = Cluster(
5357
ClusterConfiguration(
5458
name=cluster_name,
@@ -61,6 +65,7 @@ def run_local_interactives(
6165
worker_memory_requests=1,
6266
worker_memory_limits=4,
6367
worker_extended_resource_requests={gpu_resource_name: number_of_gpus},
68+
image=ray_image,
6469
verify_tls=False,
6570
)
6671
)

tests/e2e/mnist_raycluster_sdk_aw_kind_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def test_mnist_ray_cluster_sdk_kind_nvidia_gpu(self):
3737
def run_mnist_raycluster_sdk_kind(
3838
self, accelerator, gpu_resource_name="nvidia.com/gpu", number_of_gpus=0
3939
):
40+
# Use GPU-enabled Ray image when GPUs are requested
41+
from codeflare_sdk.common.utils import constants
42+
ray_image = f"rayproject/ray:{constants.RAY_VERSION}-gpu" if number_of_gpus > 0 else f"rayproject/ray:{constants.RAY_VERSION}"
43+
4044
cluster = Cluster(
4145
ClusterConfiguration(
4246
name="mnist",
@@ -49,6 +53,7 @@ def run_mnist_raycluster_sdk_kind(
4953
worker_memory_requests=1,
5054
worker_memory_limits=4,
5155
worker_extended_resource_requests={gpu_resource_name: number_of_gpus},
56+
image=ray_image,
5257
write_to_file=True,
5358
verify_tls=False,
5459
appwrapper=True,

tests/e2e/mnist_raycluster_sdk_kind_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def test_mnist_ray_cluster_sdk_kind_nvidia_gpu(self):
3737
def run_mnist_raycluster_sdk_kind(
3838
self, accelerator, gpu_resource_name="nvidia.com/gpu", number_of_gpus=0
3939
):
40+
# Use GPU-enabled Ray image when GPUs are requested
41+
from codeflare_sdk.common.utils import constants
42+
ray_image = f"rayproject/ray:{constants.RAY_VERSION}-gpu" if number_of_gpus > 0 else f"rayproject/ray:{constants.RAY_VERSION}"
43+
4044
cluster = Cluster(
4145
ClusterConfiguration(
4246
name="mnist",
@@ -49,6 +53,7 @@ def run_mnist_raycluster_sdk_kind(
4953
worker_memory_requests=1,
5054
worker_memory_limits=4,
5155
worker_extended_resource_requests={gpu_resource_name: number_of_gpus},
56+
image=ray_image,
5257
write_to_file=True,
5358
verify_tls=False,
5459
)

0 commit comments

Comments
 (0)