Skip to content

Commit f6b287a

Browse files
committed
Integrate authentication registry into HTTP backend client
Refactors HTTPBackendClient to accept an OutgoingAuthRegistry and apply authentication strategies to all backend requests via a new authRoundTripper middleware. Authentication is now resolved and validated once at client creation time rather than per-request, improving performance and enabling early error detection for misconfigurations. The authRoundTripper clones requests to preserve immutability before applying authentication, ensuring thread-safety and preventing unintended side effects.
1 parent 571bdfc commit f6b287a

File tree

3 files changed

+607
-24
lines changed

3 files changed

+607
-24
lines changed

pkg/vmcp/client/client.go

Lines changed: 112 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717

1818
"github.com/stacklok/toolhive/pkg/logger"
1919
"github.com/stacklok/toolhive/pkg/vmcp"
20+
"github.com/stacklok/toolhive/pkg/vmcp/auth"
2021
)
2122

2223
const (
@@ -44,14 +45,30 @@ type httpBackendClient struct {
4445
// clientFactory creates MCP clients for backends.
4546
// Abstracted as a function to enable testing with mock clients.
4647
clientFactory func(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error)
48+
49+
// registry manages authentication strategies for outgoing requests to backend MCP servers.
50+
// Must not be nil - use UnauthenticatedStrategy for no authentication.
51+
registry auth.OutgoingAuthRegistry
4752
}
4853

4954
// NewHTTPBackendClient creates a new HTTP-based backend client.
5055
// This client supports streamable-HTTP and SSE transports.
51-
func NewHTTPBackendClient() vmcp.BackendClient {
52-
return &httpBackendClient{
53-
clientFactory: defaultClientFactory,
56+
//
57+
// The registry parameter manages authentication strategies for outgoing requests to backend MCP servers.
58+
// It must not be nil. To disable authentication, use a registry configured with the
59+
// "unauthenticated" strategy.
60+
//
61+
// Returns an error if registry is nil.
62+
func NewHTTPBackendClient(registry auth.OutgoingAuthRegistry) (vmcp.BackendClient, error) {
63+
if registry == nil {
64+
return nil, fmt.Errorf("registry cannot be nil; use UnauthenticatedStrategy for no authentication")
65+
}
66+
67+
c := &httpBackendClient{
68+
registry: registry,
5469
}
70+
c.clientFactory = c.defaultClientFactory
71+
return c, nil
5572
}
5673

5774
// roundTripperFunc is a function adapter for http.RoundTripper.
@@ -62,29 +79,103 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
6279
return f(req)
6380
}
6481

82+
// authRoundTripper is an http.RoundTripper that adds authentication to backend requests.
83+
// The authentication strategy and metadata are pre-resolved and validated at client creation time,
84+
// eliminating per-request lookups and validation overhead.
85+
type authRoundTripper struct {
86+
base http.RoundTripper
87+
authStrategy auth.Strategy
88+
authMetadata map[string]any
89+
target *vmcp.BackendTarget
90+
}
91+
92+
// RoundTrip implements http.RoundTripper by adding authentication headers to requests.
93+
// The authentication strategy was pre-resolved and validated at client creation time,
94+
// so this method simply applies the authentication without any lookups or validation.
95+
func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
96+
// Clone request to avoid modifying the original
97+
reqClone := req.Clone(req.Context())
98+
99+
// Apply pre-resolved authentication strategy
100+
if err := a.authStrategy.Authenticate(reqClone.Context(), reqClone, a.authMetadata); err != nil {
101+
return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err)
102+
}
103+
104+
logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID)
105+
106+
return a.base.RoundTrip(reqClone)
107+
}
108+
109+
// resolveAuthStrategy resolves the authentication strategy for a backend target.
110+
// It handles defaulting to "unauthenticated" when no strategy is specified.
111+
// This method should be called once at client creation time to enable fail-fast
112+
// behavior for invalid authentication configurations.
113+
func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (auth.Strategy, error) {
114+
strategyName := target.AuthStrategy
115+
116+
// Default to unauthenticated if not specified
117+
if strategyName == "" {
118+
strategyName = "unauthenticated"
119+
}
120+
121+
// Resolve strategy from registry
122+
strategy, err := h.registry.GetStrategy(strategyName)
123+
if err != nil {
124+
return nil, fmt.Errorf("authentication strategy %q not found: %w", strategyName, err)
125+
}
126+
127+
return strategy, nil
128+
}
129+
65130
// defaultClientFactory creates mark3labs MCP clients for different transport types.
66-
func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) {
67-
// Create HTTP client with response size limits for DoS protection
131+
func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) {
132+
// Build transport chain: size limit → authentication → HTTP
133+
var baseTransport http.RoundTripper = http.DefaultTransport
134+
135+
// Resolve authentication strategy ONCE at client creation time
136+
authStrategy, err := h.resolveAuthStrategy(target)
137+
if err != nil {
138+
return nil, fmt.Errorf("failed to resolve authentication for backend %s: %w",
139+
target.WorkloadID, err)
140+
}
141+
142+
// Validate metadata ONCE at client creation time
143+
if err := authStrategy.Validate(target.AuthMetadata); err != nil {
144+
return nil, fmt.Errorf("invalid authentication configuration for backend %s: %w",
145+
target.WorkloadID, err)
146+
}
147+
148+
// Add authentication layer with pre-resolved strategy
149+
baseTransport = &authRoundTripper{
150+
base: baseTransport,
151+
authStrategy: authStrategy,
152+
authMetadata: target.AuthMetadata,
153+
target: target,
154+
}
155+
156+
// Add size limit layer for DoS protection
157+
sizeLimitedTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
158+
resp, err := baseTransport.RoundTrip(req)
159+
if err != nil {
160+
return nil, err
161+
}
162+
// Wrap response body with size limit
163+
resp.Body = struct {
164+
io.Reader
165+
io.Closer
166+
}{
167+
Reader: io.LimitReader(resp.Body, maxResponseSize),
168+
Closer: resp.Body,
169+
}
170+
return resp, nil
171+
})
172+
173+
// Create HTTP client with configured transport chain
68174
httpClient := &http.Client{
69-
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
70-
resp, err := http.DefaultTransport.RoundTrip(req)
71-
if err != nil {
72-
return nil, err
73-
}
74-
// Wrap response body with size limit
75-
resp.Body = struct {
76-
io.Reader
77-
io.Closer
78-
}{
79-
Reader: io.LimitReader(resp.Body, maxResponseSize),
80-
Closer: resp.Body,
81-
}
82-
return resp, nil
83-
}),
175+
Transport: sizeLimitedTransport,
84176
}
85177

86178
var c *client.Client
87-
var err error
88179

89180
switch target.TransportType {
90181
case "streamable-http", "streamable":
@@ -93,8 +184,6 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli
93184
transport.WithHTTPTimeout(0),
94185
transport.WithContinuousListening(),
95186
transport.WithHTTPBasicClient(httpClient),
96-
// TODO: Add authentication header injection via WithHTTPHeaderFunc
97-
// This will be implemented when we add OutgoingAuthenticator support
98187
)
99188
if err != nil {
100189
return nil, fmt.Errorf("failed to create streamable-http client: %w", err)

0 commit comments

Comments
 (0)