|
| 1 | +// Copyright 2011 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package httptest |
| 6 | + |
| 7 | +import ( |
| 8 | + "bytes" |
| 9 | + "fmt" |
| 10 | + "io" |
| 11 | + "net/http" |
| 12 | + "net/textproto" |
| 13 | + "strconv" |
| 14 | + "strings" |
| 15 | +) |
| 16 | + |
| 17 | +// ResponseRecorder is an implementation of http.ResponseWriter that |
| 18 | +// records its mutations for later inspection in tests. |
| 19 | +type ResponseRecorder struct { |
| 20 | + // Code is the HTTP response code set by WriteHeader. |
| 21 | + // |
| 22 | + // Note that if a Handler never calls WriteHeader or Write, |
| 23 | + // this might end up being 0, rather than the implicit |
| 24 | + // http.StatusOK. To get the implicit value, use the Result |
| 25 | + // method. |
| 26 | + Code int |
| 27 | + |
| 28 | + // HeaderMap contains the headers explicitly set by the Handler. |
| 29 | + // It is an internal detail. |
| 30 | + // |
| 31 | + // Deprecated: HeaderMap exists for historical compatibility |
| 32 | + // and should not be used. To access the headers returned by a handler, |
| 33 | + // use the Response.Header map as returned by the Result method. |
| 34 | + HeaderMap http.Header |
| 35 | + |
| 36 | + // Body is the buffer to which the Handler's Write calls are sent. |
| 37 | + // If nil, the Writes are silently discarded. |
| 38 | + Body *bytes.Buffer |
| 39 | + |
| 40 | + // Flushed is whether the Handler called Flush. |
| 41 | + Flushed bool |
| 42 | + |
| 43 | + result *http.Response // cache of Result's return value |
| 44 | + snapHeader http.Header // snapshot of HeaderMap at first Write |
| 45 | + wroteHeader bool |
| 46 | +} |
| 47 | + |
| 48 | +// NewRecorder returns an initialized ResponseRecorder. |
| 49 | +func NewRecorder() *ResponseRecorder { |
| 50 | + return &ResponseRecorder{ |
| 51 | + HeaderMap: make(http.Header), |
| 52 | + Body: new(bytes.Buffer), |
| 53 | + Code: 200, |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if |
| 58 | +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. |
| 59 | +const DefaultRemoteAddr = "1.2.3.4" |
| 60 | + |
| 61 | +// Header implements http.ResponseWriter. It returns the response |
| 62 | +// headers to mutate within a handler. To test the headers that were |
| 63 | +// written after a handler completes, use the Result method and see |
| 64 | +// the returned Response value's Header. |
| 65 | +func (rw *ResponseRecorder) Header() http.Header { |
| 66 | + m := rw.HeaderMap |
| 67 | + if m == nil { |
| 68 | + m = make(http.Header) |
| 69 | + rw.HeaderMap = m |
| 70 | + } |
| 71 | + return m |
| 72 | +} |
| 73 | + |
| 74 | +// writeHeader writes a header if it was not written yet and |
| 75 | +// detects Content-Type if needed. |
| 76 | +// |
| 77 | +// bytes or str are the beginning of the response body. |
| 78 | +// We pass both to avoid unnecessarily generate garbage |
| 79 | +// in rw.WriteString which was created for performance reasons. |
| 80 | +// Non-nil bytes win. |
| 81 | +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { |
| 82 | + if rw.wroteHeader { |
| 83 | + return |
| 84 | + } |
| 85 | + if len(str) > 512 { |
| 86 | + str = str[:512] |
| 87 | + } |
| 88 | + |
| 89 | + m := rw.Header() |
| 90 | + |
| 91 | + _, hasType := m["Content-Type"] |
| 92 | + hasTE := m.Get("Transfer-Encoding") != "" |
| 93 | + if !hasType && !hasTE { |
| 94 | + if b == nil { |
| 95 | + b = []byte(str) |
| 96 | + } |
| 97 | + m.Set("Content-Type", http.DetectContentType(b)) |
| 98 | + } |
| 99 | + |
| 100 | + rw.WriteHeader(200) |
| 101 | +} |
| 102 | + |
| 103 | +// Write implements http.ResponseWriter. The data in buf is written to |
| 104 | +// rw.Body, if not nil. |
| 105 | +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { |
| 106 | + rw.writeHeader(buf, "") |
| 107 | + if rw.Body != nil { |
| 108 | + rw.Body.Write(buf) |
| 109 | + } |
| 110 | + return len(buf), nil |
| 111 | +} |
| 112 | + |
| 113 | +// WriteString implements io.StringWriter. The data in str is written |
| 114 | +// to rw.Body, if not nil. |
| 115 | +func (rw *ResponseRecorder) WriteString(str string) (int, error) { |
| 116 | + rw.writeHeader(nil, str) |
| 117 | + if rw.Body != nil { |
| 118 | + rw.Body.WriteString(str) |
| 119 | + } |
| 120 | + return len(str), nil |
| 121 | +} |
| 122 | + |
| 123 | +func checkWriteHeaderCode(code int) { |
| 124 | + // Issue 22880: require valid WriteHeader status codes. |
| 125 | + // For now we only enforce that it's three digits. |
| 126 | + // In the future we might block things over 599 (600 and above aren't defined |
| 127 | + // at https://httpwg.org/specs/rfc7231.html#status.codes) |
| 128 | + // and we might block under 200 (once we have more mature 1xx support). |
| 129 | + // But for now any three digits. |
| 130 | + // |
| 131 | + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's |
| 132 | + // no equivalent bogus thing we can realistically send in HTTP/2, |
| 133 | + // so we'll consistently panic instead and help people find their bugs |
| 134 | + // early. (We can't return an error from WriteHeader even if we wanted to.) |
| 135 | + if code < 100 || code > 999 { |
| 136 | + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +// WriteHeader implements http.ResponseWriter. |
| 141 | +func (rw *ResponseRecorder) WriteHeader(code int) { |
| 142 | + if rw.wroteHeader { |
| 143 | + return |
| 144 | + } |
| 145 | + |
| 146 | + checkWriteHeaderCode(code) |
| 147 | + rw.Code = code |
| 148 | + rw.wroteHeader = true |
| 149 | + if rw.HeaderMap == nil { |
| 150 | + rw.HeaderMap = make(http.Header) |
| 151 | + } |
| 152 | + rw.snapHeader = rw.HeaderMap.Clone() |
| 153 | +} |
| 154 | + |
| 155 | +// Flush implements http.Flusher. To test whether Flush was |
| 156 | +// called, see rw.Flushed. |
| 157 | +func (rw *ResponseRecorder) Flush() { |
| 158 | + if !rw.wroteHeader { |
| 159 | + rw.WriteHeader(200) |
| 160 | + } |
| 161 | + rw.Flushed = true |
| 162 | +} |
| 163 | + |
| 164 | +// Result returns the response generated by the handler. |
| 165 | +// |
| 166 | +// The returned Response will have at least its StatusCode, |
| 167 | +// Header, Body, and optionally Trailer populated. |
| 168 | +// More fields may be populated in the future, so callers should |
| 169 | +// not DeepEqual the result in tests. |
| 170 | +// |
| 171 | +// The Response.Header is a snapshot of the headers at the time of the |
| 172 | +// first write call, or at the time of this call, if the handler never |
| 173 | +// did a write. |
| 174 | +// |
| 175 | +// The Response.Body is guaranteed to be non-nil and Body.Read call is |
| 176 | +// guaranteed to not return any error other than io.EOF. |
| 177 | +// |
| 178 | +// Result must only be called after the handler has finished running. |
| 179 | +func (rw *ResponseRecorder) Result() *http.Response { |
| 180 | + if rw.result != nil { |
| 181 | + return rw.result |
| 182 | + } |
| 183 | + if rw.snapHeader == nil { |
| 184 | + rw.snapHeader = rw.HeaderMap.Clone() |
| 185 | + } |
| 186 | + res := &http.Response{ |
| 187 | + Proto: "HTTP/1.1", |
| 188 | + ProtoMajor: 1, |
| 189 | + ProtoMinor: 1, |
| 190 | + StatusCode: rw.Code, |
| 191 | + Header: rw.snapHeader, |
| 192 | + } |
| 193 | + rw.result = res |
| 194 | + if res.StatusCode == 0 { |
| 195 | + res.StatusCode = 200 |
| 196 | + } |
| 197 | + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) |
| 198 | + if rw.Body != nil { |
| 199 | + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) |
| 200 | + } else { |
| 201 | + res.Body = http.NoBody |
| 202 | + } |
| 203 | + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) |
| 204 | + |
| 205 | + if trailers, ok := rw.snapHeader["Trailer"]; ok { |
| 206 | + res.Trailer = make(http.Header, len(trailers)) |
| 207 | + for _, k := range trailers { |
| 208 | + for _, k := range strings.Split(k, ",") { |
| 209 | + k = http.CanonicalHeaderKey(textproto.TrimString(k)) |
| 210 | + if !validTrailerHeader(k) { |
| 211 | + // Ignore since forbidden by RFC 7230, section 4.1.2. |
| 212 | + continue |
| 213 | + } |
| 214 | + vv, ok := rw.HeaderMap[k] |
| 215 | + if !ok { |
| 216 | + continue |
| 217 | + } |
| 218 | + vv2 := make([]string, len(vv)) |
| 219 | + copy(vv2, vv) |
| 220 | + res.Trailer[k] = vv2 |
| 221 | + } |
| 222 | + } |
| 223 | + } |
| 224 | + for k, vv := range rw.HeaderMap { |
| 225 | + if !strings.HasPrefix(k, http.TrailerPrefix) { |
| 226 | + continue |
| 227 | + } |
| 228 | + if res.Trailer == nil { |
| 229 | + res.Trailer = make(http.Header) |
| 230 | + } |
| 231 | + for _, v := range vv { |
| 232 | + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) |
| 233 | + } |
| 234 | + } |
| 235 | + return res |
| 236 | +} |
| 237 | + |
| 238 | +// parseContentLength trims whitespace from s and returns -1 if no value |
| 239 | +// is set, or the value if it's >= 0. |
| 240 | +// |
| 241 | +// This a modified version of same function found in net/http/transfer.go. This |
| 242 | +// one just ignores an invalid header. |
| 243 | +func parseContentLength(cl string) int64 { |
| 244 | + cl = textproto.TrimString(cl) |
| 245 | + if cl == "" { |
| 246 | + return -1 |
| 247 | + } |
| 248 | + n, err := strconv.ParseUint(cl, 10, 63) |
| 249 | + if err != nil { |
| 250 | + return -1 |
| 251 | + } |
| 252 | + return int64(n) |
| 253 | +} |
| 254 | + |
| 255 | +// ValidTrailerHeader reports whether name is a valid header field name to appear |
| 256 | +// in trailers. |
| 257 | +// See RFC 7230, Section 4.1.2 |
| 258 | +// Copied from golang.org/x/net/http/httpguts |
| 259 | +func validTrailerHeader(name string) bool { |
| 260 | + name = textproto.CanonicalMIMEHeaderKey(name) |
| 261 | + if strings.HasPrefix(name, "If-") || badTrailer[name] { |
| 262 | + return false |
| 263 | + } |
| 264 | + return true |
| 265 | +} |
| 266 | + |
| 267 | +var badTrailer = map[string]bool{ |
| 268 | + "Authorization": true, |
| 269 | + "Cache-Control": true, |
| 270 | + "Connection": true, |
| 271 | + "Content-Encoding": true, |
| 272 | + "Content-Length": true, |
| 273 | + "Content-Range": true, |
| 274 | + "Content-Type": true, |
| 275 | + "Expect": true, |
| 276 | + "Host": true, |
| 277 | + "Keep-Alive": true, |
| 278 | + "Max-Forwards": true, |
| 279 | + "Pragma": true, |
| 280 | + "Proxy-Authenticate": true, |
| 281 | + "Proxy-Authorization": true, |
| 282 | + "Proxy-Connection": true, |
| 283 | + "Range": true, |
| 284 | + "Realm": true, |
| 285 | + "Te": true, |
| 286 | + "Trailer": true, |
| 287 | + "Transfer-Encoding": true, |
| 288 | + "Www-Authenticate": true, |
| 289 | +} |
0 commit comments