From 4b8f0846b5554a7ab823aab486c77f97a1146c4c Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 22:39:32 +0000 Subject: [PATCH 01/13] Reorganize incoming auth factory into subfolder to prevent an import cycle Move incoming authentication factory from pkg/vmcp/auth/ to pkg/vmcp/auth/factory/ subfolder to improve code organization. This separates factory code from core authentication types and middleware. --- pkg/vmcp/auth/{incoming_factory.go => factory/incoming.go} | 2 +- .../auth/{incoming_factory_test.go => factory/incoming_test.go} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename pkg/vmcp/auth/{incoming_factory.go => factory/incoming.go} (99%) rename pkg/vmcp/auth/{incoming_factory_test.go => factory/incoming_test.go} (99%) diff --git a/pkg/vmcp/auth/incoming_factory.go b/pkg/vmcp/auth/factory/incoming.go similarity index 99% rename from pkg/vmcp/auth/incoming_factory.go rename to pkg/vmcp/auth/factory/incoming.go index 479876d2d..edb09a6cd 100644 --- a/pkg/vmcp/auth/incoming_factory.go +++ b/pkg/vmcp/auth/factory/incoming.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" diff --git a/pkg/vmcp/auth/incoming_factory_test.go b/pkg/vmcp/auth/factory/incoming_test.go similarity index 99% rename from pkg/vmcp/auth/incoming_factory_test.go rename to pkg/vmcp/auth/factory/incoming_test.go index e3f7a22cb..10bc65344 100644 --- a/pkg/vmcp/auth/incoming_factory_test.go +++ b/pkg/vmcp/auth/factory/incoming_test.go @@ -1,4 +1,4 @@ -package auth +package factory import ( "context" From 311ead5637e49b2aef6249bee8a9719bd6958ef9 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 13:31:15 +0000 Subject: [PATCH 02/13] Refactor outgoing auth to separate registry from strategy Rename OutgoingAuthenticator to OutgoingAuthRegistry to better reflect its responsibility as a strategy registry rather than an authenticator. The interface now focuses solely on strategy management (registration and retrieval), while authentication is performed directly by Strategy implementations. This separation improves performance by eliminating indirection in the hot path (per-request authentication) and clarifies the single responsibility of each component: the registry manages strategies, strategies perform authentication. --- pkg/vmcp/auth/auth.go | 45 +- pkg/vmcp/auth/outgoing_authenticator.go | 130 ------ pkg/vmcp/auth/outgoing_authenticator_test.go | 455 ------------------- pkg/vmcp/auth/outgoing_registry.go | 103 +++++ pkg/vmcp/auth/outgoing_registry_test.go | 263 +++++++++++ pkg/vmcp/doc.go | 14 +- pkg/vmcp/types.go | 2 +- 7 files changed, 404 insertions(+), 608 deletions(-) delete mode 100644 pkg/vmcp/auth/outgoing_authenticator.go delete mode 100644 pkg/vmcp/auth/outgoing_authenticator_test.go create mode 100644 pkg/vmcp/auth/outgoing_registry.go create mode 100644 pkg/vmcp/auth/outgoing_registry_test.go diff --git a/pkg/vmcp/auth/auth.go b/pkg/vmcp/auth/auth.go index 76f9626eb..455b6e71c 100644 --- a/pkg/vmcp/auth/auth.go +++ b/pkg/vmcp/auth/auth.go @@ -1,7 +1,7 @@ // Package auth provides authentication for Virtual MCP Server. // // This package defines: -// - OutgoingAuthenticator: Authenticates vMCP to backend servers +// - OutgoingAuthRegistry: Registry for managing backend authentication strategies // - Strategy: Pluggable authentication strategies for backends // // Incoming authentication uses pkg/auth middleware (OIDC, local, anonymous) @@ -17,24 +17,39 @@ import ( "github.com/stacklok/toolhive/pkg/auth" ) -// OutgoingAuthenticator handles authentication to backend MCP servers. -// This is responsible for obtaining and injecting appropriate credentials -// for each backend based on its authentication strategy. +// OutgoingAuthRegistry manages authentication strategies for outgoing requests to backend MCP servers. +// This is a registry that stores and retrieves Strategy implementations. // -// The specific authentication strategies and their behavior will be defined -// during implementation based on the design decisions documented in the -// Virtual MCP Server proposal. -type OutgoingAuthenticator interface { - // AuthenticateRequest adds authentication to an outgoing backend request. - // The strategy and metadata are provided in the BackendTarget. - AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error - - // GetStrategy returns the authentication strategy handler for a given strategy name. - // This enables extensibility - new strategies can be registered. +// The registry supports dynamic strategy registration, allowing custom authentication +// strategies to be added at runtime. Once registered, strategies can be retrieved +// by name and used to authenticate requests to backends. +// +// Responsibilities: +// - Maintain registry of available strategies +// - Retrieve strategies by name +// - Register new strategies dynamically +// +// This registry does NOT perform authentication itself. Authentication is performed +// by Strategy implementations retrieved from this registry. +// +// Usage Pattern: +// 1. Register strategies during application initialization +// 2. Resolve strategy once at client creation time (cold path) +// 3. Call strategy.Authenticate() directly per-request (hot path) +// +// Thread-safety: Implementations must be safe for concurrent access. +type OutgoingAuthRegistry interface { + // GetStrategy retrieves an authentication strategy by name. + // Returns an error if the strategy is not found. GetStrategy(name string) (Strategy, error) // RegisterStrategy registers a new authentication strategy. - // This allows custom auth strategies to be added at runtime. + // The strategy name must match the name returned by strategy.Name(). + // Returns an error if: + // - name is empty + // - strategy is nil + // - a strategy with the same name is already registered + // - strategy.Name() does not match the registration name RegisterStrategy(name string, strategy Strategy) error } diff --git a/pkg/vmcp/auth/outgoing_authenticator.go b/pkg/vmcp/auth/outgoing_authenticator.go deleted file mode 100644 index 6498f68dd..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator.go +++ /dev/null @@ -1,130 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "net/http" - "sync" -) - -// DefaultOutgoingAuthenticator is a thread-safe implementation of OutgoingAuthenticator -// that maintains a registry of authentication strategies. -// -// Thread-safety: Safe for concurrent calls to RegisterStrategy and AuthenticateRequest. -// Strategy implementations must be thread-safe as they are called concurrently. -// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. -// -// This authenticator supports dynamic registration of strategies and dispatches -// authentication requests to the appropriate strategy based on the strategy name. -// -// Example usage: -// -// auth := NewDefaultOutgoingAuthenticator() -// auth.RegisterStrategy("bearer", NewBearerStrategy()) -// err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) -type DefaultOutgoingAuthenticator struct { - strategies map[string]Strategy - mu sync.RWMutex -} - -// NewDefaultOutgoingAuthenticator creates a new DefaultOutgoingAuthenticator -// with an empty strategy registry. -// -// Strategies must be registered using RegisterStrategy before they can be used -// for authentication. -func NewDefaultOutgoingAuthenticator() *DefaultOutgoingAuthenticator { - return &DefaultOutgoingAuthenticator{ - strategies: make(map[string]Strategy), - } -} - -// RegisterStrategy registers a new authentication strategy. -// -// This method is thread-safe and validates that: -// - name is not empty -// - strategy is not nil -// - no strategy is already registered with the same name -// -// Parameters: -// - name: The unique identifier for this strategy -// - strategy: The Strategy implementation to register -// -// Returns an error if validation fails or a strategy with the same name -// already exists. -func (a *DefaultOutgoingAuthenticator) RegisterStrategy(name string, strategy Strategy) error { - if name == "" { - return errors.New("strategy name cannot be empty") - } - if strategy == nil { - return errors.New("strategy cannot be nil") - } - - a.mu.Lock() - defer a.mu.Unlock() - - if _, exists := a.strategies[name]; exists { - return fmt.Errorf("strategy %q is already registered", name) - } - - a.strategies[name] = strategy - return nil -} - -// GetStrategy retrieves an authentication strategy by name. -// -// This method is thread-safe for concurrent reads. It returns the strategy -// if found, or an error if no strategy is registered with the given name. -// -// Parameters: -// - name: The identifier of the strategy to retrieve -// -// Returns: -// - Strategy: The registered strategy -// - error: An error if the strategy is not found -func (a *DefaultOutgoingAuthenticator) GetStrategy(name string) (Strategy, error) { - a.mu.RLock() - defer a.mu.RUnlock() - - strategy, exists := a.strategies[name] - if !exists { - return nil, fmt.Errorf("strategy %q not found", name) - } - - return strategy, nil -} - -// AuthenticateRequest adds authentication to an outgoing backend request. -// -// This method retrieves the specified strategy and delegates authentication -// to it. The strategy modifies the request by adding appropriate headers, -// tokens, or other authentication artifacts. -// -// Parameters: -// - ctx: Request context (may contain identity for pass-through auth) -// - req: The HTTP request to authenticate -// - strategyName: The name of the strategy to use -// - metadata: Strategy-specific configuration -// -// Returns an error if: -// - The strategy is not found -// - The metadata validation fails -// - The strategy's Authenticate method fails -func (a *DefaultOutgoingAuthenticator) AuthenticateRequest( - ctx context.Context, - req *http.Request, - strategyName string, - metadata map[string]any, -) error { - strategy, err := a.GetStrategy(strategyName) - if err != nil { - return err - } - - // Validate metadata before using it - if err := strategy.Validate(metadata); err != nil { - return fmt.Errorf("invalid metadata for strategy %q: %w", strategyName, err) - } - - return strategy.Authenticate(ctx, req, metadata) -} diff --git a/pkg/vmcp/auth/outgoing_authenticator_test.go b/pkg/vmcp/auth/outgoing_authenticator_test.go deleted file mode 100644 index 43073bc7d..000000000 --- a/pkg/vmcp/auth/outgoing_authenticator_test.go +++ /dev/null @@ -1,455 +0,0 @@ -package auth - -import ( - "context" - "errors" - "net/http" - "net/http/httptest" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" -) - -type testContextKey struct{} - -var testKey = testContextKey{} - -func TestDefaultOutgoingAuthenticator_RegisterStrategy(t *testing.T) { - t.Parallel() - t.Run("register valid strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - - err := auth.RegisterStrategy("bearer", strategy) - - require.NoError(t, err) - // Verify strategy was registered - retrieved, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("register empty name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("", strategy) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy name cannot be empty") - }) - - t.Run("register nil strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - err := auth.RegisterStrategy("bearer", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "strategy cannot be nil") - }) - - t.Run("register duplicate name fails", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy1 := mocks.NewMockStrategy(ctrl) - strategy1.EXPECT().Name().Return("bearer").AnyTimes() - strategy2 := mocks.NewMockStrategy(ctrl) - - err := auth.RegisterStrategy("bearer", strategy1) - require.NoError(t, err) - - err = auth.RegisterStrategy("bearer", strategy2) - assert.Error(t, err) - assert.Contains(t, err.Error(), "already registered") - assert.Contains(t, err.Error(), "bearer") - }) - - t.Run("register multiple different strategies succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - bearer := mocks.NewMockStrategy(ctrl) - bearer.EXPECT().Name().Return("bearer").AnyTimes() - basic := mocks.NewMockStrategy(ctrl) - basic.EXPECT().Name().Return("basic").AnyTimes() - apiKey := mocks.NewMockStrategy(ctrl) - apiKey.EXPECT().Name().Return("api-key").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("bearer", bearer)) - require.NoError(t, auth.RegisterStrategy("basic", basic)) - require.NoError(t, auth.RegisterStrategy("api-key", apiKey)) - - // Verify all strategies are registered - s1, err := auth.GetStrategy("bearer") - require.NoError(t, err) - assert.Equal(t, bearer, s1) - - s2, err := auth.GetStrategy("basic") - require.NoError(t, err) - assert.Equal(t, basic, s2) - - s3, err := auth.GetStrategy("api-key") - require.NoError(t, err) - assert.Equal(t, apiKey, s3) - }) -} - -func TestDefaultOutgoingAuthenticator_GetStrategy(t *testing.T) { - t.Parallel() - t.Run("get existing strategy succeeds", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - retrieved, err := auth.GetStrategy("bearer") - - require.NoError(t, err) - assert.Equal(t, strategy, retrieved) - }) - - t.Run("get non-existent strategy fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("non-existent") - - assert.Error(t, err) - assert.Nil(t, retrieved) - assert.Contains(t, err.Error(), "not found") - assert.Contains(t, err.Error(), "non-existent") - }) - - t.Run("get from empty registry fails", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - - retrieved, err := auth.GetStrategy("bearer") - - assert.Error(t, err) - assert.Nil(t, retrieved) - }) -} - -func TestDefaultOutgoingAuthenticator_AuthenticateRequest(t *testing.T) { - t.Parallel() - t.Run("authenticates with valid strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - // Add a header to verify the request was modified - req.Header.Set("Authorization", "Bearer token123") - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"token": "token123"} - err := auth.AuthenticateRequest(context.Background(), req, "bearer", metadata) - - require.NoError(t, err) - assert.Equal(t, "Bearer token123", req.Header.Get("Authorization")) - }) - - t.Run("fails with non-existent strategy", func(t *testing.T) { - t.Parallel() - auth := NewDefaultOutgoingAuthenticator() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - - err := auth.AuthenticateRequest(context.Background(), req, "non-existent", nil) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("returns error from strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategyErr := errors.New("authentication failed") - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).Return(strategyErr) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - - assert.Error(t, err) - assert.Equal(t, strategyErr, err) - }) - - t.Run("passes context and metadata to strategy", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - var receivedCtx context.Context - var receivedMetadata map[string]any - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil) - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, _ *http.Request, metadata map[string]any) error { - receivedCtx = ctx - receivedMetadata = metadata - return nil - }, - ) - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - ctx := context.WithValue(context.Background(), testKey, "test-value") - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{ - "token": "abc123", - "scopes": []string{"read", "write"}, - } - - err := auth.AuthenticateRequest(ctx, req, "bearer", metadata) - - require.NoError(t, err) - assert.NotNil(t, receivedCtx) - assert.Equal(t, "test-value", receivedCtx.Value(testKey)) - assert.Equal(t, metadata, receivedMetadata) - }) - - t.Run("validates metadata before authentication", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("test-strategy").AnyTimes() - - require.NoError(t, auth.RegisterStrategy("test-strategy", strategy)) - - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - metadata := map[string]any{"invalid": "data"} - - // Expect Validate to be called and return error - strategy.EXPECT(). - Validate(metadata). - Return(errors.New("invalid metadata")) - - // Authenticate should NOT be called if validation fails - // (no EXPECT for Authenticate) - - err := auth.AuthenticateRequest(context.Background(), req, "test-strategy", metadata) - - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid metadata for strategy") - assert.Contains(t, err.Error(), "test-strategy") - }) -} - -func TestDefaultOutgoingAuthenticator_ConcurrentAccess(t *testing.T) { - t.Parallel() - t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Register multiple strategies - strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} - for _, name := range strategies { - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return(name).AnyTimes() - require.NoError(t, auth.RegisterStrategy(name, strategy)) - } - - // Test concurrent reads with -race detector - const numGoroutines = 100 - const numOperations = 1000 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines*numOperations) - - for i := 0; i < numGoroutines; i++ { - go func(_ int) { - defer wg.Done() - for j := 0; j < numOperations; j++ { - // Rotate through strategies - strategyName := strategies[j%len(strategies)] - strategy, err := auth.GetStrategy(strategyName) - if err != nil { - errs <- err - return - } - if strategy.Name() != strategyName { - errs <- errors.New("strategy name mismatch") - return - } - } - }(i) - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent access produced errors: %v", collectedErrors) - } - }) - - t.Run("concurrent AuthenticateRequest calls are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - // Counter to verify all authentications happen - var authCount int64 - var authMu sync.Mutex - - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("bearer").AnyTimes() - strategy.EXPECT().Validate(gomock.Any()).Return(nil).AnyTimes() - strategy.EXPECT().Authenticate(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, req *http.Request, _ map[string]any) error { - authMu.Lock() - authCount++ - authMu.Unlock() - req.Header.Set("Authorization", "Bearer test") - return nil - }, - ).AnyTimes() - require.NoError(t, auth.RegisterStrategy("bearer", strategy)) - - const numGoroutines = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - errs := make(chan error, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - err := auth.AuthenticateRequest(context.Background(), req, "bearer", nil) - if err != nil { - errs <- err - } - }() - } - - wg.Wait() - close(errs) - - // Check for errors - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent AuthenticateRequest produced errors: %v", collectedErrors) - } - - // Verify all authentications completed - assert.Equal(t, int64(numGoroutines), authCount) - }) - - t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - auth := NewDefaultOutgoingAuthenticator() - - const numRegister = 50 - const numGet = 50 - - var wg sync.WaitGroup - wg.Add(numRegister + numGet) - - errs := make(chan error, numRegister+numGet) - - // Goroutines registering strategies - for i := 0; i < numRegister; i++ { - go func(id int) { - defer wg.Done() - strategy := mocks.NewMockStrategy(ctrl) - strategy.EXPECT().Name().Return("strategy").AnyTimes() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - err := auth.RegisterStrategy(strategyName, strategy) - if err != nil { - errs <- err - } - }(i) - } - - // Goroutines reading strategies (will mostly fail, but shouldn't race) - for i := 0; i < numGet; i++ { - go func(id int) { - defer wg.Done() - strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) - // GetStrategy may return error if not registered yet, that's OK - _, _ = auth.GetStrategy(strategyName) - }(i) - } - - wg.Wait() - close(errs) - - // Check for unexpected errors (registration errors are not expected) - var collectedErrors []error - for err := range errs { - collectedErrors = append(collectedErrors, err) - } - - if len(collectedErrors) > 0 { - t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) - } - }) -} diff --git a/pkg/vmcp/auth/outgoing_registry.go b/pkg/vmcp/auth/outgoing_registry.go new file mode 100644 index 000000000..04f2513a3 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry.go @@ -0,0 +1,103 @@ +package auth + +import ( + "errors" + "fmt" + "sync" +) + +// DefaultOutgoingAuthRegistry is a thread-safe implementation of OutgoingAuthRegistry +// that maintains a registry of authentication strategies. +// +// Thread-safety: Safe for concurrent calls to RegisterStrategy and GetStrategy. +// Strategy implementations must be thread-safe as they are called concurrently. +// It uses sync.RWMutex for thread-safety as HTTP servers are inherently concurrent. +// +// This registry supports dynamic registration of strategies and retrieval by name. +// It does not perform authentication itself - that is done by the Strategy implementations. +// +// Example usage: +// +// registry := NewDefaultOutgoingAuthRegistry() +// registry.RegisterStrategy("header_injection", NewHeaderInjectionStrategy()) +// strategy, err := registry.GetStrategy("header_injection") +// if err == nil { +// err = strategy.Authenticate(ctx, req, metadata) +// } +type DefaultOutgoingAuthRegistry struct { + strategies map[string]Strategy + mu sync.RWMutex +} + +// NewDefaultOutgoingAuthRegistry creates a new DefaultOutgoingAuthRegistry +// with an empty strategy registry. +// +// Strategies must be registered using RegisterStrategy before they can be used +// for authentication. +func NewDefaultOutgoingAuthRegistry() *DefaultOutgoingAuthRegistry { + return &DefaultOutgoingAuthRegistry{ + strategies: make(map[string]Strategy), + } +} + +// RegisterStrategy registers a new authentication strategy. +// +// This method is thread-safe and validates that: +// - name is not empty +// - strategy is not nil +// - strategy.Name() matches the registration name +// - no strategy is already registered with the same name +// +// Parameters: +// - name: The unique identifier for this strategy +// - strategy: The Strategy implementation to register +// +// Returns an error if validation fails or a strategy with the same name +// already exists. +func (r *DefaultOutgoingAuthRegistry) RegisterStrategy(name string, strategy Strategy) error { + if name == "" { + return errors.New("strategy name cannot be empty") + } + if strategy == nil { + return errors.New("strategy cannot be nil") + } + + // Validate that strategy name matches registration name + if name != strategy.Name() { + return fmt.Errorf("strategy name mismatch: registered as %q but strategy.Name() returns %q", + name, strategy.Name()) + } + + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.strategies[name]; exists { + return fmt.Errorf("strategy %q is already registered", name) + } + + r.strategies[name] = strategy + return nil +} + +// GetStrategy retrieves an authentication strategy by name. +// +// This method is thread-safe for concurrent reads. It returns the strategy +// if found, or an error if no strategy is registered with the given name. +// +// Parameters: +// - name: The identifier of the strategy to retrieve +// +// Returns: +// - Strategy: The registered strategy +// - error: An error if the strategy is not found +func (r *DefaultOutgoingAuthRegistry) GetStrategy(name string) (Strategy, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + strategy, exists := r.strategies[name] + if !exists { + return nil, fmt.Errorf("strategy %q not found", name) + } + + return strategy, nil +} diff --git a/pkg/vmcp/auth/outgoing_registry_test.go b/pkg/vmcp/auth/outgoing_registry_test.go new file mode 100644 index 000000000..3d2e8a495 --- /dev/null +++ b/pkg/vmcp/auth/outgoing_registry_test.go @@ -0,0 +1,263 @@ +package auth + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" +) + +func TestDefaultOutgoingAuthRegistry_RegisterStrategy(t *testing.T) { + t.Parallel() + t.Run("register valid strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy) + + require.NoError(t, err) + // Verify strategy was registered + retrieved, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("register empty name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + + err := registry.RegisterStrategy("", strategy) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy name cannot be empty") + }) + + t.Run("register nil strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + err := registry.RegisterStrategy("bearer", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "strategy cannot be nil") + }) + + t.Run("register duplicate name fails", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy1 := mocks.NewMockStrategy(ctrl) + strategy1.EXPECT().Name().Return("bearer").AnyTimes() + strategy2 := mocks.NewMockStrategy(ctrl) + strategy2.EXPECT().Name().Return("bearer").AnyTimes() + + err := registry.RegisterStrategy("bearer", strategy1) + require.NoError(t, err) + + err = registry.RegisterStrategy("bearer", strategy2) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already registered") + assert.Contains(t, err.Error(), "bearer") + }) + + t.Run("register multiple different strategies succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + bearer := mocks.NewMockStrategy(ctrl) + bearer.EXPECT().Name().Return("bearer").AnyTimes() + basic := mocks.NewMockStrategy(ctrl) + basic.EXPECT().Name().Return("basic").AnyTimes() + apiKey := mocks.NewMockStrategy(ctrl) + apiKey.EXPECT().Name().Return("api-key").AnyTimes() + + require.NoError(t, registry.RegisterStrategy("bearer", bearer)) + require.NoError(t, registry.RegisterStrategy("basic", basic)) + require.NoError(t, registry.RegisterStrategy("api-key", apiKey)) + + // Verify all strategies are registered + s1, err := registry.GetStrategy("bearer") + require.NoError(t, err) + assert.Equal(t, bearer, s1) + + s2, err := registry.GetStrategy("basic") + require.NoError(t, err) + assert.Equal(t, basic, s2) + + s3, err := registry.GetStrategy("api-key") + require.NoError(t, err) + assert.Equal(t, apiKey, s3) + }) +} + +func TestDefaultOutgoingAuthRegistry_GetStrategy(t *testing.T) { + t.Parallel() + t.Run("get existing strategy succeeds", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return("bearer").AnyTimes() + require.NoError(t, registry.RegisterStrategy("bearer", strategy)) + + retrieved, err := registry.GetStrategy("bearer") + + require.NoError(t, err) + assert.Equal(t, strategy, retrieved) + }) + + t.Run("get non-existent strategy fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("non-existent") + + assert.Error(t, err) + assert.Nil(t, retrieved) + assert.Contains(t, err.Error(), "not found") + assert.Contains(t, err.Error(), "non-existent") + }) + + t.Run("get from empty registry fails", func(t *testing.T) { + t.Parallel() + registry := NewDefaultOutgoingAuthRegistry() + + retrieved, err := registry.GetStrategy("bearer") + + assert.Error(t, err) + assert.Nil(t, retrieved) + }) +} + +func TestDefaultOutgoingAuthRegistry_ConcurrentAccess(t *testing.T) { + t.Parallel() + t.Run("concurrent GetStrategy calls are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + // Register multiple strategies + strategies := []string{"bearer", "basic", "api-key", "oauth2", "jwt"} + for _, name := range strategies { + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(name).AnyTimes() + require.NoError(t, registry.RegisterStrategy(name, strategy)) + } + + // Test concurrent reads with -race detector + const numGoroutines = 100 + const numOperations = 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + errs := make(chan error, numGoroutines*numOperations) + + for i := 0; i < numGoroutines; i++ { + go func(_ int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + // Rotate through strategies + strategyName := strategies[j%len(strategies)] + strategy, err := registry.GetStrategy(strategyName) + if err != nil { + errs <- err + return + } + if strategy.Name() != strategyName { + errs <- errors.New("strategy name mismatch") + return + } + } + }(i) + } + + wg.Wait() + close(errs) + + // Check for errors + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent access produced errors: %v", collectedErrors) + } + }) + + t.Run("concurrent RegisterStrategy and GetStrategy are thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + registry := NewDefaultOutgoingAuthRegistry() + + const numRegister = 50 + const numGet = 50 + + var wg sync.WaitGroup + wg.Add(numRegister + numGet) + + errs := make(chan error, numRegister+numGet) + + // Goroutines registering strategies + for i := 0; i < numRegister; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + strategy := mocks.NewMockStrategy(ctrl) + strategy.EXPECT().Name().Return(strategyName).AnyTimes() + err := registry.RegisterStrategy(strategyName, strategy) + if err != nil { + errs <- err + } + }(i) + } + + // Goroutines reading strategies (will mostly fail, but shouldn't race) + for i := 0; i < numGet; i++ { + go func(id int) { + defer wg.Done() + strategyName := "strategy-" + string(rune('A'+id%26)) + string(rune('0'+id/26)) + // GetStrategy may return error if not registered yet, that's OK + _, _ = registry.GetStrategy(strategyName) + }(i) + } + + wg.Wait() + close(errs) + + // Check for unexpected errors (registration errors are not expected) + var collectedErrors []error + for err := range errs { + collectedErrors = append(collectedErrors, err) + } + + if len(collectedErrors) > 0 { + t.Fatalf("concurrent RegisterStrategy/GetStrategy produced errors: %v", collectedErrors) + } + }) +} diff --git a/pkg/vmcp/doc.go b/pkg/vmcp/doc.go index 246b03d2c..f81f8561b 100644 --- a/pkg/vmcp/doc.go +++ b/pkg/vmcp/doc.go @@ -83,12 +83,11 @@ // Middleware() func(http.Handler) http.Handler // } // -// OutgoingAuthenticator (pkg/vmcp/auth): +// OutgoingAuthRegistry (pkg/vmcp/auth): // -// type OutgoingAuthenticator interface { -// AuthenticateRequest(ctx context.Context, req *http.Request, strategy string, metadata map[string]any) error -// GetStrategy(name string) (AuthStrategy, error) -// RegisterStrategy(name string, strategy AuthStrategy) error +// type OutgoingAuthRegistry interface { +// GetStrategy(name string) (Strategy, error) +// RegisterStrategy(name string, strategy Strategy) error // } // // # Design Principles @@ -137,9 +136,10 @@ // // Route to backend // target, err := rtr.RouteTool(ctx, toolName) // -// // Authenticate to backend +// // Authenticate to backend (resolve strategy and call it) // backendReq := createBackendRequest(...) -// err = outAuth.AuthenticateRequest(ctx, backendReq, target.AuthStrategy, target.AuthMetadata) +// strategy, err := outAuth.GetStrategy(target.AuthStrategy) +// err = strategy.Authenticate(ctx, backendReq, target.AuthMetadata) // // // Forward request and return response // // ... diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index fcb7ea2c8..70d7f3b7b 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -35,7 +35,7 @@ type BackendTarget struct { OriginalCapabilityName string // AuthStrategy identifies the authentication strategy for this backend. - // The actual authentication is handled by OutgoingAuthenticator interface. + // The actual authentication is handled by OutgoingAuthRegistry interface. // Examples: "pass_through", "token_exchange", "client_credentials", "oauth_proxy" AuthStrategy string From 84f3bf362e6e5adac363fbe8f6255c4681b06f12 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:15:16 +0000 Subject: [PATCH 03/13] Add factory package to resolve auth import cycle Introduces pkg/vmcp/auth/factory to break the circular dependency between pkg/vmcp/auth and pkg/vmcp/auth/strategies. The import cycle occurred because: - auth package needed to import strategies to instantiate them - strategies package imported auth for Identity and context helpers The factory package sits at the composition layer and can import both auth (for interfaces) and strategies (for implementations) without creating cycles. --- pkg/vmcp/auth/factory/outgoing.go | 166 ++++++ pkg/vmcp/auth/factory/outgoing_test.go | 743 +++++++++++++++++++++++++ 2 files changed, 909 insertions(+) create mode 100644 pkg/vmcp/auth/factory/outgoing.go create mode 100644 pkg/vmcp/auth/factory/outgoing_test.go diff --git a/pkg/vmcp/auth/factory/outgoing.go b/pkg/vmcp/auth/factory/outgoing.go new file mode 100644 index 000000000..1c7cf7254 --- /dev/null +++ b/pkg/vmcp/auth/factory/outgoing.go @@ -0,0 +1,166 @@ +// Copyright 2025 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package factory provides factory functions for creating vMCP authentication components. +package factory + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// NewOutgoingAuthRegistry creates an OutgoingAuthRegistry from configuration. +// It registers all strategies found in the configuration (both default and backend-specific). +// +// The factory ALWAYS registers the "unauthenticated" strategy as a default fallback, +// ensuring that backends without explicit authentication configuration can function. +// This makes empty/nil configuration safe: the registry will have at least one +// usable strategy. +// +// Strategy Registration: +// - "unauthenticated" is always registered (default fallback) +// - Additional strategies are registered based on configuration +// - Each strategy is instantiated once and shared across backends +// - Strategies are stateless (except token_exchange which has internal caching) +// +// Parameters: +// - ctx: Context for any initialization that requires it +// - cfg: The outgoing authentication configuration (may be nil) +// +// Returns: +// - auth.OutgoingAuthRegistry: Configured registry with registered strategies +// - error: Any error during strategy initialization or registration +func NewOutgoingAuthRegistry(_ context.Context, cfg *config.OutgoingAuthConfig) (auth.OutgoingAuthRegistry, error) { + registry := auth.NewDefaultOutgoingAuthRegistry() + + // ALWAYS register the unauthenticated strategy as the default fallback. + if err := registerUnauthenticatedStrategy(registry); err != nil { + return nil, err + } + + // Handle nil config gracefully - return registry with unauthenticated strategy + if cfg == nil { + return registry, nil + } + + // Validate configuration structure + if err := validateConfig(cfg); err != nil { + return nil, err + } + + // Collect and register all unique strategy types from configuration + strategyTypes := collectStrategyTypes(cfg) + if err := registerStrategies(registry, strategyTypes); err != nil { + return nil, err + } + + return registry, nil +} + +// registerUnauthenticatedStrategy registers the default unauthenticated strategy. +func registerUnauthenticatedStrategy(registry auth.OutgoingAuthRegistry) error { + unauthStrategy := strategies.NewUnauthenticatedStrategy() + if err := registry.RegisterStrategy("unauthenticated", unauthStrategy); err != nil { + return fmt.Errorf("failed to register default unauthenticated strategy: %w", err) + } + return nil +} + +// validateConfig validates the configuration structure. +func validateConfig(cfg *config.OutgoingAuthConfig) error { + if cfg.Default != nil && strings.TrimSpace(cfg.Default.Type) == "" { + return fmt.Errorf("default auth strategy type cannot be empty") + } + + for backendID, backendCfg := range cfg.Backends { + if backendCfg != nil && strings.TrimSpace(backendCfg.Type) == "" { + return fmt.Errorf("backend %q has empty auth strategy type", backendID) + } + } + + return nil +} + +// collectStrategyTypes collects all unique strategy types from configuration. +func collectStrategyTypes(cfg *config.OutgoingAuthConfig) map[string]struct{} { + strategyTypes := make(map[string]struct{}) + + // Add default strategy type if present + if cfg.Default != nil && cfg.Default.Type != "" { + strategyTypes[cfg.Default.Type] = struct{}{} + } + + // Add all backend strategy types + for _, backendCfg := range cfg.Backends { + if backendCfg != nil && backendCfg.Type != "" { + strategyTypes[backendCfg.Type] = struct{}{} + } + } + + return strategyTypes +} + +// registerStrategies instantiates and registers each unique strategy type. +func registerStrategies(registry auth.OutgoingAuthRegistry, strategyTypes map[string]struct{}) error { + for strategyType := range strategyTypes { + // Skip "unauthenticated" - already registered + if strategyType == "unauthenticated" { + continue + } + + strategy, err := createStrategy(strategyType) + if err != nil { + return fmt.Errorf("failed to create strategy %q: %w", strategyType, err) + } + + if err := registry.RegisterStrategy(strategyType, strategy); err != nil { + return fmt.Errorf("failed to register strategy %q: %w", strategyType, err) + } + } + + return nil +} + +// createStrategy instantiates a strategy based on its type. +// +// Each strategy instance is stateless (except token_exchange which has internal caching). +// This function validates that the strategy type is not empty and returns an appropriate +// error for unknown strategy types. +// +// Parameters: +// - strategyType: The type identifier of the strategy to create +// +// Returns: +// - auth.Strategy: The instantiated strategy +// - error: Any error during strategy creation or validation +func createStrategy(strategyType string) (auth.Strategy, error) { + // Validate strategy type is not empty + if strings.TrimSpace(strategyType) == "" { + return nil, fmt.Errorf("strategy type cannot be empty") + } + + switch strategyType { + case "header_injection": + return strategies.NewHeaderInjectionStrategy(), nil + case "unauthenticated": + return strategies.NewUnauthenticatedStrategy(), nil + default: + return nil, fmt.Errorf("unknown strategy type: %s", strategyType) + } +} diff --git a/pkg/vmcp/auth/factory/outgoing_test.go b/pkg/vmcp/auth/factory/outgoing_test.go new file mode 100644 index 000000000..1b551a149 --- /dev/null +++ b/pkg/vmcp/auth/factory/outgoing_test.go @@ -0,0 +1,743 @@ +package factory + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +func TestNewOutgoingAuthRegistry(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.OutgoingAuthConfig + wantErr bool + errContains string + checkRegistry func(t *testing.T, cfg *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) + }{ + { + name: "nil config returns registry with unauthenticated strategy", + cfg: nil, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + require.NotNil(t, registry) + + // Registry should have unauthenticated strategy + strategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, strategy) + }, + }, + { + name: "empty config returns registry with unauthenticated strategy", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + require.NotNil(t, registry) + + // Registry should have unauthenticated strategy + strategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, strategy) + }, + }, + { + name: "default strategy with empty type fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "", + }, + }, + wantErr: true, + errContains: "default auth strategy type cannot be empty", + }, + { + name: "default strategy with whitespace type fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: " ", + }, + }, + wantErr: true, + errContains: "default auth strategy type cannot be empty", + }, + { + name: "backend strategy with empty type fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "", + }, + }, + }, + wantErr: true, + errContains: "backend \"github\" has empty auth strategy type", + }, + { + name: "backend strategy with whitespace type fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: " \t ", + }, + }, + }, + wantErr: true, + errContains: "backend \"github\" has empty auth strategy type", + }, + { + name: "unknown strategy type in default fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "unknown_strategy", + }, + }, + wantErr: true, + errContains: "unknown strategy type: unknown_strategy", + }, + { + name: "unknown strategy type in backend fails", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "magic_auth", + }, + }, + }, + wantErr: true, + errContains: "unknown strategy type: magic_auth", + }, + { + name: "valid header_injection in default succeeds", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "header_injection", + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both unauthenticated and header_injection + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + }, + }, + { + name: "valid header_injection in backend succeeds", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both unauthenticated and header_injection + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + }, + }, + { + name: "multiple backends with same strategy type registers once", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + "gitlab": { + Type: "header_injection", + }, + "jira": { + Type: "header_injection", + }, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both strategies + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + + // Verify all backends can use the same strategy instance + // (This tests that we collect unique strategy types) + }, + }, + { + name: "default and backend with different strategies registers both", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both strategies + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + }, + }, + { + name: "default unauthenticated does not register duplicate", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have unauthenticated strategy + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + }, + }, + { + name: "backend unauthenticated does not register duplicate", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "unauthenticated", + }, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have unauthenticated strategy + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + }, + }, + { + name: "complex config with multiple backends and strategies succeeds", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + Metadata: map[string]any{ + "headers": map[string]string{ + "Authorization": "Bearer token", + }, + }, + }, + "gitlab": { + Type: "header_injection", + Metadata: map[string]any{ + "headers": map[string]string{ + "Private-Token": "token", + }, + }, + }, + "public-api": { + Type: "unauthenticated", + }, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both strategies + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + }, + }, + { + name: "nil backend in backends map is ignored", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + "ignored": nil, + }, + }, + wantErr: false, + checkRegistry: func(t *testing.T, _ *config.OutgoingAuthConfig, registry auth.OutgoingAuthRegistry) { + t.Helper() + + // Should have both strategies + unauthStrategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, unauthStrategy) + + headerStrategy, err := registry.GetStrategy("header_injection") + require.NoError(t, err) + assert.NotNil(t, headerStrategy) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + registry, err := NewOutgoingAuthRegistry(ctx, tt.cfg) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, registry) + } else { + require.NoError(t, err) + require.NotNil(t, registry) + if tt.checkRegistry != nil { + tt.checkRegistry(t, tt.cfg, registry) + } + } + }) + } +} + +func TestValidateConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.OutgoingAuthConfig + wantErr bool + errContains string + }{ + { + name: "valid config with default and backends", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantErr: false, + }, + { + name: "empty default type fails", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "", + }, + }, + wantErr: true, + errContains: "default auth strategy type cannot be empty", + }, + { + name: "whitespace default type fails", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: " \t\n ", + }, + }, + wantErr: true, + errContains: "default auth strategy type cannot be empty", + }, + { + name: "empty backend type fails with backend id in error", + cfg: &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "my-backend": { + Type: "", + }, + }, + }, + wantErr: true, + errContains: `backend "my-backend" has empty auth strategy type`, + }, + { + name: "multiple backends with empty types fails on first", + cfg: &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "backend1": { + Type: "", + }, + "backend2": { + Type: "", + }, + }, + }, + wantErr: true, + // Will fail on one of them (map iteration order is random) + errContains: "has empty auth strategy type", + }, + { + name: "nil backend entries are allowed", + cfg: &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "backend1": nil, + "backend2": { + Type: "header_injection", + }, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := validateConfig(tt.cfg) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCollectStrategyTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.OutgoingAuthConfig + wantTypes []string + wantTypeCount int + }{ + { + name: "empty config returns empty set", + cfg: &config.OutgoingAuthConfig{ + Source: "inline", + }, + wantTypes: []string{}, + wantTypeCount: 0, + }, + { + name: "default strategy only", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "header_injection", + }, + }, + wantTypes: []string{"header_injection"}, + wantTypeCount: 1, + }, + { + name: "backend strategy only", + cfg: &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantTypes: []string{"header_injection"}, + wantTypeCount: 1, + }, + { + name: "duplicate strategy types are deduplicated", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "header_injection", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + "gitlab": { + Type: "header_injection", + }, + }, + }, + wantTypes: []string{"header_injection"}, + wantTypeCount: 1, + }, + { + name: "multiple different strategy types", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantTypes: []string{"unauthenticated", "header_injection"}, + wantTypeCount: 2, + }, + { + name: "nil default is ignored", + cfg: &config.OutgoingAuthConfig{ + Default: nil, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantTypes: []string{"header_injection"}, + wantTypeCount: 1, + }, + { + name: "nil backends are ignored", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "unauthenticated", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": nil, + "gitlab": nil, + "backend": {Type: "header_injection"}, + }, + }, + wantTypes: []string{"unauthenticated", "header_injection"}, + wantTypeCount: 2, + }, + { + name: "empty type strings are ignored", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "", + }, + Backends: map[string]*config.BackendAuthStrategy{ + "github": { + Type: "header_injection", + }, + }, + }, + wantTypes: []string{"header_injection"}, + wantTypeCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategyTypes := collectStrategyTypes(tt.cfg) + + assert.Equal(t, tt.wantTypeCount, len(strategyTypes), "unexpected number of strategy types") + + // Check all expected types are present + for _, expectedType := range tt.wantTypes { + _, exists := strategyTypes[expectedType] + assert.True(t, exists, "expected strategy type %q not found", expectedType) + } + }) + } +} + +func TestCreateStrategy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + strategyType string + wantErr bool + errContains string + checkStrategy func(t *testing.T, strategy interface{}) + }{ + { + name: "header_injection creates strategy", + strategyType: "header_injection", + wantErr: false, + checkStrategy: func(t *testing.T, strategy interface{}) { + t.Helper() + require.NotNil(t, strategy) + // Verify it's the right type by checking Name() + named := strategy.(interface{ Name() string }) + assert.Equal(t, "header_injection", named.Name()) + }, + }, + { + name: "unauthenticated creates strategy", + strategyType: "unauthenticated", + wantErr: false, + checkStrategy: func(t *testing.T, strategy interface{}) { + t.Helper() + require.NotNil(t, strategy) + named := strategy.(interface{ Name() string }) + assert.Equal(t, "unauthenticated", named.Name()) + }, + }, + { + name: "unknown strategy type fails", + strategyType: "magic_auth", + wantErr: true, + errContains: "unknown strategy type: magic_auth", + }, + { + name: "empty strategy type fails", + strategyType: "", + wantErr: true, + errContains: "strategy type cannot be empty", + }, + { + name: "whitespace strategy type fails", + strategyType: " \t\n ", + wantErr: true, + errContains: "strategy type cannot be empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy, err := createStrategy(tt.strategyType) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, strategy) + } else { + require.NoError(t, err) + require.NotNil(t, strategy) + if tt.checkStrategy != nil { + tt.checkStrategy(t, strategy) + } + } + }) + } +} + +func TestRegisterUnauthenticatedStrategy(t *testing.T) { + t.Parallel() + + t.Run("successfully registers unauthenticated strategy", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + registry, err := NewOutgoingAuthRegistry(ctx, nil) + require.NoError(t, err) + require.NotNil(t, registry) + + // Should be able to retrieve the strategy + strategy, err := registry.GetStrategy("unauthenticated") + require.NoError(t, err) + assert.NotNil(t, strategy) + assert.Equal(t, "unauthenticated", strategy.Name()) + }) +} + +func TestErrorMessages(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.OutgoingAuthConfig + wantErrSubstring []string // All of these should be in the error + }{ + { + name: "unknown strategy error includes strategy name", + cfg: &config.OutgoingAuthConfig{ + Default: &config.BackendAuthStrategy{ + Type: "nonexistent_strategy", + }, + }, + wantErrSubstring: []string{"unknown strategy type", "nonexistent_strategy"}, + }, + { + name: "empty backend type error includes backend id", + cfg: &config.OutgoingAuthConfig{ + Backends: map[string]*config.BackendAuthStrategy{ + "my-special-backend": { + Type: "", + }, + }, + }, + wantErrSubstring: []string{"backend", "my-special-backend", "empty"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + _, err := NewOutgoingAuthRegistry(ctx, tt.cfg) + require.Error(t, err) + + errMsg := err.Error() + for _, substring := range tt.wantErrSubstring { + assert.True(t, strings.Contains(errMsg, substring), + "error message %q should contain %q", errMsg, substring) + } + }) + } +} From c807a61df7f9d3599465978d12032f7fc14e11d7 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 13:31:46 +0000 Subject: [PATCH 04/13] Integrate authentication registry into HTTP backend client Refactors HTTPBackendClient to accept an OutgoingAuthRegistry and apply authentication strategies to all backend requests via a new authRoundTripper middleware. Authentication is now resolved and validated once at client creation time rather than per-request, improving performance and enabling early error detection for misconfigurations. The authRoundTripper clones requests to preserve immutability before applying authentication, ensuring thread-safety and preventing unintended side effects. --- pkg/vmcp/client/client.go | 135 +++++- pkg/vmcp/client/client_test.go | 426 +++++++++++++++++- .../client/mocks/mock_outgoing_registry.go | 70 +++ 3 files changed, 607 insertions(+), 24 deletions(-) create mode 100644 pkg/vmcp/client/mocks/mock_outgoing_registry.go diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index aaaf9cc59..aadc1dae4 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -17,6 +17,7 @@ import ( "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" ) const ( @@ -44,14 +45,30 @@ type httpBackendClient struct { // clientFactory creates MCP clients for backends. // Abstracted as a function to enable testing with mock clients. clientFactory func(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) + + // registry manages authentication strategies for outgoing requests to backend MCP servers. + // Must not be nil - use UnauthenticatedStrategy for no authentication. + registry auth.OutgoingAuthRegistry } // NewHTTPBackendClient creates a new HTTP-based backend client. // This client supports streamable-HTTP and SSE transports. -func NewHTTPBackendClient() vmcp.BackendClient { - return &httpBackendClient{ - clientFactory: defaultClientFactory, +// +// The registry parameter manages authentication strategies for outgoing requests to backend MCP servers. +// It must not be nil. To disable authentication, use a registry configured with the +// "unauthenticated" strategy. +// +// Returns an error if registry is nil. +func NewHTTPBackendClient(registry auth.OutgoingAuthRegistry) (vmcp.BackendClient, error) { + if registry == nil { + return nil, fmt.Errorf("registry cannot be nil; use UnauthenticatedStrategy for no authentication") + } + + c := &httpBackendClient{ + registry: registry, } + c.clientFactory = c.defaultClientFactory + return c, nil } // roundTripperFunc is a function adapter for http.RoundTripper. @@ -62,29 +79,103 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } +// authRoundTripper is an http.RoundTripper that adds authentication to backend requests. +// The authentication strategy and metadata are pre-resolved and validated at client creation time, +// eliminating per-request lookups and validation overhead. +type authRoundTripper struct { + base http.RoundTripper + authStrategy auth.Strategy + authMetadata map[string]any + target *vmcp.BackendTarget +} + +// RoundTrip implements http.RoundTripper by adding authentication headers to requests. +// The authentication strategy was pre-resolved and validated at client creation time, +// so this method simply applies the authentication without any lookups or validation. +func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone request to avoid modifying the original + reqClone := req.Clone(req.Context()) + + // Apply pre-resolved authentication strategy + if err := a.authStrategy.Authenticate(reqClone.Context(), reqClone, a.authMetadata); err != nil { + return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) + } + + logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) + + return a.base.RoundTrip(reqClone) +} + +// resolveAuthStrategy resolves the authentication strategy for a backend target. +// It handles defaulting to "unauthenticated" when no strategy is specified. +// This method should be called once at client creation time to enable fail-fast +// behavior for invalid authentication configurations. +func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (auth.Strategy, error) { + strategyName := target.AuthStrategy + + // Default to unauthenticated if not specified + if strategyName == "" { + strategyName = "unauthenticated" + } + + // Resolve strategy from registry + strategy, err := h.registry.GetStrategy(strategyName) + if err != nil { + return nil, fmt.Errorf("authentication strategy %q not found: %w", strategyName, err) + } + + return strategy, nil +} + // defaultClientFactory creates mark3labs MCP clients for different transport types. -func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { - // Create HTTP client with response size limits for DoS protection +func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { + // Build transport chain: size limit → authentication → HTTP + var baseTransport http.RoundTripper = http.DefaultTransport + + // Resolve authentication strategy ONCE at client creation time + authStrategy, err := h.resolveAuthStrategy(target) + if err != nil { + return nil, fmt.Errorf("failed to resolve authentication for backend %s: %w", + target.WorkloadID, err) + } + + // Validate metadata ONCE at client creation time + if err := authStrategy.Validate(target.AuthMetadata); err != nil { + return nil, fmt.Errorf("invalid authentication configuration for backend %s: %w", + target.WorkloadID, err) + } + + // Add authentication layer with pre-resolved strategy + baseTransport = &authRoundTripper{ + base: baseTransport, + authStrategy: authStrategy, + authMetadata: target.AuthMetadata, + target: target, + } + + // Add size limit layer for DoS protection + sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { + resp, err := baseTransport.RoundTrip(req) + if err != nil { + return nil, err + } + // Wrap response body with size limit + resp.Body = struct { + io.Reader + io.Closer + }{ + Reader: io.LimitReader(resp.Body, maxResponseSize), + Closer: resp.Body, + } + return resp, nil + }) + + // Create HTTP client with configured transport chain httpClient := &http.Client{ - Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { - resp, err := http.DefaultTransport.RoundTrip(req) - if err != nil { - return nil, err - } - // Wrap response body with size limit - resp.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.LimitReader(resp.Body, maxResponseSize), - Closer: resp.Body, - } - return resp, nil - }), + Transport: sizeLimitedTransport, } var c *client.Client - var err error switch target.TransportType { case "streamable-http", "streamable": @@ -93,8 +184,6 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli transport.WithHTTPTimeout(0), transport.WithContinuousListening(), transport.WithHTTPBasicClient(httpClient), - // TODO: Add authentication header injection via WithHTTPHeaderFunc - // This will be implemented when we add OutgoingAuthenticator support ) if err != nil { return nil, fmt.Errorf("failed to create streamable-http client: %w", err) diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 2a7619cb0..4e1c38837 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -1,15 +1,23 @@ package client +//go:generate mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry + import ( "context" "errors" + "net/http" + "net/http/httptest" "testing" "github.com/mark3labs/mcp-go/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" + authmocks "github.com/stacklok/toolhive/pkg/vmcp/auth/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" ) func TestHTTPBackendClient_ListCapabilities_WithMockFactory(t *testing.T) { @@ -76,7 +84,16 @@ func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) { TransportType: tc.transportType, } - _, err := defaultClientFactory(context.Background(), target) + // Create authenticator with unauthenticated strategy for testing + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + err := mockRegistry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + + backendClient, err := NewHTTPBackendClient(mockRegistry) + require.NoError(t, err) + httpClient := backendClient.(*httpBackendClient) + + _, err = httpClient.defaultClientFactory(context.Background(), target) require.Error(t, err) assert.ErrorIs(t, err, vmcp.ErrUnsupportedTransport) @@ -189,3 +206,410 @@ func TestInitializeClient_ErrorHandling(t *testing.T) { assert.NotNil(t, initializeClient) }) } + +// mockRoundTripper is a test implementation of http.RoundTripper that captures requests +type mockRoundTripper struct { + capturedReq *http.Request + response *http.Response + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.capturedReq = req + if m.err != nil { + return nil, m.err + } + return m.response, nil +} + +func TestAuthRoundTripper_RoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupStrategy func(*gomock.Controller) auth.Strategy + baseTransportResp *http.Response + baseTransportErr error + expectError bool + errorContains string + checkRequest func(t *testing.T, originalReq, capturedReq *http.Request) + checkBaseTransport func(t *testing.T, baseTransport *mockRoundTripper) + }{ + { + name: "successful authentication adds headers and forwards request", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Simulate adding auth header + req.Header.Set("Authorization", "Bearer test-token") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should not be modified + assert.Empty(t, originalReq.Header.Get("Authorization")) + // Captured request should have auth header + assert.Equal(t, "Bearer test-token", capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "unauthenticated strategy skips authentication", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unauthenticated", + AuthMetadata: nil, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("unauthenticated"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + gomock.Nil(), + ). + DoAndReturn(func(_ context.Context, _ *http.Request, _ map[string]any) error { + // UnauthenticatedStrategy does nothing + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Neither request should have auth headers + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, capturedReq.Header.Get("Authorization")) + }, + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "authentication failure returns error without calling base transport", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(errors.New("auth failed")) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: true, + errorContains: "authentication failed for backend backend-1", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should NOT have been called + assert.Nil(t, baseTransport.capturedReq) + }, + }, + { + name: "base transport error propagates after successful auth", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + Return(nil) + return mockStrategy + }, + baseTransportErr: errors.New("connection refused"), + expectError: true, + errorContains: "connection refused", + checkBaseTransport: func(t *testing.T, baseTransport *mockRoundTripper) { + t.Helper() + // Base transport should have been called + assert.NotNil(t, baseTransport.capturedReq) + }, + }, + { + name: "request immutability - original request unchanged", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "pass_through", + AuthMetadata: map[string]any{"key": "value"}, + }, + setupStrategy: func(ctrl *gomock.Controller) auth.Strategy { + mockStrategy := authmocks.NewMockStrategy(ctrl) + mockStrategy.EXPECT(). + Name(). + Return("pass_through"). + AnyTimes() + mockStrategy.EXPECT(). + Authenticate( + gomock.Any(), + gomock.Any(), + map[string]any{"key": "value"}, + ). + DoAndReturn(func(_ context.Context, req *http.Request, _ map[string]any) error { + // Modify the cloned request + req.Header.Set("Authorization", "Bearer modified-token") + req.Header.Set("X-Custom-Header", "custom-value") + return nil + }) + return mockStrategy + }, + baseTransportResp: &http.Response{StatusCode: http.StatusOK}, + expectError: false, + checkRequest: func(t *testing.T, originalReq, capturedReq *http.Request) { + t.Helper() + // Original request should be completely unmodified + assert.Empty(t, originalReq.Header.Get("Authorization")) + assert.Empty(t, originalReq.Header.Get("X-Custom-Header")) + + // Captured (cloned) request should have modifications + assert.Equal(t, "Bearer modified-token", capturedReq.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedReq.Header.Get("X-Custom-Header")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + // Setup mock strategy + var mockStrategy auth.Strategy + if tt.setupStrategy != nil { + mockStrategy = tt.setupStrategy(ctrl) + } + + // Setup mock base transport + baseTransport := &mockRoundTripper{ + response: tt.baseTransportResp, + err: tt.baseTransportErr, + } + + // Create authRoundTripper with pre-resolved strategy + authRT := &authRoundTripper{ + base: baseTransport, + authStrategy: mockStrategy, + authMetadata: tt.target.AuthMetadata, + target: tt.target, + } + + // Create test request + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + ctx := context.Background() + req = req.WithContext(ctx) + + // Execute RoundTrip + resp, err := authRT.RoundTrip(req) + + // Check error expectations + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + } + + // Check request modifications if specified + if tt.checkRequest != nil { + tt.checkRequest(t, req, baseTransport.capturedReq) + } + + // Check base transport calls if specified + if tt.checkBaseTransport != nil { + tt.checkBaseTransport(t, baseTransport) + } + }) + } +} + +func TestNewHTTPBackendClient_NilRegistry(t *testing.T) { + t.Parallel() + + t.Run("returns error when registry is nil", func(t *testing.T) { + t.Parallel() + + client, err := NewHTTPBackendClient(nil) + + require.Error(t, err) + assert.Nil(t, client) + assert.Contains(t, err.Error(), "registry cannot be nil") + assert.Contains(t, err.Error(), "UnauthenticatedStrategy") + }) + + t.Run("succeeds with valid registry", func(t *testing.T) { + t.Parallel() + + mockRegistry := auth.NewDefaultOutgoingAuthRegistry() + client, err := NewHTTPBackendClient(mockRegistry) + + require.NoError(t, err) + assert.NotNil(t, client) + }) +} + +func TestResolveAuthStrategy(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + target *vmcp.BackendTarget + setupRegistry func() auth.OutgoingAuthRegistry + expectError bool + errorContains string + checkStrategy func(t *testing.T, strategy auth.Strategy) + }{ + { + name: "defaults to unauthenticated when strategy is empty", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "unauthenticated", strategy.Name()) + }, + }, + { + name: "resolves explicitly configured strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "header_injection", + AuthMetadata: map[string]any{"header_name": "X-API-Key", "api_key": "test-key"}, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("header_injection", strategies.NewHeaderInjectionStrategy()) + require.NoError(t, err) + return registry + }, + expectError: false, + checkStrategy: func(t *testing.T, strategy auth.Strategy) { + t.Helper() + assert.Equal(t, "header_injection", strategy.Name()) + }, + }, + { + name: "returns error for unknown strategy", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "unknown_strategy", + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + registry := auth.NewDefaultOutgoingAuthRegistry() + err := registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{}) + require.NoError(t, err) + return registry + }, + expectError: true, + errorContains: "authentication strategy \"unknown_strategy\" not found", + }, + { + name: "returns error when unauthenticated strategy not registered", + target: &vmcp.BackendTarget{ + WorkloadID: "backend-1", + AuthStrategy: "", // Empty strategy defaults to unauthenticated + AuthMetadata: nil, + }, + setupRegistry: func() auth.OutgoingAuthRegistry { + // Don't register unauthenticated strategy + return auth.NewDefaultOutgoingAuthRegistry() + }, + expectError: true, + errorContains: "authentication strategy \"unauthenticated\" not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + registry := tt.setupRegistry() + backendClient, err := NewHTTPBackendClient(registry) + require.NoError(t, err) + + httpClient := backendClient.(*httpBackendClient) + + // Call resolveAuthStrategy + strategy, err := httpClient.resolveAuthStrategy(tt.target) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, strategy) + } else { + require.NoError(t, err) + assert.NotNil(t, strategy) + if tt.checkStrategy != nil { + tt.checkStrategy(t, strategy) + } + } + }) + } +} diff --git a/pkg/vmcp/client/mocks/mock_outgoing_registry.go b/pkg/vmcp/client/mocks/mock_outgoing_registry.go new file mode 100644 index 000000000..e18e65e05 --- /dev/null +++ b/pkg/vmcp/client/mocks/mock_outgoing_registry.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/stacklok/toolhive/pkg/vmcp/auth (interfaces: OutgoingAuthRegistry) +// +// Generated by this command: +// +// mockgen -destination=mocks/mock_outgoing_registry.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/auth OutgoingAuthRegistry +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + auth "github.com/stacklok/toolhive/pkg/vmcp/auth" + gomock "go.uber.org/mock/gomock" +) + +// MockOutgoingAuthRegistry is a mock of OutgoingAuthRegistry interface. +type MockOutgoingAuthRegistry struct { + ctrl *gomock.Controller + recorder *MockOutgoingAuthRegistryMockRecorder + isgomock struct{} +} + +// MockOutgoingAuthRegistryMockRecorder is the mock recorder for MockOutgoingAuthRegistry. +type MockOutgoingAuthRegistryMockRecorder struct { + mock *MockOutgoingAuthRegistry +} + +// NewMockOutgoingAuthRegistry creates a new mock instance. +func NewMockOutgoingAuthRegistry(ctrl *gomock.Controller) *MockOutgoingAuthRegistry { + mock := &MockOutgoingAuthRegistry{ctrl: ctrl} + mock.recorder = &MockOutgoingAuthRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOutgoingAuthRegistry) EXPECT() *MockOutgoingAuthRegistryMockRecorder { + return m.recorder +} + +// GetStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) GetStrategy(name string) (auth.Strategy, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStrategy", name) + ret0, _ := ret[0].(auth.Strategy) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetStrategy indicates an expected call of GetStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) GetStrategy(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).GetStrategy), name) +} + +// RegisterStrategy mocks base method. +func (m *MockOutgoingAuthRegistry) RegisterStrategy(name string, strategy auth.Strategy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterStrategy", name, strategy) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterStrategy indicates an expected call of RegisterStrategy. +func (mr *MockOutgoingAuthRegistryMockRecorder) RegisterStrategy(name, strategy any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterStrategy", reflect.TypeOf((*MockOutgoingAuthRegistry)(nil).RegisterStrategy), name, strategy) +} From a8c3be5a4449e88cdee02a691b4030a6b147ce4b Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:16:26 +0000 Subject: [PATCH 05/13] Apply auth configuration in backend discoverer The CLI backend discoverer now accepts authentication configuration and applies it to discovered backends during the discovery process. This change enables per-backend authentication by: - Adding authConfig parameter to NewCLIBackendDiscoverer constructor - Implementing resolveAuthConfig() to select backend-specific or default authentication settings with proper precedence - Populating Backend.AuthStrategy and Backend.AuthMetadata fields during backend creation Authentication configuration follows this precedence: 1. Backend-specific configuration (cfg.Backends[backendID]) 2. Default configuration (cfg.Default) 3. No authentication (if neither is configured) The populated authentication fields are later consumed when converting Backend instances to BackendTarget for use by the HTTP client's authRoundTripper. --- pkg/vmcp/aggregator/cli_discoverer.go | 45 +++++++++++++++++++++- pkg/vmcp/aggregator/cli_discoverer_test.go | 18 ++++----- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pkg/vmcp/aggregator/cli_discoverer.go b/pkg/vmcp/aggregator/cli_discoverer.go index b96350b53..c1dec6b41 100644 --- a/pkg/vmcp/aggregator/cli_discoverer.go +++ b/pkg/vmcp/aggregator/cli_discoverer.go @@ -8,6 +8,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/workloads" ) @@ -16,14 +17,23 @@ import ( type cliBackendDiscoverer struct { workloadsManager workloads.Manager groupsManager groups.Manager + authConfig *config.OutgoingAuthConfig } // NewCLIBackendDiscoverer creates a new CLI-based backend discoverer. // It discovers workloads from Docker/Podman containers managed by ToolHive. -func NewCLIBackendDiscoverer(workloadsManager workloads.Manager, groupsManager groups.Manager) BackendDiscoverer { +// +// The authConfig parameter configures authentication for discovered backends. +// If nil, backends will have no authentication configured. +func NewCLIBackendDiscoverer( + workloadsManager workloads.Manager, + groupsManager groups.Manager, + authConfig *config.OutgoingAuthConfig, +) BackendDiscoverer { return &cliBackendDiscoverer{ workloadsManager: workloadsManager, groupsManager: groupsManager, + authConfig: authConfig, } } @@ -92,6 +102,16 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ Metadata: make(map[string]string), } + // Apply authentication configuration if provided + if d.authConfig != nil { + authStrategy, authMetadata := d.resolveAuthConfig(name) + backend.AuthStrategy = authStrategy + backend.AuthMetadata = authMetadata + if authStrategy != "" { + logger.Debugf("Backend %s configured with auth strategy: %s", name, authStrategy) + } + } + // Copy user labels to metadata first for k, v := range workload.Labels { backend.Metadata[k] = v @@ -116,6 +136,29 @@ func (d *cliBackendDiscoverer) Discover(ctx context.Context, groupRef string) ([ return backends, nil } +// resolveAuthConfig determines the authentication strategy and metadata for a backend. +// It checks for backend-specific configuration first, then falls back to default. +func (d *cliBackendDiscoverer) resolveAuthConfig(backendID string) (string, map[string]any) { + if d.authConfig == nil { + return "", nil + } + + // Check for backend-specific configuration + if strategy, exists := d.authConfig.Backends[backendID]; exists && strategy != nil { + logger.Debugf("Using backend-specific auth strategy for %s: %s", backendID, strategy.Type) + return strategy.Type, strategy.Metadata + } + + // Fall back to default configuration + if d.authConfig.Default != nil { + logger.Debugf("Using default auth strategy for %s: %s", backendID, d.authConfig.Default.Type) + return d.authConfig.Default.Type, d.authConfig.Default.Metadata + } + + // No authentication configured + return "", nil +} + // mapWorkloadStatusToHealth converts a workload status to a backend health status. func mapWorkloadStatusToHealth(status rt.WorkloadStatus) vmcp.BackendHealthStatus { switch status { diff --git a/pkg/vmcp/aggregator/cli_discoverer_test.go b/pkg/vmcp/aggregator/cli_discoverer_test.go index 19e1de944..9c3402fad 100644 --- a/pkg/vmcp/aggregator/cli_discoverer_test.go +++ b/pkg/vmcp/aggregator/cli_discoverer_test.go @@ -45,7 +45,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "running-workload").Return(runningWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped-workload").Return(stoppedWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -108,7 +108,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workloadWithURL, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workloadWithoutURL, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -133,7 +133,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload1").Return(workload1, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "workload2").Return(workload2, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -150,7 +150,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "nonexistent-group").Return(false, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "nonexistent-group") require.Error(t, err) @@ -168,7 +168,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), testGroupName).Return(false, errors.New("database error")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.Error(t, err) @@ -187,7 +187,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockGroups.EXPECT().Exists(gomock.Any(), "empty-group").Return(true, nil) mockWorkloads.EXPECT().ListWorkloadsInGroup(gomock.Any(), "empty-group").Return([]string{}, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), "empty-group") require.NoError(t, err) @@ -214,7 +214,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "stopped1").Return(stoppedWorkload, nil) mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "error1").Return(errorWorkload, nil) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) @@ -240,7 +240,7 @@ func TestCLIBackendDiscoverer_Discover(t *testing.T) { mockWorkloads.EXPECT().GetWorkload(gomock.Any(), "failing-workload"). Return(core.Workload{}, errors.New("workload query failed")) - discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups) + discoverer := NewCLIBackendDiscoverer(mockWorkloads, mockGroups, nil) backends, err := discoverer.Discover(context.Background(), testGroupName) require.NoError(t, err) From 43d322b276b83d44390804d30c45122d500c2c9c Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 21:17:14 +0000 Subject: [PATCH 06/13] Complete outgoing authentication integration in serve command Finalizes the end-to-end authentication flow by connecting the authentication factory, backend discoverer, and HTTP client in the serve command. This enables vMCP proxy to authenticate requests to downstream MCP servers using configured authentication strategies. The serve command now: - Creates outgoing authenticator from configuration using the factory - Provides authentication config to backend discoverer for setup - Supplies authenticator to HTTP client for request signing - Uses factory for incoming authentication middleware (consistency) This completes the authentication architecture where configuration flows through the factory to create strategies that are applied by the client's round tripper to outgoing requests. Also simplifies redundant type annotation in client variable declaration for consistency with Go style conventions. --- cmd/vmcp/app/commands.go | 18 ++++++++++++++---- pkg/vmcp/client/client.go | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 96007a152..dc209c81f 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -12,7 +12,7 @@ import ( "github.com/stacklok/toolhive/pkg/groups" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" - vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/config" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" @@ -213,8 +213,15 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create groups manager: %w", err) } + // Create outgoing authentication registry from configuration + logger.Info("Initializing outgoing authentication") + outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, cfg.OutgoingAuth) + if err != nil { + return fmt.Errorf("failed to create outgoing authentication registry: %w", err) + } + // Create backend discoverer - discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager) + discoverer := aggregator.NewCLIBackendDiscoverer(workloadsManager, groupsManager, cfg.OutgoingAuth) // Discover backends from the configured group logger.Infof("Discovering backends in group: %s", cfg.GroupRef) @@ -230,7 +237,10 @@ func runServe(cmd *cobra.Command, _ []string) error { logger.Infof("Discovered %d backends", len(backends)) // Create backend client - backendClient := vmcpclient.NewHTTPBackendClient() + backendClient, err := vmcpclient.NewHTTPBackendClient(outgoingRegistry) + if err != nil { + return fmt.Errorf("failed to create backend client: %w", err) + } // Create conflict resolver based on configuration // Use the factory method that handles all strategies @@ -264,7 +274,7 @@ func runServe(cmd *cobra.Command, _ []string) error { // Setup authentication middleware logger.Infof("Setting up incoming authentication (type: %s)", cfg.IncomingAuth.Type) - authMiddleware, authInfoHandler, err := vmcpauth.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) + authMiddleware, authInfoHandler, err := factory.NewIncomingAuthMiddleware(ctx, cfg.IncomingAuth) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index aadc1dae4..cd83cd061 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -130,7 +130,7 @@ func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (aut // defaultClientFactory creates mark3labs MCP clients for different transport types. func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { // Build transport chain: size limit → authentication → HTTP - var baseTransport http.RoundTripper = http.DefaultTransport + var baseTransport = http.DefaultTransport // Resolve authentication strategy ONCE at client creation time authStrategy, err := h.resolveAuthStrategy(target) From 2879d63d9c7c4b9fd2c5ebf507eeb7731909a702 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 23:21:20 +0000 Subject: [PATCH 07/13] Add explicit unauthenticated strategy for vMCP Replace the pattern of passing nil authenticators with an explicit UnauthenticatedStrategy that implements the Strategy interface as a no-op. This makes the intent clear in configuration and improves type safety by eliminating nil checks. The strategy is appropriate for backends on trusted networks or where authentication is handled at the network layer. Configuration now explicitly declares "strategy: unauthenticated" instead of relying on implicit nil behavior. --- pkg/vmcp/auth/strategies/unauthenticated.go | 72 +++++++ .../auth/strategies/unauthenticated_test.go | 196 ++++++++++++++++++ 2 files changed, 268 insertions(+) create mode 100644 pkg/vmcp/auth/strategies/unauthenticated.go create mode 100644 pkg/vmcp/auth/strategies/unauthenticated_test.go diff --git a/pkg/vmcp/auth/strategies/unauthenticated.go b/pkg/vmcp/auth/strategies/unauthenticated.go new file mode 100644 index 000000000..454495c52 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated.go @@ -0,0 +1,72 @@ +package strategies + +import ( + "context" + "net/http" +) + +// UnauthenticatedStrategy is a no-op authentication strategy that performs no authentication. +// This strategy is used when a backend MCP server requires no authentication. +// +// Unlike passing a nil authenticator (which is now an error), this strategy makes +// the intent explicit: "this backend intentionally has no authentication". +// +// The strategy performs no modifications to requests and validates all metadata. +// +// This is appropriate when: +// - The backend MCP server is on a trusted network (e.g., localhost) +// - The backend has no authentication requirements +// - Authentication is handled by network-level security (e.g., VPC, firewall) +// +// Security Warning: Only use this strategy when you are certain the backend +// requires no authentication. For production deployments, prefer explicit +// authentication strategies (pass_through, header_injection, token_exchange). +// +// Configuration: No metadata required, but any metadata is accepted and ignored. +// +// Example configuration: +// +// backends: +// local-backend: +// strategy: "unauthenticated" +type UnauthenticatedStrategy struct{} + +// NewUnauthenticatedStrategy creates a new UnauthenticatedStrategy instance. +func NewUnauthenticatedStrategy() *UnauthenticatedStrategy { + return &UnauthenticatedStrategy{} +} + +// Name returns the strategy identifier. +func (*UnauthenticatedStrategy) Name() string { + return "unauthenticated" +} + +// Authenticate performs no authentication and returns immediately. +// +// This method: +// 1. Does not modify the request in any way +// 2. Always returns nil (success) +// +// Parameters: +// - ctx: Request context (unused) +// - req: The HTTP request (not modified) +// - metadata: Strategy-specific configuration (ignored) +// +// Returns nil (always succeeds). +func (*UnauthenticatedStrategy) Authenticate(_ context.Context, _ *http.Request, _ map[string]any) error { + // No-op: intentionally does nothing + return nil +} + +// Validate checks if the strategy configuration is valid. +// +// UnauthenticatedStrategy accepts any metadata (including nil or empty), +// so this always returns nil. +// +// This permissive validation allows the strategy to be used without +// configuration or with arbitrary configuration that may be present +// for documentation purposes. +func (*UnauthenticatedStrategy) Validate(_ map[string]any) error { + // No-op: accepts any metadata + return nil +} diff --git a/pkg/vmcp/auth/strategies/unauthenticated_test.go b/pkg/vmcp/auth/strategies/unauthenticated_test.go new file mode 100644 index 000000000..43ee62bb2 --- /dev/null +++ b/pkg/vmcp/auth/strategies/unauthenticated_test.go @@ -0,0 +1,196 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnauthenticatedStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + assert.Equal(t, "unauthenticated", strategy.Name()) +} + +func TestUnauthenticatedStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + setupRequest func() *http.Request + checkRequest func(t *testing.T, req *http.Request) + }{ + { + name: "does not modify request with no metadata", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Custom-Header", "original-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "original-value", req.Header.Get("X-Custom-Header")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "does not modify request with metadata present", + metadata: map[string]any{ + "some_key": "some_value", + "count": 42, + }, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("X-Existing", "existing-value") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Original headers should be unchanged + assert.Equal(t, "existing-value", req.Header.Get("X-Existing")) + // No auth headers should be added + assert.Empty(t, req.Header.Get("Authorization")) + }, + }, + { + name: "preserves existing Authorization header", + metadata: nil, + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + req.Header.Set("Authorization", "Bearer existing-token") + return req + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Should not modify existing Authorization header + assert.Equal(t, "Bearer existing-token", req.Header.Get("Authorization")) + }, + }, + { + name: "works with empty request", + metadata: nil, + setupRequest: func() *http.Request { + return httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + }, + checkRequest: func(t *testing.T, req *http.Request) { + t.Helper() + // Request should have no auth headers + assert.Empty(t, req.Header.Get("Authorization")) + // Headers should be empty or minimal + assert.LessOrEqual(t, len(req.Header), 1) // May have Host header + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + req := tt.setupRequest() + ctx := context.Background() + + err := strategy.Authenticate(ctx, req, tt.metadata) + + require.NoError(t, err) + tt.checkRequest(t, req) + }) + } +} + +func TestUnauthenticatedStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + }{ + { + name: "accepts nil metadata", + metadata: nil, + }, + { + name: "accepts empty metadata", + metadata: map[string]any{}, + }, + { + name: "accepts arbitrary metadata", + metadata: map[string]any{ + "key1": "value1", + "key2": 42, + "key3": []string{"a", "b", "c"}, + "nested": map[string]any{"inner": "value"}, + }, + }, + { + name: "accepts metadata with typical auth fields", + metadata: map[string]any{ + "token_url": "https://example.com/token", + "client_id": "client-123", + "header_name": "X-API-Key", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + err := strategy.Validate(tt.metadata) + + require.NoError(t, err) + }) + } +} + +func TestUnauthenticatedStrategy_IntegrationBehavior(t *testing.T) { + t.Parallel() + + t.Run("strategy can be called multiple times safely", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Call multiple times with different requests + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + require.NoError(t, err) + assert.Empty(t, req.Header.Get("Authorization")) + } + }) + + t.Run("strategy is safe for concurrent use", func(t *testing.T) { + t.Parallel() + + strategy := NewUnauthenticatedStrategy() + ctx := context.Background() + + // Run authentication concurrently + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + req := httptest.NewRequest(http.MethodGet, "http://backend.example.com/test", nil) + err := strategy.Authenticate(ctx, req, nil) + assert.NoError(t, err) + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 10; i++ { + <-done + } + }) +} From 1e328732bedc87617ec2cc9bdbf7bfd725a12de1 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 3 Nov 2025 22:54:36 +0000 Subject: [PATCH 08/13] Implement HeaderInjection authentication strategy Add HeaderInjectionStrategy for injecting static header values into backend requests. This general-purpose strategy supports any HTTP header with any static value, enabling flexible authentication schemes like API keys, bearer tokens, and custom auth headers. The strategy extracts header_name and api_key from metadata configuration and validates them to prevent CRLF injection attacks using pkg/validation functions. Validation occurs at configuration time for fail-fast behavior. --- pkg/vmcp/auth/strategies/header_injection.go | 113 +++++ .../auth/strategies/header_injection_test.go | 408 ++++++++++++++++++ 2 files changed, 521 insertions(+) create mode 100644 pkg/vmcp/auth/strategies/header_injection.go create mode 100644 pkg/vmcp/auth/strategies/header_injection_test.go diff --git a/pkg/vmcp/auth/strategies/header_injection.go b/pkg/vmcp/auth/strategies/header_injection.go new file mode 100644 index 000000000..07fccc084 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection.go @@ -0,0 +1,113 @@ +// Package strategies provides authentication strategy implementations for Virtual MCP Server. +package strategies + +import ( + "context" + "fmt" + "net/http" + + "github.com/stacklok/toolhive/pkg/validation" +) + +// HeaderInjectionStrategy injects a static header value into request headers. +// This is a general-purpose strategy that can inject any header with any value, +// commonly used for API keys, bearer tokens, or custom authentication headers. +// +// The strategy extracts the header name and value from the metadata +// configuration and injects them into the backend request headers. +// +// Required metadata fields: +// - header_name: The HTTP header name to use (e.g., "X-API-Key", "Authorization") +// - api_key: The header value to inject (can be an API key, token, or any value) +// +// This strategy is appropriate when: +// - The backend requires a static header value for authentication +// - The header value is stored securely in the vMCP configuration +// - No dynamic token exchange or user-specific authentication is required +// +// Future enhancements may include: +// - Secret reference resolution (e.g., ${SECRET_REF:...}) +// - Support for multiple header formats (e.g., "Bearer ") +// - Value rotation and refresh mechanisms +type HeaderInjectionStrategy struct{} + +// NewHeaderInjectionStrategy creates a new HeaderInjectionStrategy instance. +func NewHeaderInjectionStrategy() *HeaderInjectionStrategy { + return &HeaderInjectionStrategy{} +} + +// Name returns the strategy identifier. +func (*HeaderInjectionStrategy) Name() string { + return "header_injection" +} + +// Authenticate injects the header value from metadata into the request header. +// +// This method: +// 1. Validates that header_name and api_key are present in metadata +// 2. Sets the specified header with the provided value +// +// Parameters: +// - ctx: Request context (currently unused, reserved for future secret resolution) +// - req: The HTTP request to authenticate +// - metadata: Strategy-specific configuration containing header_name and api_key +// +// Returns an error if: +// - header_name is missing or empty +// - api_key is missing or empty +func (*HeaderInjectionStrategy) Authenticate(_ context.Context, req *http.Request, metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // TODO: Future enhancement - resolve secret references + // if strings.HasPrefix(apiKey, "${SECRET_REF:") { + // apiKey, err = s.secretResolver.Resolve(ctx, apiKey) + // if err != nil { + // return fmt.Errorf("failed to resolve secret reference: %w", err) + // } + // } + + req.Header.Set(headerName, apiKey) + return nil +} + +// Validate checks if the required metadata fields are present and valid. +// +// This method verifies that: +// - header_name is present and non-empty +// - api_key is present and non-empty +// - header_name is a valid HTTP header name (prevents CRLF injection) +// - api_key is a valid HTTP header value (prevents CRLF injection) +// +// This validation is typically called during configuration parsing to fail fast +// if the strategy is misconfigured. +func (*HeaderInjectionStrategy) Validate(metadata map[string]any) error { + headerName, ok := metadata["header_name"].(string) + if !ok || headerName == "" { + return fmt.Errorf("header_name required in metadata") + } + + apiKey, ok := metadata["api_key"].(string) + if !ok || apiKey == "" { + return fmt.Errorf("api_key required in metadata") + } + + // Validate header name to prevent injection attacks + if err := validation.ValidateHTTPHeaderName(headerName); err != nil { + return fmt.Errorf("invalid header_name: %w", err) + } + + // Validate API key value to prevent injection attacks + if err := validation.ValidateHTTPHeaderValue(apiKey); err != nil { + return fmt.Errorf("invalid api_key: %w", err) + } + + return nil +} diff --git a/pkg/vmcp/auth/strategies/header_injection_test.go b/pkg/vmcp/auth/strategies/header_injection_test.go new file mode 100644 index 000000000..537fd3d86 --- /dev/null +++ b/pkg/vmcp/auth/strategies/header_injection_test.go @@ -0,0 +1,408 @@ +package strategies + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeaderInjectionStrategy_Name(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + assert.Equal(t, "header_injection", strategy.Name()) +} + +func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + checkHeader func(t *testing.T, req *http.Request) + }{ + { + name: "sets X-API-Key header correctly", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key-123", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "secret-key-123", req.Header.Get("X-API-Key")) + }, + }, + { + name: "sets Authorization header with API key", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "ApiKey my-secret-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "ApiKey my-secret-key", req.Header.Get("Authorization")) + }, + }, + { + name: "sets custom header name", + metadata: map[string]any{ + "header_name": "X-Custom-Auth-Token", + "api_key": "custom-token-value", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "custom-token-value", req.Header.Get("X-Custom-Auth-Token")) + }, + }, + { + name: "handles complex API key values", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles API key with special characters", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", req.Header.Get("X-API-Key")) + }, + }, + { + name: "ignores additional metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "my-key", + "extra_field": "ignored", + "another": 123, + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + assert.Equal(t, "my-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty string", + metadata: map[string]any{ + "header_name": "", + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "my-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are missing", + metadata: map[string]any{ + "unrelated": "field", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "overwrites existing header value", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "new-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // Verify the new key was set (old-key was already set before Authenticate) + assert.Equal(t, "new-key", req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles very long API keys", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": string(make([]byte, 10000)) + "very-long-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + expected := string(make([]byte, 10000)) + "very-long-key" + assert.Equal(t, expected, req.Header.Get("X-API-Key")) + }, + }, + { + name: "handles case-sensitive header names", + metadata: map[string]any{ + "header_name": "x-api-key", // lowercase + "api_key": "my-key", + }, + expectError: false, + checkHeader: func(t *testing.T, req *http.Request) { + t.Helper() + // HTTP headers are case-insensitive, but Go normalizes them + assert.Equal(t, "my-key", req.Header.Get("x-api-key")) + assert.Equal(t, "my-key", req.Header.Get("X-Api-Key")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + ctx := context.Background() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + // Special setup for the "overwrites existing header value" test + if tt.name == "overwrites existing header value" { + req.Header.Set("X-API-Key", "old-key") + } + + err := strategy.Authenticate(ctx, req, tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + return + } + + require.NoError(t, err) + if tt.checkHeader != nil { + tt.checkHeader(t, req) + } + }) + } +} + +func TestHeaderInjectionStrategy_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata map[string]any + expectError bool + errorContains string + }{ + { + name: "valid metadata with all required fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + }, + expectError: false, + }, + { + name: "valid with extra metadata fields", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "secret-key", + "extra": "ignored", + "count": 123, + }, + expectError: false, + }, + { + name: "valid with different header name", + metadata: map[string]any{ + "header_name": "Authorization", + "api_key": "Bearer token", + }, + expectError: false, + }, + { + name: "returns error when header_name is missing", + metadata: map[string]any{ + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is empty", + metadata: map[string]any{ + "header_name": "", + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is not a string", + metadata: map[string]any{ + "header_name": 123, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when header_name is a boolean", + metadata: map[string]any{ + "header_name": true, + "api_key": "secret-key", + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when api_key is missing", + metadata: map[string]any{ + "header_name": "X-API-Key", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is empty", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "", + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is not a string", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": 123, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when api_key is a map", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": map[string]any{"nested": "value"}, + }, + expectError: true, + errorContains: "api_key required", + }, + { + name: "returns error when metadata is nil", + metadata: nil, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when metadata is empty", + metadata: map[string]any{}, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error when both fields are wrong type", + metadata: map[string]any{ + "header_name": 123, + "api_key": false, + }, + expectError: true, + errorContains: "header_name required", + }, + { + name: "returns error for whitespace in header_name", + metadata: map[string]any{ + "header_name": "X-Custom Header", + "api_key": "key", + }, + expectError: true, + errorContains: "invalid header_name", + }, + { + name: "accepts unicode in api_key", + metadata: map[string]any{ + "header_name": "X-API-Key", + "api_key": "key-with-unicode-日本語", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + strategy := NewHeaderInjectionStrategy() + err := strategy.Validate(tt.metadata) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} From 28f80dcef98ea51cf1fcfff86422e246a31bb4ac Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 16:10:19 +0000 Subject: [PATCH 09/13] Update validator to only accept implemented strategies Limit validTypes to strategies actually implemented in this PR: - unauthenticated - header_injection Comment out unimplemented strategies with TODO to add them as they are implemented in future PRs. This prevents accepting configuration for strategies that don't exist yet. --- pkg/vmcp/config/validator.go | 6 ++- pkg/vmcp/config/validator_test.go | 77 ++++++++++++++++++++--------- pkg/vmcp/config/yaml_loader_test.go | 6 +-- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/pkg/vmcp/config/validator.go b/pkg/vmcp/config/validator.go index dc863cef6..a7beb0384 100644 --- a/pkg/vmcp/config/validator.go +++ b/pkg/vmcp/config/validator.go @@ -169,8 +169,10 @@ func (*DefaultValidator) validateBackendAuthStrategy(_ string, strategy *Backend } validTypes := []string{ - "pass_through", "token_exchange", "client_credentials", - "service_account", "header_injection", "oauth_proxy", + "unauthenticated", "header_injection", + // TODO: Add more as strategies are implemented: + // "pass_through", "token_exchange", "client_credentials", + // "service_account", "oauth_proxy", } if !contains(validTypes, strategy.Type) { return fmt.Errorf("type must be one of: %s", strings.Join(validTypes, ", ")) diff --git a/pkg/vmcp/config/validator_test.go b/pkg/vmcp/config/validator_test.go index c7bb64678..4eebd929c 100644 --- a/pkg/vmcp/config/validator_test.go +++ b/pkg/vmcp/config/validator_test.go @@ -187,32 +187,60 @@ func TestValidator_ValidateOutgoingAuth(t *testing.T) { errMsg string }{ { - name: "valid inline source with pass_through default", + name: "valid inline source with unauthenticated default", auth: &OutgoingAuthConfig{ Source: "inline", Default: &BackendAuthStrategy{ - Type: "pass_through", + Type: "unauthenticated", }, }, wantErr: false, }, { - name: "valid token_exchange backend", + name: "valid header_injection backend", auth: &OutgoingAuthConfig{ Source: "inline", Backends: map[string]*BackendAuthStrategy{ "github": { - Type: "token_exchange", + Type: "header_injection", Metadata: map[string]any{ - "token_url": "https://example.com/token", - "client_id": "test-client", - "audience": "github-api", + "header_name": "Authorization", + "api_key": "secret-token", }, }, }, }, wantErr: false, }, + // TODO: Uncomment when pass_through strategy is implemented + // { + // name: "valid inline source with pass_through default", + // auth: &OutgoingAuthConfig{ + // Source: "inline", + // Default: &BackendAuthStrategy{ + // Type: "pass_through", + // }, + // }, + // wantErr: false, + // }, + // TODO: Uncomment when token_exchange strategy is implemented + // { + // name: "valid token_exchange backend", + // auth: &OutgoingAuthConfig{ + // Source: "inline", + // Backends: map[string]*BackendAuthStrategy{ + // "github": { + // Type: "token_exchange", + // Metadata: map[string]any{ + // "token_url": "https://example.com/token", + // "client_id": "test-client", + // "audience": "github-api", + // }, + // }, + // }, + // }, + // wantErr: false, + // }, { name: "invalid source", auth: &OutgoingAuthConfig{ @@ -234,23 +262,24 @@ func TestValidator_ValidateOutgoingAuth(t *testing.T) { wantErr: true, errMsg: "type must be one of", }, - { - name: "token_exchange missing required metadata", - auth: &OutgoingAuthConfig{ - Source: "inline", - Backends: map[string]*BackendAuthStrategy{ - "github": { - Type: "token_exchange", - Metadata: map[string]any{ - "client_id": "test-client", - // Missing token_url and audience - }, - }, - }, - }, - wantErr: true, - errMsg: "token_exchange requires metadata field", - }, + // TODO: Uncomment when token_exchange strategy is implemented + // { + // name: "token_exchange missing required metadata", + // auth: &OutgoingAuthConfig{ + // Source: "inline", + // Backends: map[string]*BackendAuthStrategy{ + // "github": { + // Type: "token_exchange", + // Metadata: map[string]any{ + // "client_id": "test-client", + // // Missing token_url and audience + // }, + // }, + // }, + // }, + // wantErr: true, + // errMsg: "token_exchange requires metadata field", + // }, } for _, tt := range tests { diff --git a/pkg/vmcp/config/yaml_loader_test.go b/pkg/vmcp/config/yaml_loader_test.go index acd96df34..c5eac6184 100644 --- a/pkg/vmcp/config/yaml_loader_test.go +++ b/pkg/vmcp/config/yaml_loader_test.go @@ -387,7 +387,7 @@ incoming_auth: outgoing_auth: source: inline default: - type: pass_through + type: unauthenticated aggregation: conflict_resolution: prefix @@ -407,7 +407,7 @@ incoming_auth: outgoing_auth: source: inline default: - type: pass_through + type: unauthenticated aggregation: conflict_resolution: prefix @@ -429,7 +429,7 @@ incoming_auth: outgoing_auth: source: inline default: - type: pass_through + type: unauthenticated aggregation: conflict_resolution: prefix From 572e0e4c93e95c03ecd5d81d8f002e8e536e7358 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 16:22:07 +0000 Subject: [PATCH 10/13] Update example configs to use implemented strategies Update example configuration files to use only implemented authentication strategies (unauthenticated and header_injection). Changes: - Replace pass_through with unauthenticated in defaults - Show header_injection example for backends - Comment out unimplemented strategies (pass_through, token_exchange, service_account) with TODOs - Add clear notes about which strategies are currently implemented This ensures example configs are valid and can be used immediately without validation errors. --- cmd/vmcp/example-config.yaml | 31 ++++++++++------ examples/vmcp-config.yaml | 70 +++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/cmd/vmcp/example-config.yaml b/cmd/vmcp/example-config.yaml index 9360f5d40..bed36e65f 100644 --- a/cmd/vmcp/example-config.yaml +++ b/cmd/vmcp/example-config.yaml @@ -34,25 +34,34 @@ incoming_auth: # scopes: ["openid", "profile", "email"] # ===== OUTGOING AUTHENTICATION (Virtual MCP → Backends) ===== -# Currently not implemented - this configuration is a placeholder for -# future implementation (Issue #160) +# Implemented strategies: unauthenticated, header_injection outgoing_auth: source: inline # Options: inline | discovered # Default behavior for backends without explicit config default: - type: pass_through # Options: pass_through | token_exchange | service_account + type: unauthenticated # Options: unauthenticated | header_injection + # TODO: Uncomment when pass_through is implemented + # type: pass_through - # Per-backend authentication (not yet implemented) + # Per-backend authentication examples # backends: + # # Example: API key authentication # github: - # type: token_exchange - # token_exchange: - # token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" - # client_id: "vmcp-github-exchange" - # client_secret_env: "GITHUB_EXCHANGE_SECRET" - # audience: "github-api" - # scopes: ["repo", "read:org"] + # type: header_injection + # header_injection: + # header_name: "Authorization" + # api_key: "${GITHUB_API_TOKEN}" + # + # # TODO: Uncomment when token_exchange is implemented + # # jira: + # # type: token_exchange + # # metadata: + # # token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" + # # client_id: "vmcp-github-exchange" + # # client_secret_env: "GITHUB_EXCHANGE_SECRET" + # # audience: "github-api" + # # scopes: ["repo", "read:org"] # ===== TOOL AGGREGATION ===== aggregation: diff --git a/examples/vmcp-config.yaml b/examples/vmcp-config.yaml index 2e73001c1..a31d2c8cc 100644 --- a/examples/vmcp-config.yaml +++ b/examples/vmcp-config.yaml @@ -13,6 +13,7 @@ incoming_auth: client_id: "vmcp-client" client_secret_env: "VMCP_CLIENT_SECRET" # Read from environment variable audience: "vmcp" # Token must have aud=vmcp + resource: "http://localhost:4483/mcp" scopes: ["openid", "profile", "email"] # Optional: Authorization policies (Cedar) @@ -33,42 +34,55 @@ outgoing_auth: # Default behavior for backends without explicit config default: - type: pass_through # pass_through | error + type: unauthenticated # unauthenticated | header_injection + # TODO: Uncomment when pass_through is implemented + # type: pass_through # Forward client token unchanged # Per-backend authentication configurations # IMPORTANT: These tokens are for backend APIs (e.g., github-api, jira-api), # NOT for authenticating Virtual MCP to backend MCP servers. # Backend MCP servers receive properly scoped tokens and use them to call upstream APIs. backends: + # Example: API key authentication using header_injection github: - type: token_exchange - token_exchange: - # RFC 8693 token exchange for GitHub API access - token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" - client_id: "vmcp-github-exchange" - client_secret_env: "GITHUB_EXCHANGE_SECRET" - audience: "github-api" # Token audience for GitHub API - scopes: ["repo", "read:org"] # GitHub API scopes - subject_token_type: "access_token" # access_token | id_token - - jira: - type: token_exchange - token_exchange: - token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" - client_id: "vmcp-jira-exchange" - client_secret_env: "JIRA_EXCHANGE_SECRET" - audience: "jira-api" # Token audience for Jira API - scopes: ["read:jira-work", "write:jira-work"] - - slack: - type: service_account - service_account: - credentials_env: "SLACK_BOT_TOKEN" + type: header_injection + header_injection: header_name: "Authorization" - header_format: "Bearer {token}" - - internal-db: - type: pass_through # Forward client token unchanged + api_key: "${GITHUB_API_TOKEN}" # Read from environment variable + + # TODO: Uncomment when token_exchange strategy is implemented + # github: + # type: token_exchange + # metadata: + # # RFC 8693 token exchange for GitHub API access + # token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" + # client_id: "vmcp-github-exchange" + # client_secret_env: "GITHUB_EXCHANGE_SECRET" + # audience: "github-api" # Token audience for GitHub API + # scopes: ["repo", "read:org"] # GitHub API scopes + # subject_token_type: "access_token" # access_token | id_token + + # TODO: Uncomment when token_exchange strategy is implemented + # jira: + # type: token_exchange + # metadata: + # token_url: "https://keycloak.example.com/realms/myrealm/protocol/openid-connect/token" + # client_id: "vmcp-jira-exchange" + # client_secret_env: "JIRA_EXCHANGE_SECRET" + # audience: "jira-api" # Token audience for Jira API + # scopes: ["read:jira-work", "write:jira-work"] + + # TODO: Uncomment when service_account strategy is implemented + # slack: + # type: service_account + # metadata: + # credentials_env: "SLACK_BOT_TOKEN" + # header_name: "Authorization" + # header_format: "Bearer {token}" + + # TODO: Uncomment when pass_through strategy is implemented + # internal-db: + # type: pass_through # Forward client token unchanged # ===== TOKEN CACHING ===== token_cache: From 904739c9bf8c3d4401bfa1223957e01c961360be Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 21:31:02 +0000 Subject: [PATCH 11/13] Rename api_key to header_value in HeaderInjectionStrategy Rename the misleading 'api_key' field to 'header_value' to better reflect that this strategy can inject any HTTP header value, not just API keys. This improves semantic clarity and matches the general- purpose nature of the strategy. --- cmd/vmcp/example-config.yaml | 2 +- examples/vmcp-config.yaml | 2 +- pkg/vmcp/auth/strategies/header_injection.go | 36 ++--- .../auth/strategies/header_injection_test.go | 146 +++++++++--------- pkg/vmcp/client/client_test.go | 2 +- pkg/vmcp/config/validator.go | 5 +- pkg/vmcp/config/validator_test.go | 4 +- 7 files changed, 100 insertions(+), 97 deletions(-) diff --git a/cmd/vmcp/example-config.yaml b/cmd/vmcp/example-config.yaml index bed36e65f..4ce9848f1 100644 --- a/cmd/vmcp/example-config.yaml +++ b/cmd/vmcp/example-config.yaml @@ -51,7 +51,7 @@ outgoing_auth: # type: header_injection # header_injection: # header_name: "Authorization" - # api_key: "${GITHUB_API_TOKEN}" + # header_value: "${GITHUB_API_TOKEN}" # # # TODO: Uncomment when token_exchange is implemented # # jira: diff --git a/examples/vmcp-config.yaml b/examples/vmcp-config.yaml index a31d2c8cc..849c49d43 100644 --- a/examples/vmcp-config.yaml +++ b/examples/vmcp-config.yaml @@ -48,7 +48,7 @@ outgoing_auth: type: header_injection header_injection: header_name: "Authorization" - api_key: "${GITHUB_API_TOKEN}" # Read from environment variable + header_value: "${GITHUB_API_TOKEN}" # Read from environment variable # TODO: Uncomment when token_exchange strategy is implemented # github: diff --git a/pkg/vmcp/auth/strategies/header_injection.go b/pkg/vmcp/auth/strategies/header_injection.go index 07fccc084..52779f2b4 100644 --- a/pkg/vmcp/auth/strategies/header_injection.go +++ b/pkg/vmcp/auth/strategies/header_injection.go @@ -18,7 +18,7 @@ import ( // // Required metadata fields: // - header_name: The HTTP header name to use (e.g., "X-API-Key", "Authorization") -// - api_key: The header value to inject (can be an API key, token, or any value) +// - header_value: The header value to inject (can be an API key, token, or any value) // // This strategy is appropriate when: // - The backend requires a static header value for authentication @@ -44,37 +44,37 @@ func (*HeaderInjectionStrategy) Name() string { // Authenticate injects the header value from metadata into the request header. // // This method: -// 1. Validates that header_name and api_key are present in metadata +// 1. Validates that header_name and header_value are present in metadata // 2. Sets the specified header with the provided value // // Parameters: // - ctx: Request context (currently unused, reserved for future secret resolution) // - req: The HTTP request to authenticate -// - metadata: Strategy-specific configuration containing header_name and api_key +// - metadata: Strategy-specific configuration containing header_name and header_value // // Returns an error if: // - header_name is missing or empty -// - api_key is missing or empty +// - header_value is missing or empty func (*HeaderInjectionStrategy) Authenticate(_ context.Context, req *http.Request, metadata map[string]any) error { headerName, ok := metadata["header_name"].(string) if !ok || headerName == "" { return fmt.Errorf("header_name required in metadata") } - apiKey, ok := metadata["api_key"].(string) - if !ok || apiKey == "" { - return fmt.Errorf("api_key required in metadata") + headerValue, ok := metadata["header_value"].(string) + if !ok || headerValue == "" { + return fmt.Errorf("header_value required in metadata") } // TODO: Future enhancement - resolve secret references - // if strings.HasPrefix(apiKey, "${SECRET_REF:") { - // apiKey, err = s.secretResolver.Resolve(ctx, apiKey) + // if strings.HasPrefix(headerValue, "${SECRET_REF:") { + // headerValue, err = s.secretResolver.Resolve(ctx, headerValue) // if err != nil { // return fmt.Errorf("failed to resolve secret reference: %w", err) // } // } - req.Header.Set(headerName, apiKey) + req.Header.Set(headerName, headerValue) return nil } @@ -82,9 +82,9 @@ func (*HeaderInjectionStrategy) Authenticate(_ context.Context, req *http.Reques // // This method verifies that: // - header_name is present and non-empty -// - api_key is present and non-empty +// - header_value is present and non-empty // - header_name is a valid HTTP header name (prevents CRLF injection) -// - api_key is a valid HTTP header value (prevents CRLF injection) +// - header_value is a valid HTTP header value (prevents CRLF injection) // // This validation is typically called during configuration parsing to fail fast // if the strategy is misconfigured. @@ -94,9 +94,9 @@ func (*HeaderInjectionStrategy) Validate(metadata map[string]any) error { return fmt.Errorf("header_name required in metadata") } - apiKey, ok := metadata["api_key"].(string) - if !ok || apiKey == "" { - return fmt.Errorf("api_key required in metadata") + headerValue, ok := metadata["header_value"].(string) + if !ok || headerValue == "" { + return fmt.Errorf("header_value required in metadata") } // Validate header name to prevent injection attacks @@ -104,9 +104,9 @@ func (*HeaderInjectionStrategy) Validate(metadata map[string]any) error { return fmt.Errorf("invalid header_name: %w", err) } - // Validate API key value to prevent injection attacks - if err := validation.ValidateHTTPHeaderValue(apiKey); err != nil { - return fmt.Errorf("invalid api_key: %w", err) + // Validate header value to prevent injection attacks + if err := validation.ValidateHTTPHeaderValue(headerValue); err != nil { + return fmt.Errorf("invalid header_value: %w", err) } return nil diff --git a/pkg/vmcp/auth/strategies/header_injection_test.go b/pkg/vmcp/auth/strategies/header_injection_test.go index 537fd3d86..86e70ef27 100644 --- a/pkg/vmcp/auth/strategies/header_injection_test.go +++ b/pkg/vmcp/auth/strategies/header_injection_test.go @@ -30,8 +30,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "sets X-API-Key header correctly", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "secret-key-123", + "header_name": "X-API-Key", + "header_value": "secret-key-123", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -42,8 +42,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "sets Authorization header with API key", metadata: map[string]any{ - "header_name": "Authorization", - "api_key": "ApiKey my-secret-key", + "header_name": "Authorization", + "header_value": "ApiKey my-secret-key", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -54,8 +54,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "sets custom header name", metadata: map[string]any{ - "header_name": "X-Custom-Auth-Token", - "api_key": "custom-token-value", + "header_name": "X-Custom-Auth-Token", + "header_value": "custom-token-value", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -64,10 +64,10 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { }, }, { - name: "handles complex API key values", + name: "handles complex header values", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", + "header_name": "X-API-Key", + "header_value": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.test", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -77,10 +77,10 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { }, }, { - name: "handles API key with special characters", + name: "handles header value with special characters", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", + "header_name": "X-API-Key", + "header_value": "key-with-!@#$%^&*()-_=+[]{}|;:,.<>?", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -91,10 +91,10 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "ignores additional metadata fields", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "my-key", - "extra_field": "ignored", - "another": 123, + "header_name": "X-API-Key", + "header_value": "my-key", + "extra_field": "ignored", + "another": 123, }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -105,7 +105,7 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "returns error when header_name is missing", metadata: map[string]any{ - "api_key": "my-key", + "header_value": "my-key", }, expectError: true, errorContains: "header_name required", @@ -113,8 +113,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "returns error when header_name is empty string", metadata: map[string]any{ - "header_name": "", - "api_key": "my-key", + "header_name": "", + "header_value": "my-key", }, expectError: true, errorContains: "header_name required", @@ -122,37 +122,37 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "returns error when header_name is not a string", metadata: map[string]any{ - "header_name": 123, - "api_key": "my-key", + "header_name": 123, + "header_value": "my-key", }, expectError: true, errorContains: "header_name required", }, { - name: "returns error when api_key is missing", + name: "returns error when header_value is missing", metadata: map[string]any{ "header_name": "X-API-Key", }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { name: "returns error when api_key is empty string", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "", + "header_name": "X-API-Key", + "header_value": "", }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { - name: "returns error when api_key is not a string", + name: "returns error when header_value is not a string", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": 123, + "header_name": "X-API-Key", + "header_value": 123, }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { name: "returns error when metadata is nil", @@ -177,8 +177,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "overwrites existing header value", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "new-key", + "header_name": "X-API-Key", + "header_value": "new-key", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -188,10 +188,10 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { }, }, { - name: "handles very long API keys", + name: "handles very long header values", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": string(make([]byte, 10000)) + "very-long-key", + "header_name": "X-API-Key", + "header_value": string(make([]byte, 10000)) + "very-long-key", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -203,8 +203,8 @@ func TestHeaderInjectionStrategy_Authenticate(t *testing.T) { { name: "handles case-sensitive header names", metadata: map[string]any{ - "header_name": "x-api-key", // lowercase - "api_key": "my-key", + "header_name": "x-api-key", // lowercase + "header_value": "my-key", }, expectError: false, checkHeader: func(t *testing.T, req *http.Request) { @@ -257,33 +257,33 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "valid metadata with all required fields", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "secret-key", + "header_name": "X-API-Key", + "header_value": "secret-key", }, expectError: false, }, { name: "valid with extra metadata fields", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "secret-key", - "extra": "ignored", - "count": 123, + "header_name": "X-API-Key", + "header_value": "secret-key", + "extra": "ignored", + "count": 123, }, expectError: false, }, { name: "valid with different header name", metadata: map[string]any{ - "header_name": "Authorization", - "api_key": "Bearer token", + "header_name": "Authorization", + "header_value": "Bearer token", }, expectError: false, }, { name: "returns error when header_name is missing", metadata: map[string]any{ - "api_key": "secret-key", + "header_value": "secret-key", }, expectError: true, errorContains: "header_name required", @@ -291,8 +291,8 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "returns error when header_name is empty", metadata: map[string]any{ - "header_name": "", - "api_key": "secret-key", + "header_name": "", + "header_value": "secret-key", }, expectError: true, errorContains: "header_name required", @@ -300,8 +300,8 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "returns error when header_name is not a string", metadata: map[string]any{ - "header_name": 123, - "api_key": "secret-key", + "header_name": 123, + "header_value": "secret-key", }, expectError: true, errorContains: "header_name required", @@ -309,46 +309,46 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "returns error when header_name is a boolean", metadata: map[string]any{ - "header_name": true, - "api_key": "secret-key", + "header_name": true, + "header_value": "secret-key", }, expectError: true, errorContains: "header_name required", }, { - name: "returns error when api_key is missing", + name: "returns error when header_value is missing", metadata: map[string]any{ "header_name": "X-API-Key", }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { - name: "returns error when api_key is empty", + name: "returns error when header_value is empty", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "", + "header_name": "X-API-Key", + "header_value": "", }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { - name: "returns error when api_key is not a string", + name: "returns error when header_value is not a string", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": 123, + "header_name": "X-API-Key", + "header_value": 123, }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { - name: "returns error when api_key is a map", + name: "returns error when header_value is a map", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": map[string]any{"nested": "value"}, + "header_name": "X-API-Key", + "header_value": map[string]any{"nested": "value"}, }, expectError: true, - errorContains: "api_key required", + errorContains: "header_value required", }, { name: "returns error when metadata is nil", @@ -365,8 +365,8 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "returns error when both fields are wrong type", metadata: map[string]any{ - "header_name": 123, - "api_key": false, + "header_name": 123, + "header_value": false, }, expectError: true, errorContains: "header_name required", @@ -374,17 +374,17 @@ func TestHeaderInjectionStrategy_Validate(t *testing.T) { { name: "returns error for whitespace in header_name", metadata: map[string]any{ - "header_name": "X-Custom Header", - "api_key": "key", + "header_name": "X-Custom Header", + "header_value": "key", }, expectError: true, errorContains: "invalid header_name", }, { - name: "accepts unicode in api_key", + name: "accepts unicode in header_value", metadata: map[string]any{ - "header_name": "X-API-Key", - "api_key": "key-with-unicode-日本語", + "header_name": "X-API-Key", + "header_value": "key-with-unicode-日本語", }, expectError: false, }, diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 4e1c38837..7b609464f 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -538,7 +538,7 @@ func TestResolveAuthStrategy(t *testing.T) { target: &vmcp.BackendTarget{ WorkloadID: "backend-1", AuthStrategy: "header_injection", - AuthMetadata: map[string]any{"header_name": "X-API-Key", "api_key": "test-key"}, + AuthMetadata: map[string]any{"header_name": "X-API-Key", "header_value": "test-key"}, }, setupRegistry: func() auth.OutgoingAuthRegistry { registry := auth.NewDefaultOutgoingAuthRegistry() diff --git a/pkg/vmcp/config/validator.go b/pkg/vmcp/config/validator.go index a7beb0384..c448c85e1 100644 --- a/pkg/vmcp/config/validator.go +++ b/pkg/vmcp/config/validator.go @@ -196,10 +196,13 @@ func (*DefaultValidator) validateBackendAuthStrategy(_ string, strategy *Backend } case "header_injection": - // Header injection requires header name and value/format + // Header injection requires header name and value if _, ok := strategy.Metadata["header_name"]; !ok { return fmt.Errorf("header_injection requires metadata field: header_name") } + if _, ok := strategy.Metadata["header_value"]; !ok { + return fmt.Errorf("header_injection requires metadata field: header_value") + } } return nil diff --git a/pkg/vmcp/config/validator_test.go b/pkg/vmcp/config/validator_test.go index 4eebd929c..ace6ed938 100644 --- a/pkg/vmcp/config/validator_test.go +++ b/pkg/vmcp/config/validator_test.go @@ -204,8 +204,8 @@ func TestValidator_ValidateOutgoingAuth(t *testing.T) { "github": { Type: "header_injection", Metadata: map[string]any{ - "header_name": "Authorization", - "api_key": "secret-token", + "header_name": "Authorization", + "header_value": "secret-token", }, }, }, From dfa7ebe8ba260719fadc1d523378870f7c93c9eb Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 22:56:19 +0000 Subject: [PATCH 12/13] Fix header_injection metadata parsing in YAML loader Add dedicated rawHeaderInjectionAuth struct and update transformBackendAuthStrategy to properly parse header_injection configuration from YAML files. Previously, the header_injection strategy used a generic metadata field in YAML, but the transform function had no case to handle it, resulting in empty metadata maps. This follows the established pattern used by token_exchange and service_account strategies. The YAML format is now consistent across all strategies: backends: github: type: header_injection header_injection: header_name: Authorization header_value: Bearer xxx --- pkg/vmcp/config/yaml_loader.go | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/pkg/vmcp/config/yaml_loader.go b/pkg/vmcp/config/yaml_loader.go index 222aa1f6c..7535d29e7 100644 --- a/pkg/vmcp/config/yaml_loader.go +++ b/pkg/vmcp/config/yaml_loader.go @@ -78,9 +78,15 @@ type rawOutgoingAuth struct { } type rawBackendAuthStrategy struct { - Type string `yaml:"type"` - TokenExchange *rawTokenExchangeAuth `yaml:"token_exchange"` - ServiceAccount *rawServiceAccountAuth `yaml:"service_account"` + Type string `yaml:"type"` + HeaderInjection *rawHeaderInjectionAuth `yaml:"header_injection"` + TokenExchange *rawTokenExchangeAuth `yaml:"token_exchange"` + ServiceAccount *rawServiceAccountAuth `yaml:"service_account"` +} + +type rawHeaderInjectionAuth struct { + HeaderName string `yaml:"header_name"` + HeaderValue string `yaml:"header_value"` } type rawTokenExchangeAuth struct { @@ -304,6 +310,19 @@ func (*YAMLLoader) transformBackendAuthStrategy(raw *rawBackendAuthStrategy) (*B } switch raw.Type { + case "header_injection": + if raw.HeaderInjection == nil { + return nil, fmt.Errorf("header_injection configuration is required") + } + + strategy.Metadata = map[string]any{ + "header_name": raw.HeaderInjection.HeaderName, + "header_value": raw.HeaderInjection.HeaderValue, + } + + case "unauthenticated": + // No metadata required for unauthenticated strategy + case "token_exchange": if raw.TokenExchange == nil { return nil, fmt.Errorf("token_exchange configuration is required") From f1dd6a21144495b4ef40da4d655ef8b0b89fd485 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 4 Nov 2025 23:30:15 +0000 Subject: [PATCH 13/13] Fix auth loss when querying backend capabilities QueryCapabilities was manually creating BackendTarget but omitted AuthStrategy and AuthMetadata fields, causing all backends to fall back to unauthenticated strategy during capability queries. Replace manual struct creation with BackendToTarget() helper to ensure all fields (including auth) are properly copied from Backend to BackendTarget. This bug prevented per-backend authentication from working during the initial capability discovery phase, even though auth was correctly configured by the discoverer. --- pkg/vmcp/aggregator/default_aggregator.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 069ca2727..202ff8752 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -49,14 +49,8 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. logger.Debugf("Querying capabilities from backend %s", backend.ID) // Create a BackendTarget from the Backend - target := &vmcp.BackendTarget{ - WorkloadID: backend.ID, - WorkloadName: backend.Name, - BaseURL: backend.BaseURL, - TransportType: backend.TransportType, - HealthStatus: backend.HealthStatus, - Metadata: backend.Metadata, - } + // Use BackendToTarget helper to ensure all fields (including auth) are copied + target := vmcp.BackendToTarget(&backend) // Query capabilities using the backend client capabilities, err := a.backendClient.ListCapabilities(ctx, target)