55package externalaccount
66
77import (
8+ "context"
89 "crypto/hmac"
910 "crypto/sha256"
1011 "encoding/hex"
12+ "encoding/json"
1113 "errors"
1214 "fmt"
15+ "golang.org/x/oauth2"
1316 "io"
1417 "io/ioutil"
1518 "net/http"
19+ "os"
1620 "path"
1721 "sort"
1822 "strings"
1923 "time"
2024)
2125
22- // RequestSigner is a utility class to sign http requests using a AWS V4 signature.
26+ type awsSecurityCredentials struct {
27+ AccessKeyID string `json:"AccessKeyID"`
28+ SecretAccessKey string `json:"SecretAccessKey"`
29+ SecurityToken string `json:"Token"`
30+ }
31+
32+ // awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
2333type awsRequestSigner struct {
2434 RegionName string
25- AwsSecurityCredentials map [ string ] string
35+ AwsSecurityCredentials awsSecurityCredentials
2636}
2737
38+ // getenv aliases os.Getenv for testing
39+ var getenv = os .Getenv
40+
2841const (
29- // AWS Signature Version 4 signing algorithm identifier.
42+ // AWS Signature Version 4 signing algorithm identifier.
3043 awsAlgorithm = "AWS4-HMAC-SHA256"
3144
32- // The termination string for the AWS credential scope value as defined in
33- // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
45+ // The termination string for the AWS credential scope value as defined in
46+ // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
3447 awsRequestType = "aws4_request"
3548
36- // The AWS authorization header name for the security session token if available.
49+ // The AWS authorization header name for the security session token if available.
3750 awsSecurityTokenHeader = "x-amz-security-token"
3851
39- // The AWS authorization header name for the auto-generated date.
52+ // The AWS authorization header name for the auto-generated date.
4053 awsDateHeader = "x-amz-date"
4154
42- awsTimeFormatLong = "20060102T150405Z"
55+ awsTimeFormatLong = "20060102T150405Z"
4356 awsTimeFormatShort = "20060102"
4457)
4558
@@ -167,8 +180,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
167180
168181 signedRequest .Header .Add ("host" , requestHost (req ))
169182
170- if securityToken , ok := rs .AwsSecurityCredentials [ "security_token" ]; ok {
171- signedRequest .Header .Add (awsSecurityTokenHeader , securityToken )
183+ if rs .AwsSecurityCredentials . SecurityToken != "" {
184+ signedRequest .Header .Add (awsSecurityTokenHeader , rs . AwsSecurityCredentials . SecurityToken )
172185 }
173186
174187 if signedRequest .Header .Get ("date" ) == "" {
@@ -186,15 +199,6 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
186199}
187200
188201func (rs * awsRequestSigner ) generateAuthentication (req * http.Request , timestamp time.Time ) (string , error ) {
189- secretAccessKey , ok := rs .AwsSecurityCredentials ["secret_access_key" ]
190- if ! ok {
191- return "" , errors .New ("oauth2/google: missing secret_access_key header" )
192- }
193- accessKeyId , ok := rs .AwsSecurityCredentials ["access_key_id" ]
194- if ! ok {
195- return "" , errors .New ("oauth2/google: missing access_key_id header" )
196- }
197-
198202 canonicalHeaderColumns , canonicalHeaderData := canonicalHeaders (req )
199203
200204 dateStamp := timestamp .Format (awsTimeFormatShort )
@@ -203,28 +207,258 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
203207 serviceName = splitHost [0 ]
204208 }
205209
206- credentialScope := fmt .Sprintf ("%s/%s/%s/%s" ,dateStamp , rs .RegionName , serviceName , awsRequestType )
210+ credentialScope := fmt .Sprintf ("%s/%s/%s/%s" , dateStamp , rs .RegionName , serviceName , awsRequestType )
207211
208212 requestString , err := canonicalRequest (req , canonicalHeaderColumns , canonicalHeaderData )
209213 if err != nil {
210214 return "" , err
211215 }
212216 requestHash , err := getSha256 ([]byte (requestString ))
213- if err != nil {
217+ if err != nil {
214218 return "" , err
215219 }
216220
217221 stringToSign := fmt .Sprintf ("%s\n %s\n %s\n %s" , awsAlgorithm , timestamp .Format (awsTimeFormatLong ), credentialScope , requestHash )
218222
219- signingKey := []byte ("AWS4" + secretAccessKey )
223+ signingKey := []byte ("AWS4" + rs . AwsSecurityCredentials . SecretAccessKey )
220224 for _ , signingInput := range []string {
221225 dateStamp , rs .RegionName , serviceName , awsRequestType , stringToSign ,
222226 } {
223227 signingKey , err = getHmacSha256 (signingKey , []byte (signingInput ))
224- if err != nil {
228+ if err != nil {
229+ return "" , err
230+ }
231+ }
232+
233+ return fmt .Sprintf ("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s" , awsAlgorithm , rs .AwsSecurityCredentials .AccessKeyID , credentialScope , canonicalHeaderColumns , hex .EncodeToString (signingKey )), nil
234+ }
235+
236+ type awsCredentialSource struct {
237+ EnvironmentID string
238+ RegionURL string
239+ RegionalCredVerificationURL string
240+ CredVerificationURL string
241+ TargetResource string
242+ requestSigner * awsRequestSigner
243+ region string
244+ ctx context.Context
245+ client * http.Client
246+ }
247+
248+ type awsRequestHeader struct {
249+ Key string `json:"key"`
250+ Value string `json:"value"`
251+ }
252+
253+ type awsRequest struct {
254+ URL string `json:"url"`
255+ Method string `json:"method"`
256+ Headers []awsRequestHeader `json:"headers"`
257+ }
258+
259+ func (cs awsCredentialSource ) doRequest (req * http.Request ) (* http.Response , error ) {
260+ if cs .client == nil {
261+ cs .client = oauth2 .NewClient (cs .ctx , nil )
262+ }
263+ return cs .client .Do (req .WithContext (cs .ctx ))
264+ }
265+
266+ func (cs awsCredentialSource ) subjectToken () (string , error ) {
267+ if cs .requestSigner == nil {
268+ awsSecurityCredentials , err := cs .getSecurityCredentials ()
269+ if err != nil {
225270 return "" , err
226271 }
272+
273+ if cs .region , err = cs .getRegion (); err != nil {
274+ return "" , err
275+ }
276+
277+ cs .requestSigner = & awsRequestSigner {
278+ RegionName : cs .region ,
279+ AwsSecurityCredentials : awsSecurityCredentials ,
280+ }
281+ }
282+
283+ // Generate the signed request to AWS STS GetCallerIdentity API.
284+ // Use the required regional endpoint. Otherwise, the request will fail.
285+ req , err := http .NewRequest ("POST" , strings .Replace (cs .RegionalCredVerificationURL , "{region}" , cs .region , 1 ), nil )
286+ if err != nil {
287+ return "" , err
288+ }
289+ // The full, canonical resource name of the workload identity pool
290+ // provider, with or without the HTTPS prefix.
291+ // Including this header as part of the signature is recommended to
292+ // ensure data integrity.
293+ if cs .TargetResource != "" {
294+ req .Header .Add ("x-goog-cloud-target-resource" , cs .TargetResource )
295+ }
296+ cs .requestSigner .SignRequest (req )
297+
298+ /*
299+ The GCP STS endpoint expects the headers to be formatted as:
300+ # [
301+ # {key: 'x-amz-date', value: '...'},
302+ # {key: 'Authorization', value: '...'},
303+ # ...
304+ # ]
305+ # And then serialized as:
306+ # quote(json.dumps({
307+ # url: '...',
308+ # method: 'POST',
309+ # headers: [{key: 'x-amz-date', value: '...'}, ...]
310+ # }))
311+ */
312+
313+ awsSignedReq := awsRequest {
314+ URL : req .URL .String (),
315+ Method : "POST" ,
316+ }
317+ for headerKey , headerList := range req .Header {
318+ for _ , headerValue := range headerList {
319+ awsSignedReq .Headers = append (awsSignedReq .Headers , awsRequestHeader {
320+ Key : headerKey ,
321+ Value : headerValue ,
322+ })
323+ }
324+ }
325+ sort .Slice (awsSignedReq .Headers , func (i , j int ) bool {
326+ headerCompare := strings .Compare (awsSignedReq .Headers [i ].Key , awsSignedReq .Headers [j ].Key )
327+ if headerCompare == 0 {
328+ return strings .Compare (awsSignedReq .Headers [i ].Value , awsSignedReq .Headers [j ].Value ) < 0
329+ }
330+ return headerCompare < 0
331+ })
332+
333+ result , err := json .Marshal (awsSignedReq )
334+ if err != nil {
335+ return "" , err
336+ }
337+ return string (result ), nil
338+ }
339+
340+ func (cs * awsCredentialSource ) getRegion () (string , error ) {
341+ if envAwsRegion := getenv ("AWS_REGION" ); envAwsRegion != "" {
342+ return envAwsRegion , nil
343+ }
344+
345+ if cs .RegionURL == "" {
346+ return "" , errors .New ("oauth2/google: unable to determine AWS region" )
347+ }
348+
349+ req , err := http .NewRequest ("GET" , cs .RegionURL , nil )
350+ if err != nil {
351+ return "" , err
352+ }
353+
354+ resp , err := cs .doRequest (req )
355+ if err != nil {
356+ return "" , err
357+ }
358+ defer resp .Body .Close ()
359+
360+ respBody , err := ioutil .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
361+ if err != nil {
362+ return "" , err
363+ }
364+
365+ if resp .StatusCode != 200 {
366+ return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS region - %s" , string (respBody ))
367+ }
368+
369+ // This endpoint will return the region in format: us-east-2b.
370+ // Only the us-east-2 part should be used.
371+ respBodyEnd := 0
372+ if len (respBody ) > 1 {
373+ respBodyEnd = len (respBody ) - 1
374+ }
375+ return string (respBody [:respBodyEnd ]), nil
376+ }
377+
378+ func (cs * awsCredentialSource ) getSecurityCredentials () (result awsSecurityCredentials , err error ) {
379+ if accessKeyID := getenv ("AWS_ACCESS_KEY_ID" ); accessKeyID != "" {
380+ if secretAccessKey := getenv ("AWS_SECRET_ACCESS_KEY" ); secretAccessKey != "" {
381+ return awsSecurityCredentials {
382+ AccessKeyID : accessKeyID ,
383+ SecretAccessKey : secretAccessKey ,
384+ SecurityToken : getenv ("AWS_SESSION_TOKEN" ),
385+ }, nil
386+ }
387+ }
388+
389+ roleName , err := cs .getMetadataRoleName ()
390+ if err != nil {
391+ return
392+ }
393+
394+ credentials , err := cs .getMetadataSecurityCredentials (roleName )
395+ if err != nil {
396+ return
397+ }
398+
399+ if credentials .AccessKeyID == "" {
400+ return result , errors .New ("oauth2/google: missing AccessKeyId credential" )
401+ }
402+
403+ if credentials .SecretAccessKey == "" {
404+ return result , errors .New ("oauth2/google: missing SecretAccessKey credential" )
405+ }
406+
407+ return credentials , nil
408+ }
409+
410+ func (cs * awsCredentialSource ) getMetadataSecurityCredentials (roleName string ) (awsSecurityCredentials , error ) {
411+ var result awsSecurityCredentials
412+
413+ req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/%s" , cs .CredVerificationURL , roleName ), nil )
414+ if err != nil {
415+ return result , err
416+ }
417+ req .Header .Add ("Content-Type" , "application/json" )
418+
419+ resp , err := cs .doRequest (req )
420+ if err != nil {
421+ return result , err
422+ }
423+ defer resp .Body .Close ()
424+
425+ respBody , err := ioutil .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
426+ if err != nil {
427+ return result , err
428+ }
429+
430+ if resp .StatusCode != 200 {
431+ return result , fmt .Errorf ("oauth2/google: unable to retrieve AWS security credentials - %s" , string (respBody ))
432+ }
433+
434+ err = json .Unmarshal (respBody , & result )
435+ return result , err
436+ }
437+
438+ func (cs * awsCredentialSource ) getMetadataRoleName () (string , error ) {
439+ if cs .CredVerificationURL == "" {
440+ return "" , errors .New ("oauth2/google: unable to determine the AWS metadata server security credentials endpoint" )
441+ }
442+
443+ req , err := http .NewRequest ("GET" , cs .CredVerificationURL , nil )
444+ if err != nil {
445+ return "" , err
446+ }
447+
448+ resp , err := cs .doRequest (req )
449+ if err != nil {
450+ return "" , err
451+ }
452+ defer resp .Body .Close ()
453+
454+ respBody , err := ioutil .ReadAll (io .LimitReader (resp .Body , 1 << 20 ))
455+ if err != nil {
456+ return "" , err
457+ }
458+
459+ if resp .StatusCode != 200 {
460+ return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS role name - %s" , string (respBody ))
227461 }
228462
229- return fmt . Sprintf ( "%s Credential=%s/%s, SignedHeaders=%s, Signature=%s" , awsAlgorithm , accessKeyId , credentialScope , canonicalHeaderColumns , hex . EncodeToString ( signingKey ) ), nil
463+ return string ( respBody ), nil
230464}
0 commit comments