diff --git a/rolling-shutter/keyper/kprapi/kprapi.go b/rolling-shutter/keyper/kprapi/kprapi.go index 40fa0305..09047887 100644 --- a/rolling-shutter/keyper/kprapi/kprapi.go +++ b/rolling-shutter/keyper/kprapi/kprapi.go @@ -150,9 +150,9 @@ func (srv *Server) waitShutdown(ctx context.Context) error { func (srv *Server) setupAPIRouter(swagger *openapi3.T) http.Handler { router := chi.NewRouter() + router.Use(chimiddleware.OapiRequestValidator(swagger)) router.Use(kproapi.ConfigMiddleware(srv.config.GetEnableWriteOperations())) - router.Use(chimiddleware.OapiRequestValidator(swagger)) _ = kproapi.HandlerFromMux(srv, router) return router diff --git a/rolling-shutter/keyper/kproapi/middleware.go b/rolling-shutter/keyper/kproapi/middleware.go index 79f32210..64a29176 100644 --- a/rolling-shutter/keyper/kproapi/middleware.go +++ b/rolling-shutter/keyper/kproapi/middleware.go @@ -3,6 +3,8 @@ package kproapi import ( "encoding/json" "net/http" + "regexp" + "strings" "github.com/getkin/kin-openapi/openapi3" ) @@ -34,8 +36,23 @@ func shouldEnableEndpoint(operation *openapi3.Operation, enableWriteOperations b // findOperation looks up the OpenAPI operation for the given path and method. func findOperation(spec *openapi3.T, path string, method string) *openapi3.Operation { - pathItem := spec.Paths.Find(path) + pathItem := spec.Paths.Find(path) // first try to find the path in the spec if pathItem == nil { + for specPath, pItem := range spec.Paths { // fallback for path containing parameters + rePath := "^" + regexp.QuoteMeta(specPath) + rePath = strings.ReplaceAll(rePath, `\{`, "{") + rePath = strings.ReplaceAll(rePath, `\}`, "}") + rePath = regexp.MustCompile(`\{[^/]+\}`).ReplaceAllString(rePath, `[^/]+`) + rePath += "$" + + if matched, _ := regexp.MatchString(rePath, path); matched { + pathItem = pItem + break + } + } + } + + if pathItem == nil { // if no path is found still, return nil return nil } diff --git a/rolling-shutter/keyper/kproapi/middleware_test.go b/rolling-shutter/keyper/kproapi/middleware_test.go index 3d5f309b..bce0b73c 100644 --- a/rolling-shutter/keyper/kproapi/middleware_test.go +++ b/rolling-shutter/keyper/kproapi/middleware_test.go @@ -182,6 +182,58 @@ func TestFindOperation(t *testing.T) { } } +func TestFindOperation_ParameterizedPaths(t *testing.T) { + spec := &openapi3.T{ + Paths: openapi3.Paths{ + "/test/{id}": &openapi3.PathItem{ + Get: &openapi3.Operation{}, + }, + "/tests/{id}/items/{name}": &openapi3.PathItem{ + Post: &openapi3.Operation{}, + }, + }, + } + + tests := []struct { + name string + path string + method string + want *openapi3.Operation + }{ + { + name: "match single parameterized path", + path: "/test/123", + method: http.MethodGet, + want: spec.Paths.Find("/test/{id}").Get, + }, + { + name: "match nested parameterized path", + path: "/tests/123/items/xyz456", + method: http.MethodPost, + want: spec.Paths.Find("/tests/{id}/items/{name}").Post, + }, + { + name: "no match for wrong structure", + path: "/tests/123/items", // missing /{name} + method: http.MethodPost, + want: nil, + }, + { + name: "non-existent parameterized path", + path: "/unknown/123", + method: http.MethodGet, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := findOperation(spec, tt.path, tt.method) + assert.Equal(t, tt.want, got) + }) + } +} + func TestConfigMiddleware(t *testing.T) { // Create a test spec with both read-only and write operations spec := &openapi3.T{