Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mux middlewares and inject http pattern #4290

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ func local_request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ct
// UnaryRPC :call {{$svc.GetName}}Server directly.
// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906.
// Note that using this registration option will cause many gRPC library features to stop working. Consider using Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}FromEndpoint instead.
// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call.
func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context, mux *runtime.ServeMux, server {{$svc.InstanceName}}Server) error {
{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
Expand Down Expand Up @@ -703,7 +704,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}(ctx context.Context, mux *
// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "{{$svc.InstanceName}}Client".
// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "{{$svc.InstanceName}}Client"
// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in
// "{{$svc.InstanceName}}Client" to call the correct interceptors.
// "{{$svc.InstanceName}}Client" to call the correct interceptors. This client ignores the HTTP middlewares.
func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context, mux *runtime.ServeMux, client {{$svc.InstanceName}}Client) error {
{{range $m := $svc.Methods}}
{{range $b := $m.Bindings}}
Expand Down
11 changes: 11 additions & 0 deletions runtime/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var malformedHTTPHeaders = map[string]struct{}{
type (
rpcMethodKey struct{}
httpPathPatternKey struct{}
httpPatternKey struct{}

AnnotateContextOption func(ctx context.Context) context.Context
)
Expand Down Expand Up @@ -398,3 +399,13 @@ func HTTPPathPattern(ctx context.Context) (string, bool) {
func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
}

// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists.
func HTTPPattern(ctx context.Context) (Pattern, bool) {
v, ok := ctx.Value(httpPatternKey{}).(Pattern)
return v, ok
}

func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context {
return context.WithValue(ctx, httpPatternKey{}, httpPattern)
}
35 changes: 33 additions & 2 deletions runtime/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,17 @@ var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
// A HandlerFunc handles a specific pair of path pattern and HTTP method.
type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)

// A Middleware handler wraps another HandlerFunc to do some pre- and/or post-processing of the request. This is used as an alternative to gRPC interceptors when using the direct-to-implementation
// registration methods. It is generally recommended to use gRPC client or server interceptors instead
// where possible.
type Middleware func(HandlerFunc) HandlerFunc

// ServeMux is a request multiplexer for grpc-gateway.
// It matches http requests to patterns and invokes the corresponding handler.
type ServeMux struct {
// handlers maps HTTP method to a list of handlers.
handlers map[string][]handler
middlewares []Middleware
forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
marshalers marshalerRegistry
incomingHeaderMatcher HeaderMatcherFunc
Expand Down Expand Up @@ -89,6 +95,15 @@ func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
}
}

// WithMiddlewares sets server middleware for all handlers. This is useful as an alternative to gRPC
// interceptors when using the direct-to-implementation registration methods and cannot rely
// on gRPC interceptors. It's recommended to use gRPC interceptors instead if possible.
func WithMiddlewares(middlewares ...Middleware) ServeMuxOption {
return func(serveMux *ServeMux) {
serveMux.middlewares = append(serveMux.middlewares, middlewares...)
}
}

// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
// done with careful consideration.
Expand Down Expand Up @@ -305,6 +320,9 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux {

// Handle associates "h" to the pair of HTTP method and path pattern.
func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
if len(s.middlewares) > 0 {
h = chainMiddlewares(s.middlewares)(h)
}
s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
}

Expand Down Expand Up @@ -405,7 +423,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
continue
}
h.h(w, r, pathParams)
s.handleHandler(h, w, r, pathParams)
return
}

Expand Down Expand Up @@ -458,7 +476,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
return
}
h.h(w, r, pathParams)
s.handleHandler(h, w, r, pathParams)
return
}
_, outboundMarshaler := MarshalerForRequest(s, r)
Expand All @@ -484,3 +502,16 @@ type handler struct {
pat Pattern
h HandlerFunc
}

func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
h.h(w, r.WithContext(withHTTPPattern(r.Context(), h.pat)), pathParams)
}

func chainMiddlewares(mws []Middleware) Middleware {
return func(next HandlerFunc) HandlerFunc {
for i := len(mws); i > 0; i-- {
next = mws[i-1](next)
}
return next
}
}
60 changes: 60 additions & 0 deletions runtime/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,3 +867,63 @@ func (g *dummyHealthCheckClient) Check(ctx context.Context, r *grpc_health_v1.He
func (g *dummyHealthCheckClient) Watch(ctx context.Context, r *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) {
return nil, status.Error(codes.Unimplemented, "unimplemented")
}

func TestServeMux_HandleMiddlewares(t *testing.T) {
var mws []int
mux := runtime.NewServeMux(runtime.WithMiddlewares(
func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
mws = append(mws, 1)
next(w, r, pathParams)
}
},
func(next runtime.HandlerFunc) runtime.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
mws = append(mws, 2)
next(w, r, pathParams)
}
},
))
err := mux.HandlePath("GET", "/test", func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
if len(mws) == 0 {
t.Errorf("middlewares not called")
} else if mws[0] != 1 {
t.Errorf("first middleware is not called first")
} else if mws[1] != 2 {
t.Errorf("second middleware is not called the second")
}
})
if err != nil {
t.Errorf("The route test with method GET and path /test invalid, got %v", err)
}

r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
if w.Code != 200 {
t.Errorf("request not processed")
}
}

func TestServeMux_InjectPattern(t *testing.T) {
mux := runtime.NewServeMux()
err := mux.HandlePath("GET", "/test", func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
p, ok := runtime.HTTPPattern(r.Context())
if !ok {
t.Errorf("pattern is not injected")
}
if p.String() != "/test" {
t.Errorf("pattern not /test")
}
})
if err != nil {
t.Errorf("The route test with method GET and path /test invalid, got %v", err)
}

r := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, r)
if w.Code != 200 {
t.Errorf("request not processed")
}
}
nikitaksv marked this conversation as resolved.
Show resolved Hide resolved