Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ permissions:
contents: read
jobs:
build:
uses: oapi-codegen/actions/.github/workflows/ci.yml@b9f2c274c1c631e648931dbbcc1942c2b2027837 # v0.4.0
uses: oapi-codegen/actions/.github/workflows/ci.yml@6cf35d4f044f2663dae54547ff6d426e565beb48 # v0.6.0
with:
lint_versions: '["1.25"]'
286 changes: 286 additions & 0 deletions internal/test/nethttp/oapi_validate_prefix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
package gorilla

import (
"context"
_ "embed"
"net/http"
"testing"

middleware "github.com/oapi-codegen/nethttp-middleware"

"github.com/getkin/kin-openapi/openapi3"
"github.com/getkin/kin-openapi/openapi3filter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// prefixTestSpec defines a minimal spec with /resource (GET+POST) for prefix testing
const prefixTestSpec = `
openapi: "3.0.0"
info:
version: 1.0.0
title: TestServer
paths:
/resource:
get:
operationId: getResource
parameters:
- name: id
in: query
schema:
type: integer
minimum: 10
maximum: 100
responses:
'200':
description: success
post:
operationId: createResource
responses:
'204':
description: No content
requestBody:
required: true
content:
application/json:
schema:
properties:
name:
type: string
additionalProperties: false
`

func loadPrefixSpec(t *testing.T) *openapi3.T {
t.Helper()
spec, err := openapi3.NewLoader().LoadFromData([]byte(prefixTestSpec))
require.NoError(t, err)
spec.Servers = nil
return spec
}

// setupPrefixHandler creates a mux with a handler at the given handlerPath
// that records whether it was called and what path it saw.
func setupPrefixHandler(t *testing.T, handlerPath string) (*http.ServeMux, *bool, *string) {
t.Helper()
called := new(bool)
observedPath := new(string)

mux := http.NewServeMux()
mux.HandleFunc(handlerPath, func(w http.ResponseWriter, r *http.Request) {
*called = true
*observedPath = r.URL.Path
w.WriteHeader(http.StatusNoContent)
})
return mux, called, observedPath
}

func TestPrefix_ErrorHandler_ValidRequest(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, observedPath := setupPrefixHandler(t, "/api/v1/resource")

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api/v1",
})
server := mw(mux)

body := struct {
Name string `json:"name"`
}{Name: "test"}

rec := doPost(t, server, "http://example.com/api/v1/resource", body)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.True(t, *called, "handler should have been called")
assert.Equal(t, "/api/v1/resource", *observedPath, "handler should see the original path, not the stripped one")
}

func TestPrefix_ErrorHandler_InvalidRequest(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, _ := setupPrefixHandler(t, "/api/v1/resource")

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api/v1",
})
server := mw(mux)

// Send a request with out-of-spec query param (id=500, max is 100)
rec := doGet(t, server, "http://example.com/api/v1/resource?id=500")
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.False(t, *called, "handler should not have been called for invalid request")
}

func TestPrefix_ErrorHandlerWithOpts_ValidRequest(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, observedPath := setupPrefixHandler(t, "/api/v1/resource")

var errHandlerCalled bool
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api/v1",
ErrorHandlerWithOpts: func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
errHandlerCalled = true
http.Error(w, err.Error(), opts.StatusCode)
},
})
server := mw(mux)

body := struct {
Name string `json:"name"`
}{Name: "test"}

rec := doPost(t, server, "http://example.com/api/v1/resource", body)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.True(t, *called, "handler should have been called")
assert.False(t, errHandlerCalled, "error handler should not have been called")
assert.Equal(t, "/api/v1/resource", *observedPath, "handler should see the original path, not the stripped one")
}

func TestPrefix_ErrorHandlerWithOpts_InvalidRequest(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, _ := setupPrefixHandler(t, "/api/v1/resource")

var errHandlerCalled bool
mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api/v1",
ErrorHandlerWithOpts: func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
errHandlerCalled = true
http.Error(w, err.Error(), opts.StatusCode)
},
})
server := mw(mux)

rec := doGet(t, server, "http://example.com/api/v1/resource?id=500")
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.False(t, *called, "handler should not have been called")
assert.True(t, errHandlerCalled, "error handler should have been called")
}

func TestPrefix_RequestWithoutPrefix_NotMatched(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, _ := setupPrefixHandler(t, "/resource")

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api/v1",
})
server := mw(mux)

// A request to /resource (without the prefix) should not match the
// prefix and should be treated as if no prefix stripping happened.
// Since /resource IS in the spec, this should still validate.
rec := doGet(t, server, "http://example.com/resource")
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.True(t, *called, "handler should have been called for path that doesn't have the prefix")
}

func TestPrefix_PartialSegmentMatch_NotStripped(t *testing.T) {
spec := loadPrefixSpec(t)

// Register handler at the path that would result from incorrect partial stripping
mux := http.NewServeMux()

var resourceV2Called bool
mux.HandleFunc("/api-v2/resource", func(w http.ResponseWriter, r *http.Request) {
resourceV2Called = true
w.WriteHeader(http.StatusOK)
})

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api",
})
server := mw(mux)

// /api-v2/resource should NOT have "/api" stripped to become "-v2/resource"
// The prefix must match on a path segment boundary.
rec := doGet(t, server, "http://example.com/api-v2/resource")
// The prefix doesn't match on a segment boundary, so no stripping happens.
// /api-v2/resource is not in the spec → 404.
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.False(t, resourceV2Called, "handler should not have been called")
}

