Skip to content

Commit dbeef08

Browse files
JAORMXclaude
andauthored
Move remote MCP authentication handler to pkg/auth/remote (#2464)
Previously, RemoteAuthHandler and RemoteAuthConfig lived in pkg/runner/, mixing authentication implementation with runner execution logic. This change improves separation of concerns by moving auth logic to the auth package while keeping configuration properly organized. Changes: - Created pkg/auth/remote/ package for remote MCP authentication - Moved RemoteAuthHandler → authremote.Handler - Moved RemoteAuthConfig → authremote.Config - Updated all references in pkg/runner/, pkg/api/, and cmd/thv/app/ - Removed old remote_auth.go and remote_auth_test.go from pkg/runner/ Benefits: - Clear package boundaries (runner executes, auth authenticates) - Authentication logic can be tested independently of runner - Follows existing pattern where RunConfig uses types from domain packages (similar to how it uses registry.Header, transport/types.MiddlewareConfig) - Easier to reuse remote auth in other contexts 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <noreply@anthropic.com>
1 parent 03302b2 commit dbeef08

File tree

14 files changed

+191
-183
lines changed

14 files changed

+191
-183
lines changed

cmd/thv/app/run_flags.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/stacklok/toolhive/pkg/auth"
1111
authoauth "github.com/stacklok/toolhive/pkg/auth/oauth"
12+
"github.com/stacklok/toolhive/pkg/auth/remote"
1213
"github.com/stacklok/toolhive/pkg/authz"
1314
"github.com/stacklok/toolhive/pkg/cli"
1415
cfg "github.com/stacklok/toolhive/pkg/config"
@@ -615,7 +616,7 @@ func extractTelemetryValues(config *telemetry.Config) (string, float64, []string
615616
func getRemoteAuthFromRemoteServerMetadata(
616617
remoteServerMetadata *registry.RemoteServerMetadata,
617618
runFlags *RunFlags,
618-
) (*runner.RemoteAuthConfig, error) {
619+
) (*remote.Config, error) {
619620
if remoteServerMetadata == nil || remoteServerMetadata.OAuthConfig == nil {
620621
return getRemoteAuthFromRunFlags(runFlags)
621622
}
@@ -651,7 +652,7 @@ func getRemoteAuthFromRemoteServerMetadata(
651652
}
652653
}
653654

654-
authCfg := &runner.RemoteAuthConfig{
655+
authCfg := &remote.Config{
655656
ClientID: f.RemoteAuthClientID,
656657
ClientSecret: clientSecret,
657658
SkipBrowser: f.RemoteAuthSkipBrowser,
@@ -692,7 +693,7 @@ func getRemoteAuthFromRemoteServerMetadata(
692693
}
693694

694695
// getRemoteAuthFromRunFlags creates RemoteAuthConfig from RunFlags
695-
func getRemoteAuthFromRunFlags(runFlags *RunFlags) (*runner.RemoteAuthConfig, error) {
696+
func getRemoteAuthFromRunFlags(runFlags *RunFlags) (*remote.Config, error) {
696697
// Resolve OAuth client secret from multiple sources (flag, file, environment variable)
697698
// This follows the same priority as resolveSecret: flag → file → environment variable
698699
resolvedClientSecret, err := resolveSecret(
@@ -714,7 +715,7 @@ func getRemoteAuthFromRunFlags(runFlags *RunFlags) (*runner.RemoteAuthConfig, er
714715
}
715716
}
716717

717-
return &runner.RemoteAuthConfig{
718+
return &remote.Config{
718719
ClientID: runFlags.RemoteAuthFlags.RemoteAuthClientID,
719720
ClientSecret: clientSecret,
720721
Scopes: runFlags.RemoteAuthFlags.RemoteAuthScopes,

docs/server/docs.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/server/swagger.yaml

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/api/v1/workload_service.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"time"
77

8+
"github.com/stacklok/toolhive/pkg/auth/remote"
89
"github.com/stacklok/toolhive/pkg/container/runtime"
910
"github.com/stacklok/toolhive/pkg/groups"
1011
"github.com/stacklok/toolhive/pkg/logger"
@@ -116,7 +117,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
116117
return nil, fmt.Errorf("group '%s' does not exist", groupName)
117118
}
118119

119-
var remoteAuthConfig *runner.RemoteAuthConfig
120+
var remoteAuthConfig *remote.Config
120121
var imageURL string
121122
var imageMetadata *registry.ImageMetadata
122123
var serverMetadata registry.ServerMetadata
@@ -151,7 +152,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
151152

152153
if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok {
153154
if remoteServerMetadata.OAuthConfig != nil {
154-
remoteAuthConfig = &runner.RemoteAuthConfig{
155+
remoteAuthConfig = &remote.Config{
155156
ClientID: req.OAuthConfig.ClientID,
156157
Scopes: remoteServerMetadata.OAuthConfig.Scopes,
157158
CallbackPort: remoteServerMetadata.OAuthConfig.CallbackPort,
@@ -251,10 +252,10 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq
251252
func createRequestToRemoteAuthConfig(
252253
_ context.Context,
253254
req *createRequest,
254-
) *runner.RemoteAuthConfig {
255+
) *remote.Config {
255256

256257
// Create RemoteAuthConfig
257-
remoteAuthConfig := &runner.RemoteAuthConfig{
258+
remoteAuthConfig := &remote.Config{
258259
ClientID: req.OAuthConfig.ClientID,
259260
Scopes: req.OAuthConfig.Scopes,
260261
Issuer: req.OAuthConfig.Issuer,

pkg/api/v1/workloads_types_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/stretchr/testify/require"
88

99
"github.com/stacklok/toolhive/pkg/auth"
10+
"github.com/stacklok/toolhive/pkg/auth/remote"
1011
"github.com/stacklok/toolhive/pkg/permissions"
1112
"github.com/stacklok/toolhive/pkg/runner"
1213
"github.com/stacklok/toolhive/pkg/transport/types"
@@ -152,7 +153,7 @@ func TestRunConfigToCreateRequest(t *testing.T) {
152153

153154
runConfig := &runner.RunConfig{
154155
Name: "test-workload",
155-
RemoteAuthConfig: &runner.RemoteAuthConfig{
156+
RemoteAuthConfig: &remote.Config{
156157
Issuer: "https://oauth.example.com",
157158
AuthorizeURL: "https://oauth.example.com/auth",
158159
TokenURL: "https://oauth.example.com/token",
@@ -189,7 +190,7 @@ func TestRunConfigToCreateRequest(t *testing.T) {
189190

190191
runConfig := &runner.RunConfig{
191192
Name: "test-workload",
192-
RemoteAuthConfig: &runner.RemoteAuthConfig{
193+
RemoteAuthConfig: &remote.Config{
193194
Issuer: "https://oauth.example.com",
194195
AuthorizeURL: "https://oauth.example.com/auth",
195196
TokenURL: "https://oauth.example.com/token",

pkg/auth/remote/config.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package remote
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"time"
7+
8+
"github.com/stacklok/toolhive/pkg/registry"
9+
)
10+
11+
// Config holds authentication configuration for remote MCP servers.
12+
// Supports OAuth/OIDC-based authentication with automatic discovery.
13+
type Config struct {
14+
ClientID string `json:"client_id,omitempty" yaml:"client_id,omitempty"`
15+
ClientSecret string `json:"client_secret,omitempty" yaml:"client_secret,omitempty"`
16+
ClientSecretFile string `json:"client_secret_file,omitempty" yaml:"client_secret_file,omitempty"`
17+
Scopes []string `json:"scopes,omitempty" yaml:"scopes,omitempty"`
18+
SkipBrowser bool `json:"skip_browser,omitempty" yaml:"skip_browser,omitempty"`
19+
Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty" swaggertype:"string" example:"5m"`
20+
CallbackPort int `json:"callback_port,omitempty" yaml:"callback_port,omitempty"`
21+
UsePKCE bool `json:"use_pkce" yaml:"use_pkce"`
22+
23+
// OAuth endpoint configuration (from registry)
24+
Issuer string `json:"issuer,omitempty" yaml:"issuer,omitempty"`
25+
AuthorizeURL string `json:"authorize_url,omitempty" yaml:"authorize_url,omitempty"`
26+
TokenURL string `json:"token_url,omitempty" yaml:"token_url,omitempty"`
27+
28+
// Headers for HTTP requests
29+
Headers []*registry.Header `json:"headers,omitempty" yaml:"headers,omitempty"`
30+
31+
// Environment variables for the client
32+
EnvVars []*registry.EnvVar `json:"env_vars,omitempty" yaml:"env_vars,omitempty"`
33+
34+
// OAuth parameters for server-specific customization
35+
OAuthParams map[string]string `json:"oauth_params,omitempty" yaml:"oauth_params,omitempty"`
36+
}
37+
38+
// UnmarshalJSON implements custom JSON unmarshaling for backward compatibility
39+
// This handles both the old PascalCase format and the new snake_case format
40+
func (r *Config) UnmarshalJSON(data []byte) error {
41+
// Parse the JSON to check which format is being used
42+
var raw map[string]interface{}
43+
if err := json.Unmarshal(data, &raw); err != nil {
44+
return err
45+
}
46+
47+
// Check if this is the old PascalCase format by looking for old field name
48+
// if one old field is present, then it's the old format
49+
if _, isOld := raw["ClientID"]; isOld {
50+
// Unmarshal using old PascalCase format
51+
var oldFormat struct {
52+
ClientID string `json:"ClientID,omitempty"`
53+
ClientSecret string `json:"ClientSecret,omitempty"`
54+
ClientSecretFile string `json:"ClientSecretFile,omitempty"`
55+
Scopes []string `json:"Scopes,omitempty"`
56+
SkipBrowser bool `json:"SkipBrowser,omitempty"`
57+
Timeout time.Duration `json:"Timeout,omitempty"`
58+
CallbackPort int `json:"CallbackPort,omitempty"`
59+
UsePKCE bool `json:"UsePKCE,omitempty"`
60+
Issuer string `json:"Issuer,omitempty"`
61+
AuthorizeURL string `json:"AuthorizeURL,omitempty"`
62+
TokenURL string `json:"TokenURL,omitempty"`
63+
Headers []*registry.Header `json:"Headers,omitempty"`
64+
EnvVars []*registry.EnvVar `json:"EnvVars,omitempty"`
65+
OAuthParams map[string]string `json:"OAuthParams,omitempty"`
66+
}
67+
68+
if err := json.Unmarshal(data, &oldFormat); err != nil {
69+
return fmt.Errorf("failed to unmarshal Config in old format: %w", err)
70+
}
71+
72+
// Copy from old format to new format
73+
r.ClientID = oldFormat.ClientID
74+
r.ClientSecret = oldFormat.ClientSecret
75+
r.ClientSecretFile = oldFormat.ClientSecretFile
76+
r.Scopes = oldFormat.Scopes
77+
r.SkipBrowser = oldFormat.SkipBrowser
78+
r.Timeout = oldFormat.Timeout
79+
r.CallbackPort = oldFormat.CallbackPort
80+
r.UsePKCE = oldFormat.UsePKCE
81+
r.Issuer = oldFormat.Issuer
82+
r.AuthorizeURL = oldFormat.AuthorizeURL
83+
r.TokenURL = oldFormat.TokenURL
84+
r.Headers = oldFormat.Headers
85+
r.EnvVars = oldFormat.EnvVars
86+
r.OAuthParams = oldFormat.OAuthParams
87+
return nil
88+
}
89+
90+
// Use the new snake_case format
91+
type Alias Config
92+
alias := (*Alias)(r)
93+
return json.Unmarshal(data, alias)
94+
}
95+
96+
// DefaultCallbackPort is the default port for the OAuth callback server
97+
const DefaultCallbackPort = 8666

pkg/auth/remote/doc.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Package remote provides authentication handling for remote MCP servers.
2+
//
3+
// This package implements OAuth/OIDC-based authentication with automatic
4+
// discovery support for remote MCP servers. It handles:
5+
// - OAuth issuer discovery (RFC 8414)
6+
// - Protected resource metadata (RFC 9728)
7+
// - OAuth flow execution (PKCE-based)
8+
// - Token source creation for HTTP transports
9+
//
10+
// The main entry point is Handler.Authenticate() which takes a remote URL
11+
// and performs all necessary discovery and authentication steps.
12+
//
13+
// Configuration is defined in pkg/runner.RemoteAuthConfig as part of the
14+
// runner's RunConfig structure.
15+
package remote

pkg/runner/remote_auth.go renamed to pkg/auth/remote/handler.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package runner
1+
package remote
22

33
import (
44
"context"
@@ -10,21 +10,21 @@ import (
1010
"github.com/stacklok/toolhive/pkg/logger"
1111
)
1212

13-
// RemoteAuthHandler handles authentication for remote MCP servers.
13+
// Handler handles authentication for remote MCP servers.
1414
// Supports OAuth/OIDC-based authentication with automatic discovery.
15-
type RemoteAuthHandler struct {
16-
config *RemoteAuthConfig
15+
type Handler struct {
16+
config *Config
1717
}
1818

19-
// NewRemoteAuthHandler creates a new remote authentication handler
20-
func NewRemoteAuthHandler(config *RemoteAuthConfig) *RemoteAuthHandler {
21-
return &RemoteAuthHandler{
19+
// NewHandler creates a new remote authentication handler
20+
func NewHandler(config *Config) *Handler {
21+
return &Handler{
2222
config: config,
2323
}
2424
}
2525

2626
// Authenticate is the main entry point for remote MCP server authentication
27-
func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string) (oauth2.TokenSource, error) {
27+
func (h *Handler) Authenticate(ctx context.Context, remoteURL string) (oauth2.TokenSource, error) {
2828

2929
// First, try to detect if authentication is required
3030
authInfo, err := discovery.DetectAuthenticationFromServer(ctx, remoteURL, nil)
@@ -89,7 +89,7 @@ func (h *RemoteAuthHandler) Authenticate(ctx context.Context, remoteURL string)
8989
// discoverIssuerAndScopes attempts to discover the OAuth issuer and scopes from various sources
9090
// following RFC 8414 and RFC 9728 standards
9191
// If the issuer is not derived from Realm and Resource Metadata, it derives from the remote URL
92-
func (h *RemoteAuthHandler) discoverIssuerAndScopes(
92+
func (h *Handler) discoverIssuerAndScopes(
9393
ctx context.Context,
9494
authInfo *discovery.AuthInfo,
9595
remoteURL string,
@@ -135,7 +135,7 @@ func (h *RemoteAuthHandler) discoverIssuerAndScopes(
135135
}
136136

137137
// tryDiscoverFromResourceMetadata attempts to discover issuer and scopes from resource metadata
138-
func (h *RemoteAuthHandler) tryDiscoverFromResourceMetadata(
138+
func (h *Handler) tryDiscoverFromResourceMetadata(
139139
ctx context.Context,
140140
resourceMetadataURL string,
141141
) (string, []string, *discovery.AuthServerInfo, error) {
@@ -172,7 +172,7 @@ func (h *RemoteAuthHandler) tryDiscoverFromResourceMetadata(
172172
}
173173

174174
// findValidAuthServer validates authorization servers and returns the first valid one
175-
func (*RemoteAuthHandler) findValidAuthServer(
175+
func (*Handler) findValidAuthServer(
176176
ctx context.Context,
177177
authServers []string,
178178
) (*discovery.AuthServerInfo, string) {
@@ -197,7 +197,7 @@ func (*RemoteAuthHandler) findValidAuthServer(
197197
// tryDiscoverFromWellKnown attempts to discover the actual OAuth issuer
198198
// by probing the server's well-known endpoints without validating issuer match
199199
// This is useful when the issuer differs from the server URL (e.g., Atlassian case)
200-
func (h *RemoteAuthHandler) tryDiscoverFromWellKnown(
200+
func (h *Handler) tryDiscoverFromWellKnown(
201201
ctx context.Context,
202202
remoteURL string,
203203
) (string, []string, *discovery.AuthServerInfo, error) {

0 commit comments

Comments
 (0)