Skip to content

Commit 61d04b4

Browse files
authored
Secure OAuth client secret storage and prevent API exposure (#2204)
* fix: secure OAuth client secret storage and prevent API exposure - Implement secure client secret storage using secret references - Convert plain text CLI secrets to secure secret manager references - Add CLI format string support for secret parameters - Centralize OAuth client secret processing across all entry points - Add comprehensive unit tests for utility functions - Prevent plain text secret exposure in API responses Fixes security vulnerability where OAuth client secrets were stored as plain text and exposed in API responses. * secure OAuth client secret storage and prevent API exposure - Implement secure client secret storage using secret references - Convert plain text CLI secrets to secure secret manager references - Add CLI format string support for secret parameters - Centralize OAuth client secret processing across all entry points - Add comprehensive unit tests for utility functions - Prevent plain text secret exposure in API responses Fixes security vulnerability where OAuth client secrets were stored as plain text and exposed in API responses.
1 parent d4b7b0c commit 61d04b4

File tree

9 files changed

+432
-87
lines changed

9 files changed

+432
-87
lines changed

cmd/thv/app/oauth_secret_test.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package app
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
8+
"github.com/stacklok/toolhive/pkg/secrets"
9+
)
10+
11+
func TestGenerateOAuthClientSecretName(t *testing.T) {
12+
t.Parallel()
13+
14+
testCases := []struct {
15+
name string
16+
workloadName string
17+
expected string
18+
}{
19+
{
20+
name: "normal workload name",
21+
workloadName: "test-workload",
22+
expected: "OAUTH_CLIENT_SECRET_test-workload",
23+
},
24+
{
25+
name: "empty workload name",
26+
workloadName: "",
27+
expected: "OAUTH_CLIENT_SECRET_",
28+
},
29+
{
30+
name: "workload name with special characters",
31+
workloadName: "test-workload-123",
32+
expected: "OAUTH_CLIENT_SECRET_test-workload-123",
33+
},
34+
{
35+
name: "workload name with underscores",
36+
workloadName: "test_workload",
37+
expected: "OAUTH_CLIENT_SECRET_test_workload",
38+
},
39+
}
40+
41+
for _, tc := range testCases {
42+
tc := tc
43+
t.Run(tc.name, func(t *testing.T) {
44+
t.Parallel()
45+
46+
result := generateOAuthClientSecretName(tc.workloadName)
47+
assert.Equal(t, tc.expected, result)
48+
})
49+
}
50+
}
51+
52+
// TestSecretParameterToCLIString tests the ToCLIString method
53+
func TestSecretParameterToCLIString(t *testing.T) {
54+
t.Parallel()
55+
56+
testCases := []struct {
57+
name string
58+
param secrets.SecretParameter
59+
expected string
60+
}{
61+
{
62+
name: "normal secret parameter",
63+
param: secrets.SecretParameter{
64+
Name: "SECRET_NAME",
65+
Target: "oauth_secret",
66+
},
67+
expected: "SECRET_NAME,target=oauth_secret",
68+
},
69+
{
70+
name: "secret parameter with different target",
71+
param: secrets.SecretParameter{
72+
Name: "API_KEY",
73+
Target: "API_KEY",
74+
},
75+
expected: "API_KEY,target=API_KEY",
76+
},
77+
{
78+
name: "secret parameter with special characters",
79+
param: secrets.SecretParameter{
80+
Name: "SECRET-NAME-123",
81+
Target: "SECRET_TARGET",
82+
},
83+
expected: "SECRET-NAME-123,target=SECRET_TARGET",
84+
},
85+
}
86+
87+
for _, tc := range testCases {
88+
tc := tc
89+
t.Run(tc.name, func(t *testing.T) {
90+
t.Parallel()
91+
92+
result := tc.param.ToCLIString()
93+
assert.Equal(t, tc.expected, result)
94+
})
95+
}
96+
}
97+
98+
// TestParseSecretParameter tests the ParseSecretParameter function
99+
func TestParseSecretParameter(t *testing.T) {
100+
t.Parallel()
101+
102+
testCases := []struct {
103+
name string
104+
parameter string
105+
expectedResult secrets.SecretParameter
106+
expectError bool
107+
errorContains string
108+
}{
109+
{
110+
name: "valid CLI format",
111+
parameter: "SECRET_NAME,target=oauth_secret",
112+
expectedResult: secrets.SecretParameter{
113+
Name: "SECRET_NAME",
114+
Target: "oauth_secret",
115+
},
116+
expectError: false,
117+
},
118+
{
119+
name: "valid CLI format with different target",
120+
parameter: "API_KEY,target=API_KEY",
121+
expectedResult: secrets.SecretParameter{
122+
Name: "API_KEY",
123+
Target: "API_KEY",
124+
},
125+
expectError: false,
126+
},
127+
{
128+
name: "empty parameter",
129+
parameter: "",
130+
expectedResult: secrets.SecretParameter{},
131+
expectError: true,
132+
errorContains: "secret parameter cannot be empty",
133+
},
134+
{
135+
name: "invalid format - no target",
136+
parameter: "SECRET_NAME",
137+
expectedResult: secrets.SecretParameter{},
138+
expectError: true,
139+
errorContains: "invalid secret parameter format",
140+
},
141+
{
142+
name: "invalid format - no comma",
143+
parameter: "SECRET_NAME target=oauth_secret",
144+
expectedResult: secrets.SecretParameter{},
145+
expectError: true,
146+
errorContains: "invalid secret parameter format",
147+
},
148+
{
149+
name: "invalid format - no equals",
150+
parameter: "SECRET_NAME,target oauth_secret",
151+
expectedResult: secrets.SecretParameter{},
152+
expectError: true,
153+
errorContains: "invalid secret parameter format",
154+
},
155+
}
156+
157+
for _, tc := range testCases {
158+
tc := tc
159+
t.Run(tc.name, func(t *testing.T) {
160+
t.Parallel()
161+
162+
result, err := secrets.ParseSecretParameter(tc.parameter)
163+
164+
if tc.expectError {
165+
assert.Error(t, err)
166+
if tc.errorContains != "" {
167+
assert.Contains(t, err.Error(), tc.errorContains)
168+
}
169+
assert.Equal(t, secrets.SecretParameter{}, result)
170+
} else {
171+
assert.NoError(t, err)
172+
assert.Equal(t, tc.expectedResult, result)
173+
}
174+
})
175+
}
176+
}

