diff --git a/cmd/pd-sidecar/main.go b/cmd/pd-sidecar/main.go index bd6b1b5d..cd086eb7 100644 --- a/cmd/pd-sidecar/main.go +++ b/cmd/pd-sidecar/main.go @@ -58,6 +58,7 @@ func main() { inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)") inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)") enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values") + poolGroup := flag.String("pool-group", proxy.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.") opts := zap.Options{} opts.BindFlags(flag.CommandLine) // optional to allow zap logging control via CLI @@ -135,7 +136,7 @@ func main() { } // Create SSRF protection validator - validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *inferencePoolNamespace, *inferencePoolName) + validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *poolGroup, *inferencePoolNamespace, *inferencePoolName) if err != nil { logger.Error(err, "failed to create SSRF protection validator") return diff --git a/pkg/sidecar/proxy/allowlist.go b/pkg/sidecar/proxy/allowlist.go index 8442e8b2..578d042c 100644 --- a/pkg/sidecar/proxy/allowlist.go +++ b/pkg/sidecar/proxy/allowlist.go @@ -38,12 +38,16 @@ import ( ) const ( - inferencePoolGroup = "inference.networking.x-k8s.io" - inferencePoolVersion = "v1alpha2" inferencePoolResource = "inferencepools" resyncPeriod = 30 * time.Second ) +// InferencePool API group to version mapping +var inferencePoolGroupToVersion = map[string]string{ + DefaultPoolGroup: "v1", + LegacyPoolGroup: "v1alpha2", +} + // AllowlistValidator manages allowed prefill targets based on InferencePool resources type AllowlistValidator struct { logger logr.Logger @@ -52,6 +56,8 @@ type AllowlistValidator struct { poolName string enabled bool + gvr schema.GroupVersionResource // detected GVR + // allowedTargets maps hostport -> bool for allowed prefill targets allowedTargets set.Set[string] allowedTargetsMu sync.RWMutex @@ -65,13 +71,26 @@ type AllowlistValidator struct { } // NewAllowlistValidator creates a new SSRF protection validator -func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*AllowlistValidator, error) { +func NewAllowlistValidator(enabled bool, poolGroup, namespace, poolName string) (*AllowlistValidator, error) { if !enabled { return &AllowlistValidator{ enabled: false, }, nil } + // Determine version based on poolGroup + version, exists := inferencePoolGroupToVersion[poolGroup] + if !exists { + return nil, fmt.Errorf("unsupported poolGroup: %s, "+ + "must be one of %v", poolGroup, getSupportedPoolGroups()) + } + + gvr := schema.GroupVersionResource{ + Group: poolGroup, + Version: version, + Resource: inferencePoolResource, + } + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() overrides := &clientcmd.ConfigOverrides{} config, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( @@ -92,6 +111,7 @@ func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*Al dynamicClient: dynamicClient, namespace: namespace, poolName: poolName, + gvr: gvr, allowedTargets: set.New[string](), podInformers: make(map[string]cache.SharedInformer), podStopChans: make(map[string]chan struct{}), @@ -99,6 +119,14 @@ func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*Al }, nil } +func getSupportedPoolGroups() []string { + groups := make([]string, 0, len(inferencePoolGroupToVersion)) + for group := range inferencePoolGroupToVersion { + groups = append(groups, group) + } + return groups +} + // Start begins watching InferencePool resources and managing the allowlist func (av *AllowlistValidator) Start(ctx context.Context) error { if !av.enabled { @@ -106,25 +134,20 @@ func (av *AllowlistValidator) Start(ctx context.Context) error { } av.logger = log.FromContext(ctx).WithName("allowlist-validator") - av.logger.Info("starting SSRF protection allowlist validator", "namespace", av.namespace, "poolName", av.poolName) - - gvr := schema.GroupVersionResource{ - Group: inferencePoolGroup, - Version: inferencePoolVersion, - Resource: inferencePoolResource, - } + av.logger.Info("starting SSRF protection allowlist validator", + "namespace", av.namespace, "poolName", av.poolName, "gvr", av.gvr.String()) // Create informer for the specific InferencePool resource lw := &cache.ListWatch{ - ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) { // List with field selector to get only the specific InferencePool options.FieldSelector = "metadata.name=" + av.poolName - return av.dynamicClient.Resource(gvr).Namespace(av.namespace).List(ctx, options) + return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).List(ctx, options) }, - WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) { // Watch the specific InferencePool by name using field selector options.FieldSelector = "metadata.name=" + av.poolName - return av.dynamicClient.Resource(gvr).Namespace(av.namespace).Watch(ctx, options) + return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).Watch(ctx, options) }, } @@ -142,7 +165,7 @@ func (av *AllowlistValidator) Start(ctx context.Context) error { // Wait for cache sync if !cache.WaitForCacheSync(av.stopCh, av.poolInformer.HasSynced) { - return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", inferencePoolGroup, av.poolName) + return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", av.gvr.String(), av.poolName) } av.logger.Info("allowlist validator started successfully") diff --git a/pkg/sidecar/proxy/allowlist_test.go b/pkg/sidecar/proxy/allowlist_test.go index 292f38ef..815032f2 100644 --- a/pkg/sidecar/proxy/allowlist_test.go +++ b/pkg/sidecar/proxy/allowlist_test.go @@ -28,7 +28,7 @@ var _ = Describe("AllowlistValidator", func() { BeforeEach(func() { var err error - validator, err = NewAllowlistValidator(false, "test-namespace", "test-pool") + validator, err = NewAllowlistValidator(false, DefaultPoolGroup, "test-namespace", "test-pool") Expect(err).ToNot(HaveOccurred()) }) diff --git a/pkg/sidecar/proxy/data_parallel_test.go b/pkg/sidecar/proxy/data_parallel_test.go index 3bd8118b..c957b726 100644 --- a/pkg/sidecar/proxy/data_parallel_test.go +++ b/pkg/sidecar/proxy/data_parallel_test.go @@ -59,7 +59,7 @@ var _ = Describe("Data Parallel support", func() { DataParallelSize: testDataParallelSize, } theProxy := NewProxy(strconv.Itoa(fakeProxyPort), decodeURL, cfg) - theProxy.allowlistValidator, err = NewAllowlistValidator(false, "", "") + theProxy.allowlistValidator, err = NewAllowlistValidator(false, DefaultPoolGroup, "", "") Expect(err).ToNot(HaveOccurred()) err = theProxy.startDataParallel(ctx, nil, grp) diff --git a/pkg/sidecar/proxy/proxy.go b/pkg/sidecar/proxy/proxy.go index a5467071..068d7eeb 100644 --- a/pkg/sidecar/proxy/proxy.go +++ b/pkg/sidecar/proxy/proxy.go @@ -60,6 +60,11 @@ const ( // ConnectorSGLang enables SGLang P/D disaggregation protocol ConnectorSGLang = "sglang" + + // DefaultPoolGroup is the default pool group name + DefaultPoolGroup = "inference.networking.k8s.io" + // LegacyPoolGroup is the legacy pool group name + LegacyPoolGroup = "inference.networking.x-k8s.io" ) // Config represents the proxy server configuration