Skip to content

Commit 940b6d9

Browse files
authored
Merge branch 'llm-d:main' into refactor_1
2 parents 986acc7 + b661e65 commit 940b6d9

File tree

3 files changed

+484
-0
lines changed

3 files changed

+484
-0
lines changed

pkg/plugins/filter/by_label_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
package filter
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"testing"
78

89
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
k8stypes "k8s.io/apimachinery/pkg/types"
12+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
13+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
14+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
915
)
1016

1117
func TestByLabelFactory(t *testing.T) {
@@ -130,3 +136,126 @@ func TestByLabelFactoryInvalidJSON(t *testing.T) {
130136
})
131137
}
132138
}
139+
140+
// Helper functions
141+
func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod {
142+
return &types.PodMetrics{
143+
Pod: &backend.Pod{
144+
NamespacedName: nsn,
145+
Address: ipaddr,
146+
Labels: labels,
147+
},
148+
MetricsState: &backendmetrics.MetricsState{},
149+
}
150+
}
151+
152+
func TestByLabelFiltering(t *testing.T) {
153+
pods := []types.Pod{
154+
createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"},
155+
"10.0.0.1",
156+
map[string]string{
157+
"app": "nginx",
158+
"version": "v1.0",
159+
"tier": "frontend",
160+
}),
161+
createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"},
162+
"10.0.0.2",
163+
map[string]string{
164+
"app": "nginx",
165+
"version": "v1.1",
166+
"tier": "frontend",
167+
}),
168+
createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"},
169+
"10.0.0.3",
170+
map[string]string{
171+
"app": "coredns",
172+
"tier": "system",
173+
}),
174+
createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"},
175+
"10.0.0.4",
176+
map[string]string{
177+
"app": "redis",
178+
"tier": "backend",
179+
"deprecated": "true",
180+
}),
181+
createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"},
182+
"10.0.0.5",
183+
map[string]string{
184+
"app": "web",
185+
"tier": "frontend",
186+
"environment": "production",
187+
}),
188+
createPod(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"},
189+
"10.0.0.6",
190+
map[string]string{
191+
"app": "unknown",
192+
}),
193+
}
194+
195+
tests := []struct {
196+
testName string
197+
label string
198+
validValues []string
199+
allowsNoLabel bool
200+
expectedPods []string // pod names that should match
201+
}{
202+
{
203+
testName: "match app nginx",
204+
label: "app",
205+
validValues: []string{"nginx"},
206+
expectedPods: []string{"nginx-1", "nginx-2"},
207+
allowsNoLabel: false,
208+
},
209+
{
210+
testName: "match exact version v1.0",
211+
label: "version",
212+
validValues: []string{"v1.0"},
213+
expectedPods: []string{"nginx-1"},
214+
allowsNoLabel: false,
215+
},
216+
{
217+
testName: "allow pods without 'tier' label",
218+
label: "tier",
219+
allowsNoLabel: true,
220+
expectedPods: []string{"no-tier-pod"}, // only "no-tier-pod" doesn't have a "tier" label
221+
},
222+
{
223+
testName: "allow all known tier values",
224+
label: "tier",
225+
validValues: []string{"frontend", "backend", "system"},
226+
allowsNoLabel: false,
227+
expectedPods: []string{"nginx-1", "nginx-2", "coredns-1", "redis-1", "web-1"},
228+
},
229+
}
230+
231+
for _, tt := range tests {
232+
t.Run(tt.testName, func(t *testing.T) {
233+
rawParams, err := json.Marshal(byLabelParameters{
234+
Label: tt.label,
235+
ValidValues: tt.validValues,
236+
AllowsNoLabel: tt.allowsNoLabel,
237+
})
238+
require.NoError(t, err)
239+
240+
plugin, err := ByLabelFactory("test-label", rawParams, nil)
241+
require.NoError(t, err)
242+
require.NotNil(t, plugin)
243+
244+
blf, ok := plugin.(*ByLabel)
245+
require.True(t, ok, "plugin should be of type *ByLabel")
246+
247+
ctx := context.Background()
248+
filteredPods := blf.Filter(ctx, nil, nil, pods)
249+
250+
var actualPodNames []string
251+
for _, pod := range filteredPods {
252+
actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name)
253+
}
254+
255+
assert.ElementsMatch(t, tt.expectedPods, actualPodNames,
256+
"filtered pods should match expected pods")
257+
assert.Len(t, filteredPods, len(tt.expectedPods),
258+
"filtered pods count should match expected count")
259+
})
260+
}
261+
}

