Skip to content

Commit 0101308

Browse files
Ryan Kohlercodyoss
authored andcommitted
google: support AWS 3rd party credentials
Change-Id: I655b38f7fb8023866bb284c7ce80ab9888682e73 GitHub-Last-Rev: 648f0b3 GitHub-Pull-Request: #471 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/287752 Reviewed-by: Cody Oss <codyoss@google.com> Run-TryBot: Cody Oss <codyoss@google.com> TryBot-Result: Go Bot <gobot@golang.org> Trust: Tyler Bui-Palsulich <tbp@google.com> Trust: Cody Oss <codyoss@google.com>
1 parent f9ce19e commit 0101308

File tree

5 files changed

+779
-53
lines changed

5 files changed

+779
-53
lines changed

google/internal/externalaccount/aws.go

Lines changed: 258 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,54 @@
55
package externalaccount
66

77
import (
8+
"context"
89
"crypto/hmac"
910
"crypto/sha256"
1011
"encoding/hex"
12+
"encoding/json"
1113
"errors"
1214
"fmt"
15+
"golang.org/x/oauth2"
1316
"io"
1417
"io/ioutil"
1518
"net/http"
19+
"os"
1620
"path"
1721
"sort"
1822
"strings"
1923
"time"
2024
)
2125

22-
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
26+
type awsSecurityCredentials struct {
27+
AccessKeyID string `json:"AccessKeyID"`
28+
SecretAccessKey string `json:"SecretAccessKey"`
29+
SecurityToken string `json:"Token"`
30+
}
31+
32+
// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
2333
type awsRequestSigner struct {
2434
RegionName string
25-
AwsSecurityCredentials map[string]string
35+
AwsSecurityCredentials awsSecurityCredentials
2636
}
2737

38+
// getenv aliases os.Getenv for testing
39+
var getenv = os.Getenv
40+
2841
const (
29-
// AWS Signature Version 4 signing algorithm identifier.
42+
// AWS Signature Version 4 signing algorithm identifier.
3043
awsAlgorithm = "AWS4-HMAC-SHA256"
3144

32-
// The termination string for the AWS credential scope value as defined in
33-
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
45+
// The termination string for the AWS credential scope value as defined in
46+
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
3447
awsRequestType = "aws4_request"
3548

36-
// The AWS authorization header name for the security session token if available.
49+
// The AWS authorization header name for the security session token if available.
3750
awsSecurityTokenHeader = "x-amz-security-token"
3851

39-
// The AWS authorization header name for the auto-generated date.
52+
// The AWS authorization header name for the auto-generated date.
4053
awsDateHeader = "x-amz-date"
4154

42-
awsTimeFormatLong = "20060102T150405Z"
55+
awsTimeFormatLong = "20060102T150405Z"
4356
awsTimeFormatShort = "20060102"
4457
)
4558

@@ -167,8 +180,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
167180

168181
signedRequest.Header.Add("host", requestHost(req))
169182

170-
if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
171-
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
183+
if rs.AwsSecurityCredentials.SecurityToken != "" {
184+
signedRequest.Header.Add(awsSecurityTokenHeader, rs.AwsSecurityCredentials.SecurityToken)
172185
}
173186

174187
if signedRequest.Header.Get("date") == "" {
@@ -186,15 +199,6 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
186199
}
187200

188201
func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
189-
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
190-
if !ok {
191-
return "", errors.New("oauth2/google: missing secret_access_key header")
192-
}
193-
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
194-
if !ok {
195-
return "", errors.New("oauth2/google: missing access_key_id header")
196-
}
197-
198202
canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
199203

200204
dateStamp := timestamp.Format(awsTimeFormatShort)
@@ -203,28 +207,258 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
203207
serviceName = splitHost[0]
204208
}
205209

206-
credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
210+
credentialScope := fmt.Sprintf("%s/%s/%s/%s", dateStamp, rs.RegionName, serviceName, awsRequestType)
207211

