@@ -17,6 +17,7 @@ import (
1717
1818 "github.com/stacklok/toolhive/pkg/logger"
1919 "github.com/stacklok/toolhive/pkg/vmcp"
20+ "github.com/stacklok/toolhive/pkg/vmcp/auth"
2021)
2122
2223const (
@@ -44,14 +45,30 @@ type httpBackendClient struct {
4445 // clientFactory creates MCP clients for backends.
4546 // Abstracted as a function to enable testing with mock clients.
4647 clientFactory func (ctx context.Context , target * vmcp.BackendTarget ) (* client.Client , error )
48+
49+ // registry manages authentication strategies for outgoing requests to backend MCP servers.
50+ // Must not be nil - use UnauthenticatedStrategy for no authentication.
51+ registry auth.OutgoingAuthRegistry
4752}
4853
4954// NewHTTPBackendClient creates a new HTTP-based backend client.
5055// This client supports streamable-HTTP and SSE transports.
51- func NewHTTPBackendClient () vmcp.BackendClient {
52- return & httpBackendClient {
53- clientFactory : defaultClientFactory ,
56+ //
57+ // The registry parameter manages authentication strategies for outgoing requests to backend MCP servers.
58+ // It must not be nil. To disable authentication, use a registry configured with the
59+ // "unauthenticated" strategy.
60+ //
61+ // Returns an error if registry is nil.
62+ func NewHTTPBackendClient (registry auth.OutgoingAuthRegistry ) (vmcp.BackendClient , error ) {
63+ if registry == nil {
64+ return nil , fmt .Errorf ("registry cannot be nil; use UnauthenticatedStrategy for no authentication" )
65+ }
66+
67+ c := & httpBackendClient {
68+ registry : registry ,
5469 }
70+ c .clientFactory = c .defaultClientFactory
71+ return c , nil
5572}
5673
5774// roundTripperFunc is a function adapter for http.RoundTripper.
@@ -62,29 +79,103 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
6279 return f (req )
6380}
6481
82+ // authRoundTripper is an http.RoundTripper that adds authentication to backend requests.
83+ // The authentication strategy and metadata are pre-resolved and validated at client creation time,
84+ // eliminating per-request lookups and validation overhead.
85+ type authRoundTripper struct {
86+ base http.RoundTripper
87+ authStrategy auth.Strategy
88+ authMetadata map [string ]any
89+ target * vmcp.BackendTarget
90+ }
91+
92+ // RoundTrip implements http.RoundTripper by adding authentication headers to requests.
93+ // The authentication strategy was pre-resolved and validated at client creation time,
94+ // so this method simply applies the authentication without any lookups or validation.
95+ func (a * authRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
96+ // Clone request to avoid modifying the original
97+ reqClone := req .Clone (req .Context ())
98+
99+ // Apply pre-resolved authentication strategy
100+ if err := a .authStrategy .Authenticate (reqClone .Context (), reqClone , a .authMetadata ); err != nil {
101+ return nil , fmt .Errorf ("authentication failed for backend %s: %w" , a .target .WorkloadID , err )
102+ }
103+
104+ logger .Debugf ("Applied authentication strategy %q to backend %s" , a .authStrategy .Name (), a .target .WorkloadID )
105+
106+ return a .base .RoundTrip (reqClone )
107+ }
108+
109+ // resolveAuthStrategy resolves the authentication strategy for a backend target.
110+ // It handles defaulting to "unauthenticated" when no strategy is specified.
111+ // This method should be called once at client creation time to enable fail-fast
112+ // behavior for invalid authentication configurations.
113+ func (h * httpBackendClient ) resolveAuthStrategy (target * vmcp.BackendTarget ) (auth.Strategy , error ) {
114+ strategyName := target .AuthStrategy
115+
116+ // Default to unauthenticated if not specified
117+ if strategyName == "" {
118+ strategyName = "unauthenticated"
119+ }
120+
121+ // Resolve strategy from registry
122+ strategy , err := h .registry .GetStrategy (strategyName )
123+ if err != nil {
124+ return nil , fmt .Errorf ("authentication strategy %q not found: %w" , strategyName , err )
125+ }
126+
127+ return strategy , nil
128+ }
129+
65130// defaultClientFactory creates mark3labs MCP clients for different transport types.
66- func defaultClientFactory (ctx context.Context , target * vmcp.BackendTarget ) (* client.Client , error ) {
67- // Create HTTP client with response size limits for DoS protection
131+ func (h * httpBackendClient ) defaultClientFactory (ctx context.Context , target * vmcp.BackendTarget ) (* client.Client , error ) {
132+ // Build transport chain: size limit → authentication → HTTP
133+ var baseTransport http.RoundTripper = http .DefaultTransport
134+
135+ // Resolve authentication strategy ONCE at client creation time
136+ authStrategy , err := h .resolveAuthStrategy (target )
137+ if err != nil {
138+ return nil , fmt .Errorf ("failed to resolve authentication for backend %s: %w" ,
139+ target .WorkloadID , err )
140+ }
141+
142+ // Validate metadata ONCE at client creation time
143+ if err := authStrategy .Validate (target .AuthMetadata ); err != nil {
144+ return nil , fmt .Errorf ("invalid authentication configuration for backend %s: %w" ,
145+ target .WorkloadID , err )
146+ }
147+
148+ // Add authentication layer with pre-resolved strategy
149+ baseTransport = & authRoundTripper {
150+ base : baseTransport ,
151+ authStrategy : authStrategy ,
152+ authMetadata : target .AuthMetadata ,
153+ target : target ,
154+ }
155+
156+ // Add size limit layer for DoS protection
157+ sizeLimitedTransport := roundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
158+ resp , err := baseTransport .RoundTrip (req )
159+ if err != nil {
160+ return nil , err
161+ }
162+ // Wrap response body with size limit
163+ resp .Body = struct {
164+ io.Reader
165+ io.Closer
166+ }{
167+ Reader : io .LimitReader (resp .Body , maxResponseSize ),
168+ Closer : resp .Body ,
169+ }
170+ return resp , nil
171+ })
172+
173+ // Create HTTP client with configured transport chain
68174 httpClient := & http.Client {
69- Transport : roundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
70- resp , err := http .DefaultTransport .RoundTrip (req )
71- if err != nil {
72- return nil , err
73- }
74- // Wrap response body with size limit
75- resp .Body = struct {
76- io.Reader
77- io.Closer
78- }{
79- Reader : io .LimitReader (resp .Body , maxResponseSize ),
80- Closer : resp .Body ,
81- }
82- return resp , nil
83- }),
175+ Transport : sizeLimitedTransport ,
84176 }
85177
86178 var c * client.Client
87- var err error
88179
89180 switch target .TransportType {
90181 case "streamable-http" , "streamable" :
@@ -93,8 +184,6 @@ func defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*cli
93184 transport .WithHTTPTimeout (0 ),
94185 transport .WithContinuousListening (),
95186 transport .WithHTTPBasicClient (httpClient ),
96- // TODO: Add authentication header injection via WithHTTPHeaderFunc
97- // This will be implemented when we add OutgoingAuthenticator support
98187 )
99188 if err != nil {
100189 return nil , fmt .Errorf ("failed to create streamable-http client: %w" , err )
0 commit comments