diff --git a/internal/controller/nodeset/nodeset_sync_test.go b/internal/controller/nodeset/nodeset_sync_test.go index 20dd9a36..18049658 100644 --- a/internal/controller/nodeset/nodeset_sync_test.go +++ b/internal/controller/nodeset/nodeset_sync_test.go @@ -864,7 +864,8 @@ func TestNodeSetReconciler_processCondemned(t *testing.T) { condemned: pods, i: 0, }, - wantErr: false, + // Expect error when Slurm node is not found + wantErr: true, wantDrain: false, wantDelete: true, } @@ -1054,7 +1055,9 @@ func TestNodeSetReconciler_processCondemned(t *testing.T) { } pod := tt.args.condemned[tt.args.i] if isDrain, err := r.slurmControl.IsNodeDrain(tt.args.ctx, tt.args.nodeset, pod); err != nil { - t.Errorf("slurmControl.IsNodeDrain() error = %v", err) + if !tt.wantDelete { + t.Errorf("slurmControl.IsNodeDrain() error = %v", err) + } } else if isDrain != tt.wantDrain && !tt.wantDelete { t.Errorf("slurmControl.IsNodeDrain() = %v, wantDrain %v", isDrain, tt.wantDrain) } @@ -1897,7 +1900,8 @@ func TestNodeSetReconciler_syncRollingUpdate(t *testing.T) { pods: []*corev1.Pod{pod1, pod2}, hash: hash, }, - wantErr: false, + // Expect error when Slurm node is not found + wantErr: true, } }(), } diff --git a/internal/controller/nodeset/slurmcontrol/slurmcontrol.go b/internal/controller/nodeset/slurmcontrol/slurmcontrol.go index 5c520be1..5913a793 100644 --- a/internal/controller/nodeset/slurmcontrol/slurmcontrol.go +++ b/internal/controller/nodeset/slurmcontrol/slurmcontrol.go @@ -280,15 +280,12 @@ func (r *realSlurmControl) IsNodeDrain(ctx context.Context, nodeset *slinkyv1bet if slurmClient == nil { logger.V(2).Info("no client for nodeset, cannot do IsNodeDrain()", "pod", klog.KObj(pod)) - return true, nil + return false, nil } slurmNode := &slurmtypes.V0044Node{} key := slurmobject.ObjectKey(nodesetutils.GetNodeName(pod)) if err := slurmClient.Get(ctx, key, slurmNode); err != nil { - if tolerateError(err) { - return true, nil - } return false, err } @@ -304,15 +301,12 @@ func (r *realSlurmControl) IsNodeDrained(ctx context.Context, nodeset *slinkyv1b if slurmClient == nil { logger.V(2).Info("no client for nodeset, cannot do IsNodeDrained()", "pod", klog.KObj(pod)) - return true, nil + return false, nil } slurmNode := &slurmtypes.V0044Node{} key := slurmobject.ObjectKey(nodesetutils.GetNodeName(pod)) if err := slurmClient.Get(ctx, key, slurmNode); err != nil { - if tolerateError(err) { - return true, nil - } return false, err } diff --git a/internal/controller/nodeset/slurmcontrol/slurmcontrol_test.go b/internal/controller/nodeset/slurmcontrol/slurmcontrol_test.go index 696b98d8..ae279a5d 100644 --- a/internal/controller/nodeset/slurmcontrol/slurmcontrol_test.go +++ b/internal/controller/nodeset/slurmcontrol/slurmcontrol_test.go @@ -465,6 +465,35 @@ func Test_realSlurmControl_IsNodeDrain(t *testing.T) { want: true, wantErr: false, }, + { + name: "No Slurm client - fail closed", + fields: fields{ + clientMap: clientmap.NewClientMap(), + }, + args: args{ + ctx: ctx, + nodeset: nodeset, + pod: pod, + }, + want: false, + wantErr: false, + }, + { + name: "Node not found - fail closed", + fields: func() fields { + sclient := fake.NewClientBuilder().Build() + return fields{ + clientMap: newSlurmClientMap(controller.Name, sclient), + } + }(), + args: args{ + ctx: ctx, + nodeset: nodeset, + pod: pod, + }, + want: false, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -750,6 +779,35 @@ func Test_realSlurmControl_IsNodeDrained(t *testing.T) { }, want: false, }, + { + name: "No Slurm client - fail closed", + fields: fields{ + clientMap: clientmap.NewClientMap(), + }, + args: args{ + ctx: ctx, + nodeset: nodeset, + pod: pod, + }, + want: false, + wantErr: false, + }, + { + name: "Node not found - fail closed", + fields: func() fields { + sclient := fake.NewClientBuilder().Build() + return fields{ + clientMap: newSlurmClientMap(controller.Name, sclient), + } + }(), + args: args{ + ctx: ctx, + nodeset: nodeset, + pod: pod, + }, + want: false, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {