Skip to content

Commit 4fb8a6f

Browse files
astefanuttiopenshift-merge-robot
authored andcommitted
test: Add MNIST training in RayCluster with CodeFlare SDK
1 parent ae870ea commit 4fb8a6f

File tree

7 files changed

+235
-42
lines changed

7 files changed

+235
-42
lines changed

test/e2e/mnist_pytorch_mcad_job_test.go

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ func TestMNISTPyTorchMCAD(t *testing.T) {
3939
namespace := test.NewTestNamespace()
4040

4141
// MNIST training script
42-
mnist, err := scripts.ReadFile("mnist.py")
43-
test.Expect(err).NotTo(HaveOccurred())
44-
45-
mnistScript := &corev1.ConfigMap{
42+
mnist := &corev1.ConfigMap{
4643
TypeMeta: metav1.TypeMeta{
4744
APIVersion: corev1.SchemeGroupVersion.String(),
4845
Kind: "ConfigMap",
@@ -52,13 +49,13 @@ func TestMNISTPyTorchMCAD(t *testing.T) {
5249
Namespace: namespace.Name,
5350
},
5451
BinaryData: map[string][]byte{
55-
"mnist.py": mnist,
52+
"mnist.py": ReadFile(test, "mnist.py"),
5653
},
5754
Immutable: Ptr(true),
5855
}
59-
mnistScript, err = test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), mnistScript, metav1.CreateOptions{})
56+
mnist, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), mnist, metav1.CreateOptions{})
6057
test.Expect(err).NotTo(HaveOccurred())
61-
test.T().Logf("Created ConfigMap %s/%s successfully", mnistScript.Namespace, mnistScript.Name)
58+
test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name)
6259

