Skip to content

Commit af13f52

Browse files
Ryan Kohlercodyoss
authored andcommitted
google: Create AWS V4 Signing Utility
Change-Id: I59b4a13ed0433de7dfaa064a0f7dc1f3dd724518 GitHub-Last-Rev: 8cdc6a9 GitHub-Pull-Request: #467 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/284632 Run-TryBot: Cody Oss <codyoss@google.com> TryBot-Result: Go Bot <gobot@golang.org> Trust: Cody Oss <codyoss@google.com> Trust: Tyler Bui-Palsulich <tbp@google.com> Reviewed-by: Cody Oss <codyoss@google.com>
1 parent d3ed898 commit af13f52

File tree

2 files changed

+626
-0
lines changed

2 files changed

+626
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Copyright 2021 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package externalaccount
6+
7+
import (
8+
"crypto/hmac"
9+
"crypto/sha256"
10+
"encoding/hex"
11+
"errors"
12+
"fmt"
13+
"io"
14+
"io/ioutil"
15+
"net/http"
16+
"path"
17+
"sort"
18+
"strings"
19+
"time"
20+
)
21+
22+
// RequestSigner is a utility class to sign http requests using a AWS V4 signature.
23+
type awsRequestSigner struct {
24+
RegionName string
25+
AwsSecurityCredentials map[string]string
26+
}
27+
28+
const (
29+
// AWS Signature Version 4 signing algorithm identifier.
30+
awsAlgorithm = "AWS4-HMAC-SHA256"
31+
32+
// The termination string for the AWS credential scope value as defined in
33+
// https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html
34+
awsRequestType = "aws4_request"
35+
36+
// The AWS authorization header name for the security session token if available.
37+
awsSecurityTokenHeader = "x-amz-security-token"
38+
39+
// The AWS authorization header name for the auto-generated date.
40+
awsDateHeader = "x-amz-date"
41+
42+
awsTimeFormatLong = "20060102T150405Z"
43+
awsTimeFormatShort = "20060102"
44+
)
45+
46+
func getSha256(input []byte) (string, error) {
47+
hash := sha256.New()
48+
if _, err := hash.Write(input); err != nil {
49+
return "", err
50+
}
51+
return hex.EncodeToString(hash.Sum(nil)), nil
52+
}
53+
54+
func getHmacSha256(key, input []byte) ([]byte, error) {
55+
hash := hmac.New(sha256.New, key)
56+
if _, err := hash.Write(input); err != nil {
57+
return nil, err
58+
}
59+
return hash.Sum(nil), nil
60+
}
61+
62+
func cloneRequest(r *http.Request) *http.Request {
63+
r2 := new(http.Request)
64+
*r2 = *r
65+
if r.Header != nil {
66+
r2.Header = make(http.Header, len(r.Header))
67+
68+
// Find total number of values.
69+
headerCount := 0
70+
for _, headerValues := range r.Header {
71+
headerCount += len(headerValues)
72+
}
73+
copiedHeaders := make([]string, headerCount) // shared backing array for headers' values
74+
75+
for headerKey, headerValues := range r.Header {
76+
headerCount = copy(copiedHeaders, headerValues)
77+
r2.Header[headerKey] = copiedHeaders[:headerCount:headerCount]
78+
copiedHeaders = copiedHeaders[headerCount:]
79+
}
80+
}
81+
return r2
82+
}
83+
84+
func canonicalPath(req *http.Request) string {
85+
result := req.URL.EscapedPath()
86+
if result == "" {
87+
return "/"
88+
}
89+
return path.Clean(result)
90+
}
91+
92+
func canonicalQuery(req *http.Request) string {
93+
queryValues := req.URL.Query()
94+
for queryKey := range queryValues {
95+
sort.Strings(queryValues[queryKey])
96+
}
97+
return queryValues.Encode()
98+
}
99+
100+
func canonicalHeaders(req *http.Request) (string, string) {
101+
// Header keys need to be sorted alphabetically.
102+
var headers []string
103+
lowerCaseHeaders := make(http.Header)
104+
for k, v := range req.Header {
105+
k := strings.ToLower(k)
106+
if _, ok := lowerCaseHeaders[k]; ok {
107+
// include additional values
108+
lowerCaseHeaders[k] = append(lowerCaseHeaders[k], v...)
109+
} else {
110+
headers = append(headers, k)
111+
lowerCaseHeaders[k] = v
112+
}
113+
}
114+
sort.Strings(headers)
115+
116+
var fullHeaders strings.Builder
117+
for _, header := range headers {
118+
headerValue := strings.Join(lowerCaseHeaders[header], ",")
119+
fullHeaders.WriteString(header)
120+
fullHeaders.WriteRune(':')
121+
fullHeaders.WriteString(headerValue)
122+
fullHeaders.WriteRune('\n')
123+
}
124+
125+
return strings.Join(headers, ";"), fullHeaders.String()
126+
}
127+
128+
func requestDataHash(req *http.Request) (string, error) {
129+
var requestData []byte
130+
if req.Body != nil {
131+
requestBody, err := req.GetBody()
132+
if err != nil {
133+
return "", err
134+
}
135+
defer requestBody.Close()
136+
137+
requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
138+
if err != nil {
139+
return "", err
140+
}
141+
}
142+
143+
return getSha256(requestData)
144+
}
145+
146+
func requestHost(req *http.Request) string {
147+
if req.Host != "" {
148+
return req.Host
149+
}
150+
return req.URL.Host
151+
}
152+
153+
func canonicalRequest(req *http.Request, canonicalHeaderColumns, canonicalHeaderData string) (string, error) {
154+
dataHash, err := requestDataHash(req)
155+
if err != nil {
156+
return "", err
157+
}
158+
159+
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", req.Method, canonicalPath(req), canonicalQuery(req), canonicalHeaderData, canonicalHeaderColumns, dataHash), nil
160+
}
161+
162+
// SignRequest adds the appropriate headers to an http.Request
163+
// or returns an error if something prevented this.
164+
func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
165+
signedRequest := cloneRequest(req)
166+
timestamp := now()
167+
168+
signedRequest.Header.Add("host", requestHost(req))
169+
170+
if securityToken, ok := rs.AwsSecurityCredentials["security_token"]; ok {
171+
signedRequest.Header.Add(awsSecurityTokenHeader, securityToken)
172+
}
173+
174+
if signedRequest.Header.Get("date") == "" {
175+
signedRequest.Header.Add(awsDateHeader, timestamp.Format(awsTimeFormatLong))
176+
}
177+
178+
authorizationCode, err := rs.generateAuthentication(signedRequest, timestamp)
179+
if err != nil {
180+
return err
181+
}
182+
signedRequest.Header.Set("Authorization", authorizationCode)
183+
184+
req.Header = signedRequest.Header
185+
return nil
186+
}
187+
188+
func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp time.Time) (string, error) {
189+
secretAccessKey, ok := rs.AwsSecurityCredentials["secret_access_key"]
190+
if !ok {
191+
return "", errors.New("oauth2/google: missing secret_access_key header")
192+
}
193+
accessKeyId, ok := rs.AwsSecurityCredentials["access_key_id"]
194+
if !ok {
195+
return "", errors.New("oauth2/google: missing access_key_id header")
196+
}
197+
198+
canonicalHeaderColumns, canonicalHeaderData := canonicalHeaders(req)
199+
200+
dateStamp := timestamp.Format(awsTimeFormatShort)
201+
serviceName := ""
202+
if splitHost := strings.Split(requestHost(req), "."); len(splitHost) > 0 {
203+
serviceName = splitHost[0]
204+
}
205+
206+
credentialScope := fmt.Sprintf("%s/%s/%s/%s",dateStamp, rs.RegionName, serviceName, awsRequestType)
207+
208+
requestString, err := canonicalRequest(req, canonicalHeaderColumns, canonicalHeaderData)
209+
if err != nil {
210+
return "", err
211+
}
212+
requestHash, err := getSha256([]byte(requestString))
213+
if err != nil{
214+
return "", err
215+
}
216+
217+
stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", awsAlgorithm, timestamp.Format(awsTimeFormatLong), credentialScope, requestHash)
218+
219+
signingKey := []byte("AWS4" + secretAccessKey)
220+
for _, signingInput := range []string{
221+
dateStamp, rs.RegionName, serviceName, awsRequestType, stringToSign,
222+
} {
223+
signingKey, err = getHmacSha256(signingKey, []byte(signingInput))
224+
if err != nil{
225+
return "", err
226+
}
227+
}
228+
229+
return fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", awsAlgorithm, accessKeyId, credentialScope, canonicalHeaderColumns, hex.EncodeToString(signingKey)), nil
230+
}

0 commit comments

Comments
 (0)