@@ -132,25 +132,40 @@ func testCaseWithContext(t *testing.T, httpCtx *httpContext, test func(c *httpCo
132132 test (httpCtx )
133133}
134134
135- func NewOidcTestServer (t * testing.T ) (privateKey * rsa.PrivateKey , oidcProvider * oidc.Provider , httpServer * httptest.Server ) {
135+ type OidcTestServer struct {
136+ * rsa.PrivateKey
137+ * oidc.Provider
138+ * httptest.Server
139+ TokenEndpointHandler http.HandlerFunc
140+ }
141+
142+ func NewOidcTestServer (t * testing.T ) (oidcTestServer * OidcTestServer ) {
136143 t .Helper ()
137- privateKey , err := rsa .GenerateKey (rand .Reader , 2048 )
144+ var err error
145+ oidcTestServer = & OidcTestServer {}
146+ oidcTestServer .PrivateKey , err = rsa .GenerateKey (rand .Reader , 2048 )
138147 if err != nil {
139148 t .Fatalf ("failed to generate private key for oidc: %v" , err )
140149 }
141150 oidcServer := & oidctest.Server {
142151 Algorithms : []string {oidc .RS256 , oidc .ES256 },
143152 PublicKeys : []oidctest.PublicKey {
144153 {
145- PublicKey : privateKey .Public (),
154+ PublicKey : oidcTestServer .Public (),
146155 KeyID : "test-oidc-key-id" ,
147156 Algorithm : oidc .RS256 ,
148157 },
149158 },
150159 }
151- httpServer = httptest .NewServer (oidcServer )
152- oidcServer .SetIssuer (httpServer .URL )
153- oidcProvider , err = oidc .NewProvider (t .Context (), httpServer .URL )
160+ oidcTestServer .Server = httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
161+ if r .URL .Path == "/token" && oidcTestServer .TokenEndpointHandler != nil {
162+ oidcTestServer .TokenEndpointHandler .ServeHTTP (w , r )
163+ return
164+ }
165+ oidcServer .ServeHTTP (w , r )
166+ }))
167+ oidcServer .SetIssuer (oidcTestServer .URL )
168+ oidcTestServer .Provider , err = oidc .NewProvider (t .Context (), oidcTestServer .URL )
154169 if err != nil {
155170 t .Fatalf ("failed to create OIDC provider: %v" , err )
156171 }
@@ -520,9 +535,9 @@ func TestAuthorizationUnauthorized(t *testing.T) {
520535 })
521536 })
522537 // Failed OIDC validation
523- key , oidcProvider , httpServer := NewOidcTestServer (t )
524- t .Cleanup (httpServer .Close )
525- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
538+ oidcTestServer := NewOidcTestServer (t )
539+ t .Cleanup (oidcTestServer .Close )
540+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
526541 req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
527542 if err != nil {
528543 t .Fatalf ("Failed to create request: %v" , err )
@@ -554,12 +569,12 @@ func TestAuthorizationUnauthorized(t *testing.T) {
554569 })
555570 // Failed Kubernetes TokenReview
556571 rawClaims := `{
557- "iss": "` + httpServer .URL + `",
572+ "iss": "` + oidcTestServer .URL + `",
558573 "exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
559574 "aud": "mcp-server"
560575 }`
561- validOidcToken := oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
562- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
576+ validOidcToken := oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
577+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : true }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
563578 req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
564579 if err != nil {
565580 t .Fatalf ("Failed to create request: %v" , err )
@@ -591,7 +606,6 @@ func TestAuthorizationUnauthorized(t *testing.T) {
591606 })
592607}
593608
594- // TestAuthorizationRequireOAuthFalse tests the scenario where OAuth is not required.
595609func TestAuthorizationRequireOAuthFalse (t * testing.T ) {
596610 testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : false }}, func (ctx * httpContext ) {
597611 resp , err := http .Get (fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ))
@@ -657,17 +671,17 @@ func TestAuthorizationRawToken(t *testing.T) {
657671}
658672
659673func TestAuthorizationOidcToken (t * testing.T ) {
660- key , oidcProvider , httpServer := NewOidcTestServer (t )
661- t .Cleanup (httpServer .Close )
674+ oidcTestServer := NewOidcTestServer (t )
675+ t .Cleanup (oidcTestServer .Close )
662676 rawClaims := `{
663- "iss": "` + httpServer .URL + `",
677+ "iss": "` + oidcTestServer .URL + `",
664678 "exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
665679 "aud": "mcp-server"
666680 }`
667- validOidcToken := oidctest .SignIDToken (key , "test-oidc-key-id" , oidc .RS256 , rawClaims )
681+ validOidcToken := oidctest .SignIDToken (oidcTestServer . PrivateKey , "test-oidc-key-id" , oidc .RS256 , rawClaims )
668682 cases := []bool {false , true }
669683 for _ , validateToken := range cases {
670- testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcProvider }, func (ctx * httpContext ) {
684+ testCaseWithContext (t , & httpContext {StaticConfig : & config.StaticConfig {RequireOAuth : true , OAuthAudience : "mcp-server" , ValidateToken : validateToken }, OidcProvider : oidcTestServer . Provider }, func (ctx * httpContext ) {
671685 tokenReviewed := false
672686 ctx .mockServer .Handle (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
673687 if req .URL .EscapedPath () == "/apis/authentication.k8s.io/v1/tokenreviews" {
@@ -701,6 +715,69 @@ func TestAuthorizationOidcToken(t *testing.T) {
701715 }
702716 })
703717 })
718+ }
719+ }
704720
721+ func TestAuthorizationOidcTokenExchange (t * testing.T ) {
722+ oidcTestServer := NewOidcTestServer (t )
723+ t .Cleanup (oidcTestServer .Close )
724+ rawClaims := `{
725+ "iss": "` + oidcTestServer .URL + `",
726+ "exp": ` + strconv .FormatInt (time .Now ().Add (time .Hour ).Unix (), 10 ) + `,
727+ "aud": "%s"
728+ }`
729+ validOidcClientToken := oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
730+ fmt .Sprintf (rawClaims , "mcp-server" ))
731+ validOidcBackendToken := oidctest .SignIDToken (oidcTestServer .PrivateKey , "test-oidc-key-id" , oidc .RS256 ,
732+ fmt .Sprintf (rawClaims , "backend-audience" ))
733+ oidcTestServer .TokenEndpointHandler = func (w http.ResponseWriter , r * http.Request ) {
734+ w .Header ().Set ("Content-Type" , "application/json" )
735+ _ , _ = fmt .Fprintf (w , `{"access_token":"%s","token_type":"Bearer","expires_in":253402297199}` , validOidcBackendToken )
736+ }
737+ cases := []bool {false , true }
738+ for _ , validateToken := range cases {
739+ staticConfig := & config.StaticConfig {
740+ RequireOAuth : true ,
741+ OAuthAudience : "mcp-server" ,
742+ ValidateToken : validateToken ,
743+ StsClientId : "test-sts-client-id" ,
744+ StsClientSecret : "test-sts-client-secret" ,
745+ StsAudience : "backend-audience" ,
746+ StsScopes : []string {"backend-scope" },
747+ }
748+ testCaseWithContext (t , & httpContext {StaticConfig : staticConfig , OidcProvider : oidcTestServer .Provider }, func (ctx * httpContext ) {
749+ tokenReviewed := false
750+ ctx .mockServer .Handle (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
751+ if req .URL .EscapedPath () == "/apis/authentication.k8s.io/v1/tokenreviews" {
752+ w .Header ().Set ("Content-Type" , "application/json" )
753+ _ , _ = w .Write ([]byte (tokenReviewSuccessful ))
754+ tokenReviewed = true
755+ return
756+ }
757+ }))
758+ req , err := http .NewRequest ("GET" , fmt .Sprintf ("http://%s/mcp" , ctx .HttpAddress ), nil )
759+ if err != nil {
760+ t .Fatalf ("Failed to create request: %v" , err )
761+ }
762+ req .Header .Set ("Authorization" , "Bearer " + validOidcClientToken )
763+ resp , err := http .DefaultClient .Do (req )
764+ if err != nil {
765+ t .Fatalf ("Failed to get protected endpoint: %v" , err )
766+ }
767+ t .Cleanup (func () { _ = resp .Body .Close () })
768+ t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header returns 200 - OK" , validateToken ), func (t * testing.T ) {
769+ if resp .StatusCode != http .StatusOK {
770+ t .Errorf ("Expected HTTP 200 OK, got %d" , resp .StatusCode )
771+ }
772+ })
773+ t .Run (fmt .Sprintf ("Protected resource with validate-token='%t' with VALID OIDC EXCHANGE Authorization header performs token validation accordingly" , validateToken ), func (t * testing.T ) {
774+ if tokenReviewed == true && ! validateToken {
775+ t .Errorf ("Expected token review to be skipped when validate-token is false, but it was performed" )
776+ }
777+ if tokenReviewed == false && validateToken {
778+ t .Errorf ("Expected token review to be performed when validate-token is true, but it was skipped" )
779+ }
780+ })
781+ })
705782 }
706783}
0 commit comments