Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mux middlewares and inject http pattern #4290

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions runtime/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var malformedHTTPHeaders = map[string]struct{}{
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
httpPatternKey struct{}

AnnotateContextOption func(ctx context.Context) context.Context
)
Expand Down Expand Up @@ -398,3 +399,13 @@ func HTTPPathPattern(ctx context.Context) (string, bool) {
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}

// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists.
func HTTPPattern(ctx context.Context) (Pattern, bool) {
v, ok := ctx.Value(httpPatternKey{}).(Pattern)
return v, ok
}

func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context {
return context.WithValue(ctx, httpPatternKey{}, httpPattern)
}
42 changes: 40 additions & 2 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,15 @@ var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)

// A Middleware handler is simply an HandlerFunc that wraps another HandlerFunc to do some pre- and/or post-processing of the request
type Middleware func(HandlerFunc) HandlerFunc
nikitaksv marked this conversation as resolved.
Show resolved Hide resolved

// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
middlewares []Middleware
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
Expand All @@ -64,6 +68,7 @@ type ServeMux struct {
routingErrorHandler RoutingErrorHandlerFunc
disablePathLengthFallback bool
unescapingMode UnescapingMode
injectHTTPPattern bool
}

// ServeMuxOption is an option that can be given to a ServeMux on construction.
Expand All @@ -89,6 +94,20 @@ func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
}
}

// WithMiddlewares sets server middleware for all handlers
nikitaksv marked this conversation as resolved.
Show resolved Hide resolved
func WithMiddlewares(middlewares ...Middleware) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.middlewares = append(serveMux.middlewares, middlewares...)
}
}

// WithInjectHTTPPattern sets the current HTTP Pattern in the request context
func WithInjectHTTPPattern() ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.injectHTTPPattern = true
}
}
nikitaksv marked this conversation as resolved.
Show resolved Hide resolved

// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
Expand Down Expand Up @@ -305,6 +324,9 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {

// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
if len(s.middlewares) > 0 {
h = chainMiddlewares(s.middlewares)(h)
}
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}

Expand Down Expand Up @@ -405,7 +427,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
continue
}
h.h(w, r, pathParams)
s.handleHandler(h, w, r, pathParams)
return
}

Expand Down Expand Up @@ -458,7 +480,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
h.h(w, r, pathParams)
s.handleHandler(h, w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
Expand All @@ -484,3 +506,19 @@ type handler struct {
pat Pattern
h HandlerFunc
}

func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
if s.injectHTTPPattern {
r = r.WithContext(withHTTPPattern(r.Context(), h.pat))
}
h.h(w, r, pathParams)
}

func chainMiddlewares(mws []Middleware) Middleware {
return func(next HandlerFunc) HandlerFunc {
for i := len(mws); i > 0; i-- {
next = mws[i-1](next)
}
return next
}
}
60 changes: 60 additions & 0 deletions runtime/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,63 @@ func (g *dummyHealthCheckClient) Check(ctx context.Context, r *grpc_health_v1.He
func (g *dummyHealthCheckClient) Watch(ctx context.Context, r *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) {
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func TestServeMux_HandleMiddlewares(t *testing.T) {
var mws []int
mux := runtime.NewServeMux(runtime.WithMiddlewares(
func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
mws = append(mws, 1)
next(w, r, pathParams)
}
},
func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
mws = append(mws, 2)
next(w, r, pathParams)
}
},
))
err := mux.HandlePath("GET", "/test", func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
if len(mws) == 0 {
t.Errorf("middlewares not called")
} else if mws[0] != 1 {
t.Errorf("first middleware is not called first")
} else if mws[1] != 2 {
t.Errorf("second middleware is not called the second")
}
})
if err != nil {
t.Errorf("The route test with method GET and path /test invalid, got %v", err)
}

r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
if w.Code != 200 {
t.Errorf("request not processed")
}
}

func TestServeMux_InjectPattern(t *testing.T) {
mux := runtime.NewServeMux(runtime.WithInjectHTTPPattern())
err := mux.HandlePath("GET", "/test", func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
p, ok := runtime.HTTPPattern(r.Context())
if !ok {
t.Errorf("pattern is not injected")
}
if p.String() != "/test" {
t.Errorf("pattern not /test")
}
})
if err != nil {
t.Errorf("The route test with method GET and path /test invalid, got %v", err)
}

r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
if w.Code != 200 {
t.Errorf("request not processed")
}
}
nikitaksv marked this conversation as resolved.
Show resolved Hide resolved