6360
// pip requirements
6461
requirements := &corev1.ConfigMap{
@@ -121,7 +118,7 @@ torchvision==0.12.0
121118
VolumeSource: corev1.VolumeSource{
122119
ConfigMap: &corev1.ConfigMapVolumeSource{
123120
LocalObjectReference: corev1.LocalObjectReference{
124-
Name: mnistScript.Name,
121+
Name: mnist.Name,
125122
},
126123
},
127124
},
@@ -182,7 +179,7 @@ torchvision==0.12.0
182179
test.Eventually(AppWrapper(test, namespace, aw.Name), TestTimeoutMedium).
183180
Should(WithTransform(AppWrapperState, Equal(mcadv1beta1.AppWrapperStateActive)))
184181

185-
defer troubleshooting(test, job)
182+
defer JobTroubleshooting(test, job)
186183

187184
test.T().Logf("Waiting for Job %s/%s to complete successfully", job.Namespace, job.Name)
188185
test.Eventually(Job(test, job.Namespace, job.Name), TestTimeoutLong).
@@ -201,25 +198,3 @@ torchvision==0.12.0
201198
test.T().Logf("Printing Job %s/%s logs", job.Namespace, job.Name)
202199
test.T().Log(GetPodLogs(test, &pods[0], corev1.PodLogOptions{}))
203200
}
204-
205-
func troubleshooting(test Test, job *batchv1.Job) {
206-
if !test.T().Failed() {
207-
return
208-
}
209-
job = GetJob(test, job.Namespace, job.Name)
210-
211-
test.T().Errorf("Job %s/%s hasn't completed in time: %s", job.Namespace, job.Name, job)
212-
213-
pods := GetPods(test, job.Namespace, metav1.ListOptions{
214-
LabelSelector: labels.FormatLabels(job.Spec.Selector.MatchLabels)},
215-
)
216-
217-
if len(pods) == 0 {
218-
test.T().Errorf("Job %s/%s has no pods scheduled", job.Namespace, job.Name)
219-
} else {
220-
for i, pod := range pods {
221-
test.T().Logf("Printing Pod %s/%s logs", pod.Namespace, pod.Name)
222-
test.T().Log(GetPodLogs(test, &pods[i], corev1.PodLogOptions{}))
223-
}
224-
}
225-
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
Copyright 2023.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package e2e
18+
19+
import (
20+
"testing"
21+
22+
. "github.com/onsi/gomega"
23+
24+
batchv1 "k8s.io/api/batch/v1"
25+
corev1 "k8s.io/api/core/v1"
26+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
27+
28+
. "github.com/project-codeflare/codeflare-operator/test/support"
29+
)
30+
31+
func TestMNISTRayClusterSDK(t *testing.T) {
32+
test := With(t)
33+
test.T().Parallel()
34+
35+
test.T().Skip("Requires https://github.com/project-codeflare/codeflare-sdk/pull/146")
36+
37+
// Create a namespace
38+
namespace := test.NewTestNamespace()
39+
40+
// SDK script
41+
sdk := &corev1.ConfigMap{
42+
TypeMeta: metav1.TypeMeta{
43+
APIVersion: corev1.SchemeGroupVersion.String(),
44+
Kind: "ConfigMap",
45+
},
46+
ObjectMeta: metav1.ObjectMeta{
47+
Name: "sdk",
48+
Namespace: namespace.Name,
49+
},
50+
BinaryData: map[string][]byte{
51+
"sdk.py": ReadFile(test, "sdk.py"),
52+
},
53+
Immutable: Ptr(true),
54+
}
55+
sdk, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), sdk, metav1.CreateOptions{})
56+
test.Expect(err).NotTo(HaveOccurred())
57+
test.T().Logf("Created ConfigMap %s/%s successfully", sdk.Namespace, sdk.Name)
58+
59+
// pip requirements
60+
requirements := &corev1.ConfigMap{
61+
TypeMeta: metav1.TypeMeta{
62+
APIVersion: corev1.SchemeGroupVersion.String(),
63+
Kind: "ConfigMap",
64+
},
65+
ObjectMeta: metav1.ObjectMeta{
66+
Name: "requirements",
67+
Namespace: namespace.Name,
68+
},
69+
BinaryData: map[string][]byte{
70+
"requirements.txt": ReadFile(test, "requirements.txt"),
71+
},
72+
Immutable: Ptr(true),
73+
}
74+
requirements, err = test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), requirements, metav1.CreateOptions{})
75+
test.Expect(err).NotTo(HaveOccurred())
76+
test.T().Logf("Created ConfigMap %s/%s successfully", requirements.Namespace, requirements.Name)
77+
78+
job := &batchv1.Job{
79+
TypeMeta: metav1.TypeMeta{
80+
APIVersion: batchv1.SchemeGroupVersion.String(),
81+
Kind: "Job",
82+
},
83+
ObjectMeta: metav1.ObjectMeta{
84+
Name: "sdk",
85+
Namespace: namespace.Name,
86+
},
87+
Spec: batchv1.JobSpec{
88+
Completions: Ptr(int32(1)),
89+
Parallelism: Ptr(int32(1)),
90+
BackoffLimit: Ptr(int32(0)),
91+
Template: corev1.PodTemplateSpec{
92+
Spec: corev1.PodSpec{
93+
Containers: []corev1.Container{
94+
{
95+
Name: "sdk",
96+
Image: "quay.io/opendatahub/notebooks:jupyter-minimal-ubi8-python-3.8-4c8f26e",
97+
Command: []string{"/bin/sh", "-c", "pip install -r /test/runtime/requirements.txt && python /test/job/sdk.py"},
98+
VolumeMounts: []corev1.VolumeMount{
99+
{
100+
Name: "sdk",
101+
MountPath: "/test/job",
102+
},
103+
{
104+
Name: "requirements",
105+
MountPath: "/test/runtime",
106+
},
107+
},
108+
},
109+
},
110+
Volumes: []corev1.Volume{
111+
{
112+
Name: "sdk",
113+
VolumeSource: corev1.VolumeSource{
114+
ConfigMap: &corev1.ConfigMapVolumeSource{
115+
LocalObjectReference: corev1.LocalObjectReference{
116+
Name: sdk.Name,
117+
},
118+
},
119+
},
120+
},
121+
{
122+
Name: "requirements",
123+
VolumeSource: corev1.VolumeSource{
124+
ConfigMap: &corev1.ConfigMapVolumeSource{
125+
LocalObjectReference: corev1.LocalObjectReference{
126+
Name: requirements.Name,
127+
},
128+
},
129+
},
130+
},
131+
},
132+
RestartPolicy: corev1.RestartPolicyNever,
133+
},
134+
},
135+
},
136+
}
137+
job, err = test.Client().Core().BatchV1().Jobs(namespace.Name).Create(test.Ctx(), job, metav1.CreateOptions{})
138+
test.Expect(err).NotTo(HaveOccurred())
139+
140+
defer JobTroubleshooting(test, job)
141+
142+
test.T().Logf("Waiting for Job %s/%s to complete successfully", job.Namespace, job.Name)
143+
test.Eventually(Job(test, job.Namespace, job.Name), TestTimeoutMedium).
144+
Should(WithTransform(ConditionStatus(batchv1.JobComplete), Equal(corev1.ConditionTrue)))
145+
}

test/e2e/mnist_rayjob_mcad_raycluster_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
3939
namespace := test.NewTestNamespace()
4040

