Skip to content

Commit 2a65147

Browse files
committed
feat(allowlist): support both v1 and v1alpha2 InferencePool APIs with auto-discovery
Signed-off-by: CYJiang <googs1025@gmail.com>
1 parent 2bc23f8 commit 2a65147

File tree

5 files changed

+47
-18
lines changed

5 files changed

+47
-18
lines changed

cmd/pd-sidecar/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ func main() {
5858
inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)")
5959
inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)")
6060
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")
61+
poolGroup := flag.String("pool-group", proxy.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")
6162

6263
opts := zap.Options{}
6364
opts.BindFlags(flag.CommandLine) // optional to allow zap logging control via CLI
@@ -135,7 +136,7 @@ func main() {
135136
}
136137

137138
// Create SSRF protection validator
138-
validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *inferencePoolNamespace, *inferencePoolName)
139+
validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *poolGroup, *inferencePoolNamespace, *inferencePoolName)
139140
if err != nil {
140141
logger.Error(err, "failed to create SSRF protection validator")
141142
return

pkg/sidecar/proxy/allowlist.go

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ import (
3838
)
3939

4040
const (
41-
inferencePoolGroup = "inference.networking.x-k8s.io"
42-
inferencePoolVersion = "v1alpha2"
4341
inferencePoolResource = "inferencepools"
4442
resyncPeriod = 30 * time.Second
4543
)
4644

45+
// InferencePool API group to version mapping
46+
var inferencePoolGroupToVersion = map[string]string{
47+
DefaultPoolGroup: "v1",
48+
LegacyPoolGroup: "v1alpha2",
49+
}
50+
4751
// AllowlistValidator manages allowed prefill targets based on InferencePool resources
4852
type AllowlistValidator struct {
4953
logger logr.Logger
@@ -52,6 +56,8 @@ type AllowlistValidator struct {
5256
poolName string
5357
enabled bool
5458

59+
gvr schema.GroupVersionResource // detected GVR
60+
5561
// allowedTargets maps hostport -> bool for allowed prefill targets
5662
allowedTargets set.Set[string]
5763
allowedTargetsMu sync.RWMutex
@@ -65,13 +71,26 @@ type AllowlistValidator struct {
6571
}
6672

6773
// NewAllowlistValidator creates a new SSRF protection validator
68-
func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*AllowlistValidator, error) {
74+
func NewAllowlistValidator(enabled bool, poolGroup, namespace, poolName string) (*AllowlistValidator, error) {
6975
if !enabled {
7076
return &AllowlistValidator{
7177
enabled: false,
7278
}, nil
7379
}
7480

81+
// Determine version based on poolGroup
82+
version, exists := inferencePoolGroupToVersion[poolGroup]
83+
if !exists {
84+
return nil, fmt.Errorf("unsupported poolGroup: %s, "+
85+
"must be one of %v", poolGroup, getSupportedPoolGroups())
86+
}
87+
88+
gvr := schema.GroupVersionResource{
89+
Group: poolGroup,
90+
Version: version,
91+
Resource: inferencePoolResource,
92+
}
93+
7594
loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
7695
overrides := &clientcmd.ConfigOverrides{}
7796
config, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
@@ -92,39 +111,43 @@ func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*Al
92111
dynamicClient: dynamicClient,
93112
namespace: namespace,
94113
poolName: poolName,
114+
gvr: gvr,
95115
allowedTargets: set.New[string](),
96116
podInformers: make(map[string]cache.SharedInformer),
97117
podStopChans: make(map[string]chan struct{}),
98118
stopCh: make(chan struct{}),
99119
}, nil
100120
}
101121

122+
func getSupportedPoolGroups() []string {
123+
groups := make([]string, 0, len(inferencePoolGroupToVersion))
124+
for group := range inferencePoolGroupToVersion {
125+
groups = append(groups, group)
126+
}
127+
return groups
128+
}
129+
102130
// Start begins watching InferencePool resources and managing the allowlist
103131
func (av *AllowlistValidator) Start(ctx context.Context) error {
104132
if !av.enabled {
105133
return nil
106134
}
107135

108136
av.logger = log.FromContext(ctx).WithName("allowlist-validator")
109-
av.logger.Info("starting SSRF protection allowlist validator", "namespace", av.namespace, "poolName", av.poolName)
110-
111-
gvr := schema.GroupVersionResource{
112-
Group: inferencePoolGroup,
113-
Version: inferencePoolVersion,
114-
Resource: inferencePoolResource,
115-
}
137+
av.logger.Info("starting SSRF protection allowlist validator",
138+
"namespace", av.namespace, "poolName", av.poolName, "gvr", av.gvr.String())
116139

117140
// Create informer for the specific InferencePool resource
118141
lw := &cache.ListWatch{
119-
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
142+
ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
120143
// List with field selector to get only the specific InferencePool
121144
options.FieldSelector = "metadata.name=" + av.poolName
122-
return av.dynamicClient.Resource(gvr).Namespace(av.namespace).List(ctx, options)
145+
return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).List(ctx, options)
123146
},
124-
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
147+
WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
125148
// Watch the specific InferencePool by name using field selector
126149
options.FieldSelector = "metadata.name=" + av.poolName
127-
return av.dynamicClient.Resource(gvr).Namespace(av.namespace).Watch(ctx, options)
150+
return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).Watch(ctx, options)
128151
},
129152
}
130153

@@ -142,7 +165,7 @@ func (av *AllowlistValidator) Start(ctx context.Context) error {
142165

143166
// Wait for cache sync
144167
if !cache.WaitForCacheSync(av.stopCh, av.poolInformer.HasSynced) {
145-
return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", inferencePoolGroup, av.poolName)
168+
return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", av.gvr.String(), av.poolName)
146169
}
147170

148171
av.logger.Info("allowlist validator started successfully")

pkg/sidecar/proxy/allowlist_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ var _ = Describe("AllowlistValidator", func() {
2828

2929
BeforeEach(func() {
3030
var err error
31-
validator, err = NewAllowlistValidator(false, "test-namespace", "test-pool")
31+
validator, err = NewAllowlistValidator(false, DefaultPoolGroup, "test-namespace", "test-pool")
3232
Expect(err).ToNot(HaveOccurred())
3333
})
3434

pkg/sidecar/proxy/data_parallel_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ var _ = Describe("Data Parallel support", func() {
5959
DataParallelSize: testDataParallelSize,
6060
}
6161
theProxy := NewProxy(strconv.Itoa(fakeProxyPort), decodeURL, cfg)
62-
theProxy.allowlistValidator, err = NewAllowlistValidator(false, "", "")
62+
theProxy.allowlistValidator, err = NewAllowlistValidator(false, DefaultPoolGroup, "", "")
6363
Expect(err).ToNot(HaveOccurred())
6464

6565
err = theProxy.startDataParallel(ctx, nil, grp)

pkg/sidecar/proxy/proxy.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ const (
6060

6161
// ConnectorSGLang enables SGLang P/D disaggregation protocol
6262
ConnectorSGLang = "sglang"
63+
64+
// DefaultPoolGroup is the default pool group name
65+
DefaultPoolGroup = "inference.networking.k8s.io"
66+
// LegacyPoolGroup is the legacy pool group name
67+
LegacyPoolGroup = "inference.networking.x-k8s.io"
6368
)
6469

6570
// Config represents the proxy server configuration

0 commit comments

Comments
 (0)