Skip to content

Commit 2d8c9bb

Browse files
committed
Add 'pkg/' from commit 'bba261e1eec59928a5a8463d8844c9c525f51fc9'
git-subtree-dir: pkg git-subtree-mainline: 04a0ef6 git-subtree-split: bba261e
2 parents 04a0ef6 + bba261e commit 2d8c9bb

File tree

3 files changed

+778
-0
lines changed

3 files changed

+778
-0
lines changed

pkg/oapi_validate.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// Copyright 2019 DeepMap, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package middleware
16+
17+
import (
18+
"context"
19+
"errors"
20+
"fmt"
21+
"log"
22+
"net/http"
23+
"os"
24+
"strings"
25+
26+
"github.com/getkin/kin-openapi/openapi3"
27+
"github.com/getkin/kin-openapi/openapi3filter"
28+
"github.com/getkin/kin-openapi/routers"
29+
"github.com/getkin/kin-openapi/routers/gorillamux"
30+
"github.com/labstack/echo/v4"
31+
echomiddleware "github.com/labstack/echo/v4/middleware"
32+
)
33+
34+
const (
35+
EchoContextKey = "oapi-codegen/echo-context"
36+
UserDataKey = "oapi-codegen/user-data"
37+
)
38+
39+
// OapiValidatorFromYamlFile is an Echo middleware function which validates incoming HTTP requests
40+
// to make sure that they conform to the given OAPI 3.0 specification. When
41+
// OAPI validation fails on the request, we return an HTTP/400.
42+
// Create validator middleware from a YAML file path
43+
func OapiValidatorFromYamlFile(path string) (echo.MiddlewareFunc, error) {
44+
data, err := os.ReadFile(path)
45+
if err != nil {
46+
return nil, fmt.Errorf("error reading %s: %w", path, err)
47+
}
48+
49+
swagger, err := openapi3.NewLoader().LoadFromData(data)
50+
if err != nil {
51+
return nil, fmt.Errorf("error parsing %s as Swagger YAML: %w", path, err)
52+
}
53+
return OapiRequestValidator(swagger), nil
54+
}
55+
56+
// OapiRequestValidator creates a validator from a swagger object.
57+
func OapiRequestValidator(swagger *openapi3.T) echo.MiddlewareFunc {
58+
return OapiRequestValidatorWithOptions(swagger, nil)
59+
}
60+
61+
// ErrorHandler is called when there is an error in validation
62+
type ErrorHandler func(c echo.Context, err *echo.HTTPError) error
63+
64+
// MultiErrorHandler is called when oapi returns a MultiError type
65+
type MultiErrorHandler func(openapi3.MultiError) *echo.HTTPError
66+
67+
// Options to customize request validation. These are passed through to
68+
// openapi3filter.
69+
type Options struct {
70+
ErrorHandler ErrorHandler
71+
Options openapi3filter.Options
72+
ParamDecoder openapi3filter.ContentParameterDecoder
73+
UserData interface{}
74+
Skipper echomiddleware.Skipper
75+
MultiErrorHandler MultiErrorHandler
76+
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
77+
SilenceServersWarning bool
78+
}
79+
80+
// OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options
81+
func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) echo.MiddlewareFunc {
82+
if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) {
83+
log.Println("WARN: OapiRequestValidatorWithOptions called with an OpenAPI spec that has `Servers` set. This may lead to an HTTP 400 with `no matching operation was found` when sending a valid request, as the validator performs `Host` header validation. If you're expecting `Host` header validation, you can silence this warning by setting `Options.SilenceServersWarning = true`. See https://github.com/deepmap/oapi-codegen/issues/882 for more information.")
84+
}
85+
86+
router, err := gorillamux.NewRouter(swagger)
87+
if err != nil {
88+
panic(err)
89+
}
90+
91+
skipper := getSkipperFromOptions(options)
92+
return func(next echo.HandlerFunc) echo.HandlerFunc {
93+
return func(c echo.Context) error {
94+
if skipper(c) {
95+
return next(c)
96+
}
97+
98+
err := ValidateRequestFromContext(c, router, options)
99+
if err != nil {
100+
if options != nil && options.ErrorHandler != nil {
101+
return options.ErrorHandler(c, err)
102+
}
103+
return err
104+
}
105+
return next(c)
106+
}
107+
}
108+
}
109+
110+
// ValidateRequestFromContext is called from the middleware above and actually does the work
111+
// of validating a request.
112+
func ValidateRequestFromContext(ctx echo.Context, router routers.Router, options *Options) *echo.HTTPError {
113+
req := ctx.Request()
114+
route, pathParams, err := router.FindRoute(req)
115+
116+
// We failed to find a matching route for the request.
117+
if err != nil {
118+
switch e := err.(type) {
119+
case *routers.RouteError:
120+
// We've got a bad request, the path requested doesn't match
121+
// either server, or path, or something.
122+
return echo.NewHTTPError(http.StatusNotFound, e.Reason)
123+
default:
124+
// This should never happen today, but if our upstream code changes,
125+
// we don't want to crash the server, so handle the unexpected error.
126+
return echo.NewHTTPError(http.StatusInternalServerError,
127+
fmt.Sprintf("error validating route: %s", err.Error()))
128+
}
129+
}
130+
131+
validationInput := &openapi3filter.RequestValidationInput{
132+
Request: req,
133+
PathParams: pathParams,
134+
Route: route,
135+
}
136+
137+
// Pass the Echo context into the request validator, so that any callbacks
138+
// which it invokes make it available.
139+
requestContext := context.WithValue(context.Background(), EchoContextKey, ctx) //nolint:staticcheck
140+
141+
if options != nil {
142+
validationInput.Options = &options.Options
143+
validationInput.ParamDecoder = options.ParamDecoder
144+
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck
145+
}
146+
147+
err = openapi3filter.ValidateRequest(requestContext, validationInput)
148+
if err != nil {
149+
me := openapi3.MultiError{}
150+
if errors.As(err, &me) {
151+
errFunc := getMultiErrorHandlerFromOptions(options)
152+
return errFunc(me)
153+
}
154+
155+
switch e := err.(type) {
156+
case *openapi3filter.RequestError:
157+
// We've got a bad request
158+
// Split up the verbose error by lines and return the first one
159+
// openapi errors seem to be multi-line with a decent message on the first
160+
errorLines := strings.Split(e.Error(), "\n")
161+
return &echo.HTTPError{
162+
Code: http.StatusBadRequest,
163+
Message: errorLines[0],
164+
Internal: err,
165+
}
166+
case *openapi3filter.SecurityRequirementsError:
167+
for _, err := range e.Errors {
168+
httpErr, ok := err.(*echo.HTTPError)
169+
if ok {
170+
return httpErr
171+
}
172+
}
173+
return &echo.HTTPError{
174+
Code: http.StatusForbidden,
175+
Message: e.Error(),
176+
Internal: err,
177+
}
178+
default:
179+
// This should never happen today, but if our upstream code changes,
180+
// we don't want to crash the server, so handle the unexpected error.
181+
return &echo.HTTPError{
182+
Code: http.StatusInternalServerError,
183+
Message: fmt.Sprintf("error validating request: %s", err),
184+
Internal: err,
185+
}
186+
}
187+
}
188+
return nil
189+
}
190+
191+
// GetEchoContext gets the echo context from within requests. It returns
192+
// nil if not found or wrong type.
193+
func GetEchoContext(c context.Context) echo.Context {
194+
iface := c.Value(EchoContextKey)
195+
if iface == nil {
196+
return nil
197+
}
198+
eCtx, ok := iface.(echo.Context)
199+
if !ok {
200+
return nil
201+
}
202+
return eCtx
203+
}
204+
205+
func GetUserData(c context.Context) interface{} {
206+
return c.Value(UserDataKey)
207+
}
208+
209+
// attempt to get the skipper from the options whether it is set or not
210+
func getSkipperFromOptions(options *Options) echomiddleware.Skipper {
211+
if options == nil {
212+
return echomiddleware.DefaultSkipper
213+
}
214+
215+
if options.Skipper == nil {
216+
return echomiddleware.DefaultSkipper
217+
}
218+
219+
return options.Skipper
220+
}
221+
222+
// attempt to get the MultiErrorHandler from the options. If it is not set,
223+
// return a default handler
224+
func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
225+
if options == nil {
226+
return defaultMultiErrorHandler
227+
}
228+
229+
if options.MultiErrorHandler == nil {
230+
return defaultMultiErrorHandler
231+
}
232+
233+
return options.MultiErrorHandler
234+
}
235+
236+
// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list
237+
// of all of the errors. This method is called if there are no other
238+
// methods defined on the options.
239+
func defaultMultiErrorHandler(me openapi3.MultiError) *echo.HTTPError {
240+
return &echo.HTTPError{
241+
Code: http.StatusBadRequest,
242+
Message: me.Error(),
243+
Internal: me,
244+
}
245+
}

0 commit comments

Comments
 (0)