func TestPrefix_ExactPrefixOnly_NoTrailingSlash(t *testing.T) {
spec := loadPrefixSpec(t)
mux, called, _ := setupPrefixHandler(t, "/api/resource")

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/api",
})
server := mw(mux)

// /api/resource → strip /api → /resource (which is in the spec)
body := struct {
Name string `json:"name"`
}{Name: "test"}

rec := doPost(t, server, "http://example.com/api/resource", body)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.True(t, *called, "handler should have been called")
}

func TestPrefix_ErrorHandlerWithOpts_HandlerSeesOriginalPath(t *testing.T) {
spec := loadPrefixSpec(t)
mux, _, observedPath := setupPrefixHandler(t, "/prefix/resource")

mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{
Prefix: "/prefix",
ErrorHandlerWithOpts: func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
http.Error(w, err.Error(), opts.StatusCode)
},
})
server := mw(mux)

rec := doGet(t, server, "http://example.com/prefix/resource")
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "/prefix/resource", *observedPath, "downstream handler must see the original un-stripped path")
}

func TestPrefix_WithAuthenticationFunc(t *testing.T) {
spec := loadPrefixSpec(t)

// Add a protected endpoint to the spec for this test
protectedSpec := `
openapi: "3.0.0"
info:
version: 1.0.0
title: TestServer
paths:
/resource:
get:
operationId: getResource
security:
- BearerAuth:
- someScope
responses:
'200':
description: success
components:
securitySchemes:
BearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
`
_ = spec // unused, use protectedSpec instead
pSpec, err := openapi3.NewLoader().LoadFromData([]byte(protectedSpec))
require.NoError(t, err)
pSpec.Servers = nil

mux := http.NewServeMux()
var called bool
mux.HandleFunc("/api/resource", func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})

mw := middleware.OapiRequestValidatorWithOptions(pSpec, &middleware.Options{
Prefix: "/api",
Options: openapi3filter.Options{
AuthenticationFunc: func(ctx context.Context, input *openapi3filter.AuthenticationInput) error {
return nil // always allow
},
},
})
server := mw(mux)

rec := doGet(t, server, "http://example.com/api/resource")
assert.Equal(t, http.StatusOK, rec.Code)
assert.True(t, called, "handler should have been called when auth passes")
}
55 changes: 51 additions & 4 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type Options struct {
SilenceServersWarning bool
// DoNotValidateServers ensures that there is no Host validation performed (see `SilenceServersWarning` and https://github.com/deepmap/oapi-codegen/issues/882 for more details)
DoNotValidateServers bool
// Prefix allows (optionally) trimming a prefix from the API path.
// This may be useful if your API is routed to an internal path that is different from the OpenAPI specification.
Prefix string
}

// OapiRequestValidator Creates the middleware to validate that incoming requests match the given OpenAPI 3.x spec, with a default set of configuration.
Expand Down Expand Up @@ -153,10 +156,53 @@ func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseW
errorHandler(w, err.Error(), statusCode)
}

func makeRequestForValidation(r *http.Request, options *Options) *http.Request {
if options == nil || options.Prefix == "" {
return r
}

// Only strip the prefix when it matches on a path segment boundary:
// the path must equal the prefix exactly, or the character immediately
// after the prefix must be '/'.
if !hasPathPrefix(r.URL.Path, options.Prefix) {
return r
}

r = r.Clone(r.Context())

r.RequestURI = stripPrefix(r.RequestURI, options.Prefix)
r.URL.Path = stripPrefix(r.URL.Path, options.Prefix)
if r.URL.RawPath != "" {
r.URL.RawPath = stripPrefix(r.URL.RawPath, options.Prefix)
}

return r
}

// hasPathPrefix reports whether path starts with prefix on a segment boundary.
func hasPathPrefix(path, prefix string) bool {
if !strings.HasPrefix(path, prefix) {
return false
}
// The prefix matches if the path equals the prefix exactly, or
// the next character is a '/'.
return len(path) == len(prefix) || path[len(prefix)] == '/'
}

// stripPrefix removes prefix from s and returns the result.
// If s does not start with prefix it is returned unchanged.
func stripPrefix(s, prefix string) string {
return strings.TrimPrefix(s, prefix)
}

// Note that this is an inline-and-modified version of `validateRequest`, with a simplified control flow and providing full access to the `error` for the `ErrorHandlerWithOpts` function.
func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) {
// Build a (possibly prefix-stripped) request for validation, but keep
// the original so the downstream handler sees the un-modified path.
validationReq := makeRequestForValidation(r, options)

// Find route
route, pathParams, err := router.FindRoute(r)
route, pathParams, err := router.FindRoute(validationReq)
if err != nil {
errOpts := ErrorHandlerOpts{
// MatchedRoute will be nil, as we've not matched a route we know about
Expand All @@ -177,7 +223,7 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R

// Validate request
requestValidationInput := &openapi3filter.RequestValidationInput{
Request: r,
Request: validationReq,
PathParams: pathParams,
Route: route,
}
Expand All @@ -186,9 +232,9 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
requestValidationInput.Options = &options.Options
}

err = openapi3filter.ValidateRequest(r.Context(), requestValidationInput)
err = openapi3filter.ValidateRequest(validationReq.Context(), requestValidationInput)
if err == nil {
// it's a valid request, so serve it
// it's a valid request, so serve it with the original request
next.ServeHTTP(w, r)
return
}
Expand Down Expand Up @@ -220,6 +266,7 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
// validateRequest is called from the middleware above and actually does the work
// of validating a request.
func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) {
r = makeRequestForValidation(r, options)

// Find route
route, pathParams, err := router.FindRoute(r)
Expand Down
Loading
Loading