Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/pd-sidecar/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func main() {
inferencePoolNamespace := flag.String("inference-pool-namespace", os.Getenv("INFERENCE_POOL_NAMESPACE"), "the Kubernetes namespace to watch for InferencePool resources (defaults to INFERENCE_POOL_NAMESPACE env var)")
inferencePoolName := flag.String("inference-pool-name", os.Getenv("INFERENCE_POOL_NAME"), "the specific InferencePool name to watch (defaults to INFERENCE_POOL_NAME env var)")
enablePrefillerSampling := flag.Bool("enable-prefiller-sampling", func() bool { b, _ := strconv.ParseBool(os.Getenv("ENABLE_PREFILLER_SAMPLING")); return b }(), "if true, the target prefill instance will be selected randomly from among the provided prefill host values")
poolGroup := flag.String("pool-group", proxy.DefaultPoolGroup, "group of the InferencePool this Endpoint Picker is associated with.")

opts := zap.Options{}
opts.BindFlags(flag.CommandLine) // optional to allow zap logging control via CLI
Expand Down Expand Up @@ -135,7 +136,7 @@ func main() {
}

// Create SSRF protection validator
validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *inferencePoolNamespace, *inferencePoolName)
validator, err := proxy.NewAllowlistValidator(*enableSSRFProtection, *poolGroup, *inferencePoolNamespace, *inferencePoolName)
if err != nil {
logger.Error(err, "failed to create SSRF protection validator")
return
Expand Down
53 changes: 38 additions & 15 deletions pkg/sidecar/proxy/allowlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ import (
)

const (
inferencePoolGroup = "inference.networking.x-k8s.io"
inferencePoolVersion = "v1alpha2"
inferencePoolResource = "inferencepools"
resyncPeriod = 30 * time.Second
)

// InferencePool API group to version mapping
var inferencePoolGroupToVersion = map[string]string{
DefaultPoolGroup: "v1",
LegacyPoolGroup: "v1alpha2",
}

// AllowlistValidator manages allowed prefill targets based on InferencePool resources
type AllowlistValidator struct {
logger logr.Logger
Expand All @@ -52,6 +56,8 @@ type AllowlistValidator struct {
poolName string
enabled bool

gvr schema.GroupVersionResource // detected GVR

// allowedTargets maps hostport -> bool for allowed prefill targets
allowedTargets set.Set[string]
allowedTargetsMu sync.RWMutex
Expand All @@ -65,13 +71,26 @@ type AllowlistValidator struct {
}

// NewAllowlistValidator creates a new SSRF protection validator
func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*AllowlistValidator, error) {
func NewAllowlistValidator(enabled bool, poolGroup, namespace, poolName string) (*AllowlistValidator, error) {
if !enabled {
return &AllowlistValidator{
enabled: false,
}, nil
}

// Determine version based on poolGroup
version, exists := inferencePoolGroupToVersion[poolGroup]
if !exists {
return nil, fmt.Errorf("unsupported poolGroup: %s, "+
"must be one of %v", poolGroup, getSupportedPoolGroups())
}

gvr := schema.GroupVersionResource{
Group: poolGroup,
Version: version,
Resource: inferencePoolResource,
}

loadingRules := clientcmd.NewDefaultClientConfigLoadingRules()
overrides := &clientcmd.ConfigOverrides{}
config, err := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
Expand All @@ -92,39 +111,43 @@ func NewAllowlistValidator(enabled bool, namespace string, poolName string) (*Al
dynamicClient: dynamicClient,
namespace: namespace,
poolName: poolName,
gvr: gvr,
allowedTargets: set.New[string](),
podInformers: make(map[string]cache.SharedInformer),
podStopChans: make(map[string]chan struct{}),
stopCh: make(chan struct{}),
}, nil
}

func getSupportedPoolGroups() []string {
groups := make([]string, 0, len(inferencePoolGroupToVersion))
for group := range inferencePoolGroupToVersion {
groups = append(groups, group)
}
return groups
}

// Start begins watching InferencePool resources and managing the allowlist
func (av *AllowlistValidator) Start(ctx context.Context) error {
if !av.enabled {
return nil
}

av.logger = log.FromContext(ctx).WithName("allowlist-validator")
av.logger.Info("starting SSRF protection allowlist validator", "namespace", av.namespace, "poolName", av.poolName)

gvr := schema.GroupVersionResource{
Group: inferencePoolGroup,
Version: inferencePoolVersion,
Resource: inferencePoolResource,
}
av.logger.Info("starting SSRF protection allowlist validator",
"namespace", av.namespace, "poolName", av.poolName, "gvr", av.gvr.String())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

av.gvr is never set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh... my bad. done


// Create informer for the specific InferencePool resource
lw := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
ListWithContextFunc: func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
// List with field selector to get only the specific InferencePool
options.FieldSelector = "metadata.name=" + av.poolName
return av.dynamicClient.Resource(gvr).Namespace(av.namespace).List(ctx, options)
return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).List(ctx, options)
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
WatchFuncWithContext: func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
// Watch the specific InferencePool by name using field selector
options.FieldSelector = "metadata.name=" + av.poolName
return av.dynamicClient.Resource(gvr).Namespace(av.namespace).Watch(ctx, options)
return av.dynamicClient.Resource(av.gvr).Namespace(av.namespace).Watch(ctx, options)
},
}

Expand All @@ -142,7 +165,7 @@ func (av *AllowlistValidator) Start(ctx context.Context) error {

// Wait for cache sync
if !cache.WaitForCacheSync(av.stopCh, av.poolInformer.HasSynced) {
return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", inferencePoolGroup, av.poolName)
return fmt.Errorf("failed to sync InferencePool cache within timeout (check RBAC permissions for inferencepools.%s and that pool '%s' exists)", av.gvr.String(), av.poolName)
}

av.logger.Info("allowlist validator started successfully")
Expand Down
2 changes: 1 addition & 1 deletion pkg/sidecar/proxy/allowlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var _ = Describe("AllowlistValidator", func() {

BeforeEach(func() {
var err error
validator, err = NewAllowlistValidator(false, "test-namespace", "test-pool")
validator, err = NewAllowlistValidator(false, DefaultPoolGroup, "test-namespace", "test-pool")
Expect(err).ToNot(HaveOccurred())
})

Expand Down
2 changes: 1 addition & 1 deletion pkg/sidecar/proxy/data_parallel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var _ = Describe("Data Parallel support", func() {
DataParallelSize: testDataParallelSize,
}
theProxy := NewProxy(strconv.Itoa(fakeProxyPort), decodeURL, cfg)
theProxy.allowlistValidator, err = NewAllowlistValidator(false, "", "")
theProxy.allowlistValidator, err = NewAllowlistValidator(false, DefaultPoolGroup, "", "")
Expect(err).ToNot(HaveOccurred())

err = theProxy.startDataParallel(ctx, nil, grp)
Expand Down
5 changes: 5 additions & 0 deletions pkg/sidecar/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ const (

// ConnectorSGLang enables SGLang P/D disaggregation protocol
ConnectorSGLang = "sglang"

// DefaultPoolGroup is the default pool group name
DefaultPoolGroup = "inference.networking.k8s.io"
// LegacyPoolGroup is the legacy pool group name
LegacyPoolGroup = "inference.networking.x-k8s.io"
)

// Config represents the proxy server configuration
Expand Down