4141
// MNIST training script
42-
mnist, err := scripts.ReadFile("mnist.py")
43-
test.Expect(err).NotTo(HaveOccurred())
44-
45-
mnistScript := &corev1.ConfigMap{
42+
mnist := &corev1.ConfigMap{
4643
TypeMeta: metav1.TypeMeta{
4744
APIVersion: corev1.SchemeGroupVersion.String(),
4845
Kind: "ConfigMap",
@@ -52,13 +49,13 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
5249
Namespace: namespace.Name,
5350
},
5451
BinaryData: map[string][]byte{
55-
"mnist.py": mnist,
52+
"mnist.py": ReadFile(test, "mnist.py"),
5653
},
5754
Immutable: Ptr(true),
5855
}
59-
mnistScript, err = test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), mnistScript, metav1.CreateOptions{})
56+
mnist, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Create(test.Ctx(), mnist, metav1.CreateOptions{})
6057
test.Expect(err).NotTo(HaveOccurred())
61-
test.T().Logf("Created ConfigMap %s/%s successfully", mnistScript.Namespace, mnistScript.Name)
58+
test.T().Logf("Created ConfigMap %s/%s successfully", mnist.Namespace, mnist.Name)
6259

6360
// RayCluster
6461
rayCluster := &rayv1alpha1.RayCluster{
@@ -127,7 +124,7 @@ func TestMNISTRayJobMCADRayCluster(t *testing.T) {
127124
VolumeSource: corev1.VolumeSource{
128125
ConfigMap: &corev1.ConfigMapVolumeSource{
129126
LocalObjectReference: corev1.LocalObjectReference{
130-
Name: mnistScript.Name,
127+
Name: mnist.Name,
131128
},
132129
},
133130
},

test/e2e/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
codeflare-sdk==0.4.4

test/e2e/sdk.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration
2+
# from codeflare_sdk.cluster.auth import TokenAuthentication
3+
from codeflare_sdk.job.jobs import DDPJobDefinition
4+
5+
cluster = Cluster(ClusterConfiguration(
6+
name='mnist',
7+
# namespace='default',
8+
min_worker=1,
9+
max_worker=1,
10+
min_cpus=0.2,
11+
max_cpus=1,
12+
min_memory=0.5,
13+
max_memory=1,
14+
gpu=0,
15+
instascale=False,
16+
))
17+
18+
cluster.up()
19+
20+
cluster.status()
21+
22+
cluster.wait_ready()
23+
24+
cluster.status()
25+
26+
cluster.details()
27+
28+
jobdef = DDPJobDefinition(
29+
name="mnist",
30+
script="/test/job/mnist.py",
31+
scheduler_args={"requirements": "/test/runtime/requirements.txt"}
32+
)
33+
job = jobdef.submit(cluster)
34+
35+
job.status()
36+
37+
print(job.logs())
38+
39+
cluster.down()

test/e2e/support.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,19 @@ limitations under the License.
1616

1717
package e2e
1818

19-
import "embed"
19+
import (
20+
"embed"
2021

21-
//go:embed *.py
22-
var scripts embed.FS
22+
"github.com/onsi/gomega"
23+
24+
"github.com/project-codeflare/codeflare-operator/test/support"
25+
)
26+
27+
//go:embed *.py *.txt
28+
var files embed.FS
29+
30+
func ReadFile(t support.Test, fileName string) []byte {
31+
file, err := files.ReadFile(fileName)
32+
t.Expect(err).NotTo(gomega.HaveOccurred())
33+
return file
34+
}

test/support/batch.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"github.com/onsi/gomega"
2121

2222
batchv1 "k8s.io/api/batch/v1"
23+
corev1 "k8s.io/api/core/v1"
2324
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
25+
"k8s.io/apimachinery/pkg/labels"
2426
)
2527

2628
func Job(t Test, namespace, name string) func(g gomega.Gomega) *batchv1.Job {
@@ -35,3 +37,25 @@ func GetJob(t Test, namespace, name string) *batchv1.Job {
3537
t.T().Helper()
3638
return Job(t, namespace, name)(t)
3739
}
40+
41+
func JobTroubleshooting(test Test, job *batchv1.Job) {
42+
if !test.T().Failed() {
43+
return
44+
}
45+
job = GetJob(test, job.Namespace, job.Name)
46+
47+
test.T().Errorf("Job %s/%s hasn't completed in time: %s", job.Namespace, job.Name, job)
48+
49+
pods := GetPods(test, job.Namespace, metav1.ListOptions{
50+
LabelSelector: labels.FormatLabels(job.Spec.Selector.MatchLabels)},
51+
)
52+
53+
if len(pods) == 0 {
54+
test.T().Errorf("Job %s/%s has no pods scheduled", job.Namespace, job.Name)
55+
} else {
56+
for i, pod := range pods {
57+
test.T().Logf("Printing Pod %s/%s logs", pod.Namespace, pod.Name)
58+
test.T().Log(GetPodLogs(test, &pods[i], corev1.PodLogOptions{}))
59+
}
60+
}
61+
}

0 commit comments

Comments
 (0)