pkg/plugins/profile/dp_profile_handler_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
package profile
22

33
import (
4+
"context"
45
"encoding/json"
6+
"net"
57
"testing"
68

79
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
12+
13+
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
814
)
915

1016
func TestDataParallelProfileHandlerFactory(t *testing.T) {
@@ -113,3 +119,93 @@ func TestDataParallelProfileHandlerFactoryInvalidJSON(t *testing.T) {
113119
})
114120
}
115121
}
122+
123+
func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
124+
tests := []struct {
125+
name string
126+
primaryPort int
127+
profileResults map[string]*types.ProfileRunResult
128+
expectError bool
129+
checkResult func(*testing.T, *types.SchedulingResult, map[string]string)
130+
}{
131+
{
132+
name: "error: multiple profiles not supported",
133+
profileResults: map[string]*types.ProfileRunResult{
134+
"profile1": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
135+
"profile2": newMockProfileRunResult(DefaultTestPodPort, "pod2"),
136+
},
137+
expectError: true,
138+
},
139+
{
140+
name: "error: single profile but result is nil",
141+
profileResults: map[string]*types.ProfileRunResult{
142+
"nil-profile": nil,
143+
},
144+
expectError: true,
145+
},
146+
{
147+
name: "success: single profile with primaryPort → port overridden, header set",
148+
primaryPort: 9000,
149+
profileResults: map[string]*types.ProfileRunResult{
150+
"dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
151+
},
152+
expectError: false,
153+
checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
154+
assert.Equal(t, "dp-profile", res.PrimaryProfileName)
155+
156+
pods := res.ProfileResults["dp-profile"].TargetPods
157+
require.Len(t, pods, 1)
158+
assert.Equal(t, "9000", pods[0].GetPod().Port) // overridden
159+
expectedHeader := net.JoinHostPort("10.0.0.1", DefaultTestPodPort) // original
160+
assert.Equal(t, expectedHeader, headers[common.DataParallelPodHeader])
161+
},
162+
},
163+
{
164+
name: "success: primaryPort=0 → port becomes '0'",
165+
primaryPort: 0,
166+
profileResults: map[string]*types.ProfileRunResult{
167+
"dp": newMockProfileRunResult("8080", "pod1"),
168+
},
169+
expectError: false,
170+
checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
171+
pod := res.ProfileResults["dp"].TargetPods[0]
172+
assert.Equal(t, "0", pod.GetPod().Port)
173+
assert.Equal(t, "10.0.0.1:8080", headers[common.DataParallelPodHeader])
174+
},
175+
},
176+
{
177+
name: "success: multiple target pods → all ports overridden",
178+
primaryPort: 8080,
179+
profileResults: map[string]*types.ProfileRunResult{
180+
"dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1", "pod2"),
181+
},
182+
expectError: false,
183+
checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
184+
pods := res.ProfileResults["dp-profile"].TargetPods
185+
assert.Len(t, pods, 2)
186+
for _, p := range pods {
187+
assert.Equal(t, "8080", p.GetPod().Port)
188+
}
189+
assert.Equal(t, net.JoinHostPort("10.0.0.1", DefaultTestPodPort), headers[common.DataParallelPodHeader])
190+
},
191+
},
192+
}
193+
194+
for _, tt := range tests {
195+
t.Run(tt.name, func(t *testing.T) {
196+
handler := NewDataParallelProfileHandler(tt.primaryPort).WithName("test-handler")
197+
headers := make(map[string]string)
198+
req := &types.LLMRequest{Headers: headers}
199+
result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults)
200+
201+
if tt.expectError {
202+
assert.Error(t, err)
203+
return
204+
}
205+
206+
require.NoError(t, err)
207+
require.NotNil(t, result)
208+
tt.checkResult(t, result, headers)
209+
})
210+
}
211+
}

0 commit comments

Comments
 (0)