Skip to content

Commit 793ce4d

Browse files
kryanbeanelaurafitzgerald
authored andcommitted
feat(RHOAIENG-26482): add gcs fault tolerance
1 parent 78e8168 commit 793ce4d

File tree

4 files changed

+83
-2
lines changed

4 files changed

+83
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .rayjob import RayJob, RayJobClusterConfig
22
from .status import RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo
3+
from .config import RayJobClusterConfig

src/codeflare_sdk/ray/rayjobs/config.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""
16-
The config sub-module contains the definition of the RayJobClusterConfigV2 dataclass,
16+
The config sub-module contains the definition of the RayJobClusterConfig dataclass,
1717
which is used to specify resource requirements and other details when creating a
1818
Cluster object.
1919
"""
@@ -141,6 +141,14 @@ class RayJobClusterConfig:
141141
A list of V1Volume objects to add to the Cluster
142142
volume_mounts:
143143
A list of V1VolumeMount objects to add to the Cluster
144+
enable_gcs_ft:
145+
A boolean indicating whether to enable GCS fault tolerance.
146+
redis_address:
147+
The address of the Redis server to use for GCS fault tolerance, required when enable_gcs_ft is True.
148+
redis_password_secret:
149+
Kubernetes secret reference containing Redis password. ex: {"name": "secret-name", "key": "password-key"}
150+
external_storage_namespace:
151+
The storage namespace to use for GCS fault tolerance. By default, KubeRay sets it to the UID of RayCluster.
144152
"""
145153

146154
head_cpu_requests: Union[int, str] = 2
@@ -167,8 +175,33 @@ class RayJobClusterConfig:
167175
annotations: Dict[str, str] = field(default_factory=dict)
168176
volumes: list[V1Volume] = field(default_factory=list)
169177
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
178+
enable_gcs_ft: bool = False
179+
redis_address: Optional[str] = None
180+
redis_password_secret: Optional[Dict[str, str]] = None
181+
external_storage_namespace: Optional[str] = None
170182

171183
def __post_init__(self):
184+
if self.enable_gcs_ft:
185+
if not self.redis_address:
186+
raise ValueError(
187+
"redis_address must be provided when enable_gcs_ft is True"
188+
)
189+
190+
if self.redis_password_secret and not isinstance(
191+
self.redis_password_secret, dict
192+
):
193+
raise ValueError(
194+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
195+
)
196+
197+
if self.redis_password_secret and (
198+
"name" not in self.redis_password_secret
199+
or "key" not in self.redis_password_secret
200+
):
201+
raise ValueError(
202+
"redis_password_secret must contain both 'name' and 'key' fields"
203+
)
204+
172205
self._validate_types()
173206
self._memory_to_string()
174207
self._validate_gpu_config(self.head_accelerators)
@@ -253,6 +286,11 @@ def build_ray_cluster_spec(self, cluster_name: str) -> Dict[str, Any]:
253286
"workerGroupSpecs": [self._build_worker_group_spec(cluster_name)],
254287
}
255288

289+
# Add GCS fault tolerance if enabled
290+
if self.enable_gcs_ft:
291+
gcs_ft_options = self._build_gcs_ft_options()
292+
ray_cluster_spec["gcsFaultToleranceOptions"] = gcs_ft_options
293+
256294
return ray_cluster_spec
257295

258296
def _build_head_group_spec(self) -> Dict[str, Any]:
@@ -455,3 +493,25 @@ def _generate_volumes(self) -> list:
455493
def _build_env_vars(self) -> list:
456494
"""Build environment variables list."""
457495
return [V1EnvVar(name=key, value=value) for key, value in self.envs.items()]
496+
497+
def _build_gcs_ft_options(self) -> Dict[str, Any]:
498+
"""Build GCS fault tolerance options."""
499+
gcs_ft_options = {"redisAddress": self.redis_address}
500+
501+
if (
502+
hasattr(self, "external_storage_namespace")
503+
and self.external_storage_namespace
504+
):
505+
gcs_ft_options["externalStorageNamespace"] = self.external_storage_namespace
506+
507+
if hasattr(self, "redis_password_secret") and self.redis_password_secret:
508+
gcs_ft_options["redisPassword"] = {
509+
"valueFrom": {
510+
"secretKeyRef": {
511+
"name": self.redis_password_secret["name"],
512+
"key": self.redis_password_secret["key"],
513+
}
514+
}
515+
}
516+
517+
return gcs_ft_options

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(
140140
self.cluster_name = cluster_name
141141
logger.info(f"Using existing cluster: {self.cluster_name}")
142142

143-
# Initialize the KubeRay job API client
144143
self._api = RayjobApi()
145144

146145
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")

src/codeflare_sdk/ray/rayjobs/test_rayjob.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,3 +971,24 @@ def test_rayjob_user_override_shutdown_behavior(mocker):
971971
)
972972

973973
assert rayjob_override_priority.shutdown_after_job_finishes is True
974+
975+
976+
def test_build_ray_cluster_spec_with_gcs_ft(mocker):
977+
"""Test build_ray_cluster_spec with GCS fault tolerance enabled."""
978+
from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig
979+
980+
# Create a test cluster config with GCS FT enabled
981+
cluster_config = RayJobClusterConfig(
982+
enable_gcs_ft=True,
983+
redis_address="redis://redis-service:6379",
984+
external_storage_namespace="storage-ns",
985+
)
986+
987+
# Build the spec using the method on the cluster config
988+
spec = cluster_config.build_ray_cluster_spec("test-cluster")
989+
990+
# Verify GCS fault tolerance options
991+
assert "gcsFaultToleranceOptions" in spec
992+
gcs_ft = spec["gcsFaultToleranceOptions"]
993+
assert gcs_ft["redisAddress"] == "redis://redis-service:6379"
994+
assert gcs_ft["externalStorageNamespace"] == "storage-ns"

0 commit comments

Comments
 (0)