Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Release (2025-xx-xx)
- `core`: [v0.18.0](core/CHANGELOG.md#v0180)
- **New:** Added duration utils
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`
- `stackitmarketplace`: [v1.16.0](services/stackitmarketplace/CHANGELOG.md#v1160)
- **Breaking Change:** Remove unused `ProjectId` model struct
- `iaas`: [v1.1.0](services/iaas/CHANGELOG.md#v110)
Expand Down
1 change: 1 addition & 0 deletions core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## v0.18.0
- **New:** Added duration utils
- **Chore:** Use `jwt-bearer` grant to get a fresh token instead of `refresh_token`

## v0.17.3
- **Dependencies:** Bump `github.com/golang-jwt/jwt/v5` from `v5.2.2` to `v5.2.3`
Expand Down
74 changes: 14 additions & 60 deletions core/clients/key_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ type KeyFlowConfig struct {
// TokenResponseBody is the API response
// when requesting a new token
type TokenResponseBody struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}

// ServiceAccountKeyResponse is the API response
Expand Down Expand Up @@ -158,9 +157,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
return nil
}

// SetToken can be used to set an access and refresh token manually in the client.
// SetToken can be used to set an access token manually in the client.
// The other fields in the token field are determined by inspecting the token or setting default values.
func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
func (c *KeyFlow) SetToken(accessToken string) error {
// We can safely use ParseUnverified because we are not authenticating the user,
// We are parsing the token just to get the expiration time claim
parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{})
Expand All @@ -174,11 +173,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {

c.tokenMutex.Lock()
c.token = &TokenResponseBody{
AccessToken: accessToken,
ExpiresIn: int(exp.Time.Unix()),
Scope: defaultScope,
RefreshToken: refreshToken,
TokenType: defaultTokenType,
AccessToken: accessToken,
ExpiresIn: int(exp.Time.Unix()),
Scope: defaultScope,
TokenType: defaultTokenType,
}
c.tokenMutex.Unlock()
return nil
Expand All @@ -198,7 +196,7 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
return c.rt.RoundTrip(req)
}

// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
// GetAccessToken returns a short-lived access token and saves the access token in the token field
func (c *KeyFlow) GetAccessToken() (string, error) {
if c.rt == nil {
return "", fmt.Errorf("nil http round tripper, please run Init()")
Expand All @@ -219,7 +217,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
if !accessTokenExpired {
return accessToken, nil
}
if err = c.recreateAccessToken(); err != nil {
if err = c.createAccessToken(); err != nil {
var oapiErr *oapierror.GenericOpenAPIError
if ok := errors.As(err, &oapiErr); ok {
reg := regexp.MustCompile("Key with kid .*? was not found")
Expand Down Expand Up @@ -269,27 +267,6 @@ func (c *KeyFlow) validate() error {

// Flow auth functions

// recreateAccessToken is used to create a new access token
// when the existing one isn't valid anymore
func (c *KeyFlow) recreateAccessToken() error {
var refreshToken string

c.tokenMutex.RLock()
if c.token != nil {
refreshToken = c.token.RefreshToken
}
c.tokenMutex.RUnlock()

refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway)
if err != nil {
return err
}
if !refreshTokenExpired {
return c.createAccessTokenWithRefreshToken()
}
return c.createAccessToken()
}

// createAccessToken creates an access token using self signed JWT
func (c *KeyFlow) createAccessToken() (err error) {
grant := "urn:ietf:params:oauth:grant-type:jwt-bearer"
Expand All @@ -310,26 +287,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
return c.parseTokenResponse(res)
}

// createAccessTokenWithRefreshToken creates an access token using
// an existing pre-validated refresh token
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
c.tokenMutex.RLock()
refreshToken := c.token.RefreshToken
c.tokenMutex.RUnlock()

res, err := c.requestToken("refresh_token", refreshToken)
if err != nil {
return err
}
defer func() {
tempErr := res.Body.Close()
if tempErr != nil && err == nil {
err = fmt.Errorf("close request access token with refresh token response: %w", tempErr)
}
}()
return c.parseTokenResponse(res)
}

// generateSelfSignedJWT generates JWT token
func (c *KeyFlow) generateSelfSignedJWT() (string, error) {
claims := jwt.MapClaims{
Expand All @@ -353,11 +310,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) {
func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) {
body := url.Values{}
body.Set("grant_type", grant)
if grant == "refresh_token" {
body.Set("refresh_token", assertion)
} else {
body.Set("assertion", assertion)
}
body.Set("assertion", assertion)

payload := strings.NewReader(body.Encode())
req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion core/clients/key_flow_continuous_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim
// - (false, nil) if not successful but should be retried.
// - (_, err) if not successful and shouldn't be retried.
func (refresher *continuousTokenRefresher) refreshToken() (bool, error) {
err := refresher.keyFlow.recreateAccessToken()
err := refresher.keyFlow.createAccessToken()
if err == nil {
return true, nil
}
Expand Down
41 changes: 9 additions & 32 deletions core/clients/key_flow_continuous_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,8 @@ func TestContinuousRefreshToken(t *testing.T) {
t.Fatalf("failed to create access token: %v", err)
}

refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create refresh token: %v", err)
}

numberDoCalls := 0
mockDo := func(_ *http.Request) (resp *http.Response, err error) {
mockDo := func(r *http.Request) (resp *http.Response, err error) {
numberDoCalls++ // count refresh attempts
if tt.doError != nil {
return nil, tt.doError
Expand All @@ -115,8 +108,7 @@ func TestContinuousRefreshToken(t *testing.T) {
t.Fatalf("Do call: failed to create access token: %v", err)
}
responseBodyStruct := TokenResponseBody{
AccessToken: newAccessToken,
RefreshToken: refreshToken,
AccessToken: newAccessToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
Expand Down Expand Up @@ -153,7 +145,7 @@ func TestContinuousRefreshToken(t *testing.T) {
}

// Set the token after initialization
err = keyFlow.SetToken(accessToken, refreshToken)
err = keyFlow.SetToken(accessToken)
if err != nil {
t.Fatalf("failed to set token: %v", err)
}
Expand Down Expand Up @@ -186,7 +178,7 @@ func TestContinuousRefreshToken(t *testing.T) {
}

// Tests if
// - continuousRefreshToken() updates access token using the refresh token
// - continuousRefreshToken() updates access token
// - The access token can be accessed while continuousRefreshToken() is trying to update it
func TestContinuousRefreshTokenConcurrency(t *testing.T) {
// The times here are in the order of miliseconds (so they run faster)
Expand Down Expand Up @@ -234,14 +226,6 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("created tokens are equal")
}

// The refresh token used to update the access token
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("failed to create refresh token: %v", err)
}

ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel() // This cancels the refresher goroutine
Expand Down Expand Up @@ -271,8 +255,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("Do call: failed to create additional access token: %v", err)
}
responseBodyStruct := TokenResponseBody{
AccessToken: newAccessToken,
RefreshToken: refreshToken,
AccessToken: newAccessToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
Expand Down Expand Up @@ -308,18 +291,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("Do call: failed to parse body form: %v", err)
}
reqGrantType := req.Form.Get("grant_type")
if reqGrantType != "refresh_token" {
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" {
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType)
}
reqRefreshToken := req.Form.Get("refresh_token")
if reqRefreshToken != refreshToken {
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
}

// Return response with accessTokenSecond
responseBodyStruct := TokenResponseBody{
AccessToken: accessTokenSecond,
RefreshToken: refreshToken,
AccessToken: accessTokenSecond,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
Expand Down Expand Up @@ -409,7 +386,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
}

// Set the token after initialization
err = keyFlow.SetToken(accessTokenFirst, refreshToken)
err = keyFlow.SetToken(accessTokenFirst)
if err != nil {
t.Fatalf("failed to set token: %v", err)
}
Expand Down
66 changes: 3 additions & 63 deletions core/clients/key_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,65 +130,6 @@ func TestKeyFlowInit(t *testing.T) {
}
}

func TestSetToken(t *testing.T) {
tests := []struct {
name string
tokenInvalid bool
refreshToken string
wantErr bool
}{
{
name: "ok",
tokenInvalid: false,
refreshToken: "refresh_token",
wantErr: false,
},
{
name: "invalid_token",
tokenInvalid: true,
refreshToken: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var accessToken string
var err error

timestamp := time.Now().Add(24 * time.Hour)
if tt.tokenInvalid {
accessToken = "foo"
} else {
accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(timestamp)})
accessToken, err = accessTokenJWT.SignedString(testSigningKey)
if err != nil {
t.Fatalf("get test access token as string: %s", err)
}
}

keyFlow := &KeyFlow{}
err = keyFlow.SetToken(accessToken, tt.refreshToken)

if (err != nil) != tt.wantErr {
t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr)
}
if err == nil {
expectedKeyFlowToken := &TokenResponseBody{
AccessToken: accessToken,
ExpiresIn: int(timestamp.Unix()),
RefreshToken: tt.refreshToken,
Scope: defaultScope,
TokenType: defaultTokenType,
}
if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) {
t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token)
}
}
})
}
}

func TestTokenExpired(t *testing.T) {
tokenExpirationLeeway := 5 * time.Second
tests := []struct {
Expand Down Expand Up @@ -442,10 +383,9 @@ func TestKeyFlow_Do(t *testing.T) {
res.Header().Set("Content-Type", "application/json")

token := &TokenResponseBody{
AccessToken: testBearerToken,
ExpiresIn: 2147483647,
RefreshToken: testBearerToken,
TokenType: "Bearer",
AccessToken: testBearerToken,
ExpiresIn: 2147483647,
TokenType: "Bearer",
}

if err := json.NewEncoder(res.Body).Encode(token); err != nil {
Expand Down