cmd/thv/app/run_flags.go

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"strings"
7+
"time"
78

89
"github.com/spf13/cobra"
910

@@ -15,11 +16,13 @@ import (
1516
"github.com/stacklok/toolhive/pkg/container/runtime"
1617
"github.com/stacklok/toolhive/pkg/environment"
1718
"github.com/stacklok/toolhive/pkg/ignore"
19+
"github.com/stacklok/toolhive/pkg/logger"
1820
"github.com/stacklok/toolhive/pkg/networking"
1921
"github.com/stacklok/toolhive/pkg/process"
2022
"github.com/stacklok/toolhive/pkg/registry"
2123
"github.com/stacklok/toolhive/pkg/runner"
2224
"github.com/stacklok/toolhive/pkg/runner/retriever"
25+
"github.com/stacklok/toolhive/pkg/secrets"
2326
"github.com/stacklok/toolhive/pkg/telemetry"
2427
"github.com/stacklok/toolhive/pkg/transport"
2528
"github.com/stacklok/toolhive/pkg/transport/types"
@@ -548,7 +551,10 @@ func configureRemoteAuth(runFlags *RunFlags, serverMetadata registry.ServerMetad
548551
var opts []runner.RunConfigBuilderOption
549552

550553
if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok {
551-
remoteAuthConfig := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata, runFlags)
554+
remoteAuthConfig, err := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata, runFlags)
555+
if err != nil {
556+
return nil, err
557+
}
552558

553559
// Validate OAuth callback port availability upfront for better user experience
554560
if err := networking.ValidateCallbackPort(remoteAuthConfig.CallbackPort, remoteAuthConfig.ClientID); err != nil {
@@ -559,7 +565,10 @@ func configureRemoteAuth(runFlags *RunFlags, serverMetadata registry.ServerMetad
559565
}
560566

561567
if runFlags.RemoteURL != "" {
562-
remoteAuthConfig := getRemoteAuthFromRunFlags(runFlags)
568+
remoteAuthConfig, err := getRemoteAuthFromRunFlags(runFlags)
569+
if err != nil {
570+
return nil, err
571+
}
563572

564573
// Validate OAuth callback port availability upfront for better user experience
565574
if err := networking.ValidateCallbackPort(remoteAuthConfig.CallbackPort, remoteAuthConfig.ClientID); err != nil {
@@ -593,7 +602,7 @@ func extractTelemetryValues(config *telemetry.Config) (string, float64, []string
593602
func getRemoteAuthFromRemoteServerMetadata(
594603
remoteServerMetadata *registry.RemoteServerMetadata,
595604
runFlags *RunFlags,
596-
) *runner.RemoteAuthConfig {
605+
) (*runner.RemoteAuthConfig, error) {
597606
if remoteServerMetadata == nil || remoteServerMetadata.OAuthConfig == nil {
598607
return getRemoteAuthFromRunFlags(runFlags)
599608
}
@@ -608,9 +617,26 @@ func getRemoteAuthFromRemoteServerMetadata(
608617
return b
609618
}
610619

620+
// Resolve OAuth client secret from multiple sources (flag, file, environment variable)
621+
// This follows the same priority as resolveSecret: flag → file → environment variable
622+
resolvedClientSecret, err := resolveSecret(
623+
f.RemoteAuthClientSecret,
624+
f.RemoteAuthClientSecretFile,
625+
"", // No specific environment variable for OAuth client secret
626+
)
627+
if err != nil {
628+
return nil, fmt.Errorf("failed to resolve OAuth client secret: %w", err)
629+
}
630+
631+
// Process the resolved client secret (convert plain text to secret reference if needed)
632+
clientSecret, err := processOAuthClientSecret(resolvedClientSecret, runFlags.Name)
633+
if err != nil {
634+
return nil, fmt.Errorf("failed to process OAuth client secret: %w", err)
635+
}
636+
611637
authCfg := &runner.RemoteAuthConfig{
612638
ClientID: f.RemoteAuthClientID,
613-
ClientSecret: f.RemoteAuthClientSecret,
639+
ClientSecret: clientSecret,
614640
SkipBrowser: f.RemoteAuthSkipBrowser,
615641
Timeout: f.RemoteAuthTimeout,
616642
Headers: remoteServerMetadata.Headers,
@@ -645,14 +671,31 @@ func getRemoteAuthFromRemoteServerMetadata(
645671
authCfg.OAuthParams = oc.OAuthParams
646672
}
647673

648-
return authCfg
674+
return authCfg, nil
649675
}
650676

651677
// getRemoteAuthFromRunFlags creates RemoteAuthConfig from RunFlags
652-
func getRemoteAuthFromRunFlags(runFlags *RunFlags) *runner.RemoteAuthConfig {
678+
func getRemoteAuthFromRunFlags(runFlags *RunFlags) (*runner.RemoteAuthConfig, error) {
679+
// Resolve OAuth client secret from multiple sources (flag, file, environment variable)
680+
// This follows the same priority as resolveSecret: flag → file → environment variable
681+
resolvedClientSecret, err := resolveSecret(
682+
runFlags.RemoteAuthFlags.RemoteAuthClientSecret,
683+
runFlags.RemoteAuthFlags.RemoteAuthClientSecretFile,
684+
"", // No specific environment variable for OAuth client secret
685+
)
686+
if err != nil {
687+
return nil, fmt.Errorf("failed to resolve OAuth client secret: %w", err)
688+
}
689+
690+
// Process the resolved client secret (convert plain text to secret reference if needed)
691+
clientSecret, err := processOAuthClientSecret(resolvedClientSecret, runFlags.Name)
692+
if err != nil {
693+
return nil, fmt.Errorf("failed to process OAuth client secret: %w", err)
694+
}
695+
653696
return &runner.RemoteAuthConfig{
654697
ClientID: runFlags.RemoteAuthFlags.RemoteAuthClientID,
655-
ClientSecret: runFlags.RemoteAuthFlags.RemoteAuthClientSecret,
698+
ClientSecret: clientSecret,
656699
Scopes: runFlags.RemoteAuthFlags.RemoteAuthScopes,
657700
SkipBrowser: runFlags.RemoteAuthFlags.RemoteAuthSkipBrowser,
658701
Timeout: runFlags.RemoteAuthFlags.RemoteAuthTimeout,
@@ -661,7 +704,7 @@ func getRemoteAuthFromRunFlags(runFlags *RunFlags) *runner.RemoteAuthConfig {
661704
AuthorizeURL: runFlags.RemoteAuthFlags.RemoteAuthAuthorizeURL,
662705
TokenURL: runFlags.RemoteAuthFlags.RemoteAuthTokenURL,
663706
OAuthParams: runFlags.OAuthParams,
664-
}
707+
}, nil
665708
}
666709

667710
// getOidcFromFlags extracts OIDC configuration from command flags
@@ -778,3 +821,87 @@ func createTelemetryConfig(otelEndpoint string, otelEnablePrometheusMetricsPath
778821
EnvironmentVariables: processedEnvVars,
779822
}
780823
}
824+
825+
// processOAuthClientSecret processes an OAuth client secret, converting plain text to secret reference if needed
826+
func processOAuthClientSecret(clientSecret, workloadName string) (string, error) {
827+
if clientSecret == "" {
828+
return "", nil
829+
}
830+
831+
// Check if it's already in CLI format (contains ",target=")
832+
if _, err := secrets.ParseSecretParameter(clientSecret); err == nil {
833+
// Already in CLI format, use as-is
834+
return clientSecret, nil
835+
}
836+
837+
// It's plain text, we must convert to secret reference
838+
uniqueSecretName, err := findUniqueSecretName(workloadName)
839+
if err != nil {
840+
logger.Errorf("Failed to find unique secret name: %v", err)
841+
return "", err
842+
}
843+
844+
if err := storeSecretInManager(uniqueSecretName, clientSecret); err != nil {
845+
logger.Errorf("Failed to store OAuth client secret: %v", err)
846+
// This is a critical error - we cannot proceed without storing the secret
847+
return "", err
848+
}
849+
850+
// Return CLI format reference to the stored secret
851+
return secrets.SecretParameter{Name: uniqueSecretName, Target: "oauth_secret"}.ToCLIString(), nil
852+
}
853+
854+
// generateOAuthClientSecretName generates a base secret name for an OAuth client secret
855+
func generateOAuthClientSecretName(workloadName string) string {
856+
return fmt.Sprintf("OAUTH_CLIENT_SECRET_%s", workloadName)
857+
}
858+
859+
// findUniqueSecretName finds a unique secret name, handling conflicts by appending timestamps
860+
func findUniqueSecretName(workloadName string) (string, error) {
861+
baseName := generateOAuthClientSecretName(workloadName)
862+
863+
// Get the secrets manager to check for existing secrets
864+
secretManager, err := getSecretsManager()
865+
if err != nil {
866+
return "", fmt.Errorf("failed to get secrets manager: %w", err)
867+
}
868+
869+
// Check if the base name is available
870+
ctx := context.Background()
871+
_, err = secretManager.GetSecret(ctx, baseName)
872+
if err != nil {
873+
// Secret doesn't exist, we can use the base name
874+
return baseName, nil
875+
}
876+
877+
// Secret exists, generate a unique name with timestamp
878+
timestamp := time.Now().Unix()
879+
uniqueName := fmt.Sprintf("%s-%d", baseName, timestamp)
880+
return uniqueName, nil
881+
}
882+
883+
// storeSecretInManager stores a secret in the configured secret manager
884+
func storeSecretInManager(secretName, secretValue string) error {
885+
// Use existing getSecretsManager function from secret.go
886+
secretManager, err := getSecretsManager()
887+
if err != nil {
888+
return fmt.Errorf("failed to get secrets manager: %w", err)
889+
}
890+
891+
// Check if the provider supports writing secrets
892+
if !secretManager.Capabilities().CanWrite {
893+
configProvider := cfg.NewDefaultProvider()
894+
config := configProvider.GetConfig()
895+
providerType, _ := config.Secrets.GetProviderType()
896+
return fmt.Errorf("secrets provider %s does not support writing secrets (read-only)", providerType)
897+
}
898+
899+
// Store the secret
900+
ctx := context.Background()
901+
if err := secretManager.SetSecret(ctx, secretName, secretValue); err != nil {
902+
return fmt.Errorf("failed to store secret %s: %w", secretName, err)
903+
}
904+
905+
logger.Debugf("Stored secret: %s", secretName)
906+
return nil
907+
}

0 commit comments

Comments
 (0)