208212
requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
209213
if err != nil {
210214
return "", err
211215
}
212216
requestHash, err := getSha256([]byte(requestString))
213-
if err != nil{
217+
if err != nil {
214218
return "", err
215219
}
216220

217221
stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
218222

219-
signingKey := []byte("AWS4" + secretAccessKey)
223+
signingKey := []byte("AWS4" + rs.AwsSecurityCredentials.SecretAccessKey)
220224
for _, signingInput := range []string{
221225
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
222226
} {
223227
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
224-
if err != nil{
228+
if err != nil {
229+
return "", err
230+
}
231+
}
232+
233+
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, rs.AwsSecurityCredentials.AccessKeyID, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
234+
}
235+
236+
type awsCredentialSource struct {
237+
EnvironmentID string
238+
RegionURL string
239+
RegionalCredVerificationURL string
240+
CredVerificationURL string
241+
TargetResource string
242+
requestSigner *awsRequestSigner
243+
region string
244+
ctx context.Context
245+
client *http.Client
246+
}
247+
248+
type awsRequestHeader struct {
249+
Key string `json:"key"`
250+
Value string `json:"value"`
251+
}
252+
253+
type awsRequest struct {
254+
URL string `json:"url"`
255+
Method string `json:"method"`
256+
Headers []awsRequestHeader `json:"headers"`
257+
}
258+
259+
func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, error) {
260+
if cs.client == nil {
261+
cs.client = oauth2.NewClient(cs.ctx, nil)
262+
}
263+
return cs.client.Do(req.WithContext(cs.ctx))
264+
}
265+
266+
func (cs awsCredentialSource) subjectToken() (string, error) {
267+
if cs.requestSigner == nil {
268+
awsSecurityCredentials, err := cs.getSecurityCredentials()
269+
if err != nil {
225270
return "", err
226271
}
272+
273+
if cs.region, err = cs.getRegion(); err != nil {
274+
return "", err
275+
}
276+
277+
cs.requestSigner = &awsRequestSigner{
278+
RegionName: cs.region,
279+
AwsSecurityCredentials: awsSecurityCredentials,
280+
}
281+
}
282+
283+
// Generate the signed request to AWS STS GetCallerIdentity API.
284+
// Use the required regional endpoint. Otherwise, the request will fail.
285+
req, err := http.NewRequest("POST", strings.Replace(cs.RegionalCredVerificationURL, "{region}", cs.region, 1), nil)
286+
if err != nil {
287+
return "", err
288+
}
289+
// The full, canonical resource name of the workload identity pool
290+
// provider, with or without the HTTPS prefix.
291+
// Including this header as part of the signature is recommended to
292+
// ensure data integrity.
293+
if cs.TargetResource != "" {
294+
req.Header.Add("x-goog-cloud-target-resource", cs.TargetResource)
295+
}
296+
cs.requestSigner.SignRequest(req)
297+
298+
/*
299+
The GCP STS endpoint expects the headers to be formatted as:
300+
# [
301+
# {key: 'x-amz-date', value: '...'},
302+
# {key: 'Authorization', value: '...'},
303+
# ...
304+
# ]
305+
# And then serialized as:
306+
# quote(json.dumps({
307+
# url: '...',
308+
# method: 'POST',
309+
# headers: [{key: 'x-amz-date', value: '...'}, ...]
310+
# }))
311+
*/
312+
313+
awsSignedReq := awsRequest{
314+
URL: req.URL.String(),
315+
Method: "POST",
316+
}
317+
for headerKey, headerList := range req.Header {
318+
for _, headerValue := range headerList {
319+
awsSignedReq.Headers = append(awsSignedReq.Headers, awsRequestHeader{
320+
Key: headerKey,
321+
Value: headerValue,
322+
})
323+
}
324+
}
325+
sort.Slice(awsSignedReq.Headers, func(i, j int) bool {
326+
headerCompare := strings.Compare(awsSignedReq.Headers[i].Key, awsSignedReq.Headers[j].Key)
327+
if headerCompare == 0 {
328+
return strings.Compare(awsSignedReq.Headers[i].Value, awsSignedReq.Headers[j].Value) < 0
329+
}
330+
return headerCompare < 0
331+
})
332+
333+
result, err := json.Marshal(awsSignedReq)
334+
if err != nil {
335+
return "", err
336+
}
337+
return string(result), nil
338+
}
339+
340+
func (cs *awsCredentialSource) getRegion() (string, error) {
341+
if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" {
342+
return envAwsRegion, nil
343+
}
344+
345+
if cs.RegionURL == "" {
346+
return "", errors.New("oauth2/google: unable to determine AWS region")
347+
}
348+
349+
req, err := http.NewRequest("GET", cs.RegionURL, nil)
350+
if err != nil {
351+
return "", err
352+
}
353+
354+
resp, err := cs.doRequest(req)
355+
if err != nil {
356+
return "", err
357+
}
358+
defer resp.Body.Close()
359+
360+
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
361+
if err != nil {
362+
return "", err
363+
}
364+
365+
if resp.StatusCode != 200 {
366+
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS region - %s", string(respBody))
367+
}
368+
369+
// This endpoint will return the region in format: us-east-2b.
370+
// Only the us-east-2 part should be used.
371+
respBodyEnd := 0
372+
if len(respBody) > 1 {
373+
respBodyEnd = len(respBody) - 1
374+
}
375+
return string(respBody[:respBodyEnd]), nil
376+
}
377+
378+
func (cs *awsCredentialSource) getSecurityCredentials() (result awsSecurityCredentials, err error) {
379+
if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" {
380+
if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" {
381+
return awsSecurityCredentials{
382+
AccessKeyID: accessKeyID,
383+
SecretAccessKey: secretAccessKey,
384+
SecurityToken: getenv("AWS_SESSION_TOKEN"),
385+
}, nil
386+
}
387+
}
388+
389+
roleName, err := cs.getMetadataRoleName()
390+
if err != nil {
391+
return
392+
}
393+
394+
credentials, err := cs.getMetadataSecurityCredentials(roleName)
395+
if err != nil {
396+
return
397+
}
398+
399+
if credentials.AccessKeyID == "" {
400+
return result, errors.New("oauth2/google: missing AccessKeyId credential")
401+
}
402+
403+
if credentials.SecretAccessKey == "" {
404+
return result, errors.New("oauth2/google: missing SecretAccessKey credential")
405+
}
406+
407+
return credentials, nil
408+
}
409+
410+
func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string) (awsSecurityCredentials, error) {
411+
var result awsSecurityCredentials
412+
413+
req, err := http.NewRequest("GET", fmt.Sprintf("%s/%s", cs.CredVerificationURL, roleName), nil)
414+
if err != nil {
415+
return result, err
416+
}
417+
req.Header.Add("Content-Type", "application/json")
418+
419+
resp, err := cs.doRequest(req)
420+
if err != nil {
421+
return result, err
422+
}
423+
defer resp.Body.Close()
424+
425+
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
426+
if err != nil {
427+
return result, err
428+
}
429+
430+
if resp.StatusCode != 200 {
431+
return result, fmt.Errorf("oauth2/google: unable to retrieve AWS security credentials - %s", string(respBody))
432+
}
433+
434+
err = json.Unmarshal(respBody, &result)
435+
return result, err
436+
}
437+
438+
func (cs *awsCredentialSource) getMetadataRoleName() (string, error) {
439+
if cs.CredVerificationURL == "" {
440+
return "", errors.New("oauth2/google: unable to determine the AWS metadata server security credentials endpoint")
441+
}
442+
443+
req, err := http.NewRequest("GET", cs.CredVerificationURL, nil)
444+
if err != nil {
445+
return "", err
446+
}
447+
448+
resp, err := cs.doRequest(req)
449+
if err != nil {
450+
return "", err
451+
}
452+
defer resp.Body.Close()
453+
454+
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
455+
if err != nil {
456+
return "", err
457+
}
458+
459+
if resp.StatusCode != 200 {
460+
return "", fmt.Errorf("oauth2/google: unable to retrieve AWS role name - %s", string(respBody))
227461
}
228462

229-
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
463+
return string(respBody), nil
230464
}

0 commit comments

Comments
 (0)