Skip to content

Commit

Permalink
simplify middleware api (#404)
Browse files Browse the repository at this point in the history
* simplify middleware api

* update web generator
  • Loading branch information
matthewmueller committed Apr 17, 2023
1 parent a7c8911 commit 721420f
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 78 deletions.
1 change: 1 addition & 0 deletions framework/web/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (l *loader) Load() (state *State, err error) {
// Add initial imports
l.imports.AddStd("net/http", "context")
l.imports.AddNamed("middleware", "github.com/livebud/bud/package/middleware")
l.imports.AddNamed("methodoverride", "github.com/livebud/bud/package/middleware/methodoverride")
l.imports.AddNamed("webrt", "github.com/livebud/bud/framework/web/webrt")
l.imports.AddNamed("router", "github.com/livebud/bud/package/router")
// Show the welcome page if we don't have any web resources
Expand Down
9 changes: 4 additions & 5 deletions framework/web/web.gotext
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ func New(
{{- end }}
{{- end }}
// Compose the middleware together
middleware := middleware.Compose(
middleware.MethodOverride(),
router,
stack := middleware.Compose(
methodoverride.New(),
)
// 404 at the bottom of the middleware
handler := middleware.Middleware(http.NotFoundHandler())
// Add the router to the bottom of the middleware
handler := stack(router)
// Return the web server
return &Server{handler}
}
Expand Down
43 changes: 19 additions & 24 deletions package/middleware/httpbuffer/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,32 @@ import (

"github.com/felixge/httpsnoop"
"github.com/livebud/bud/package/log"
"github.com/livebud/bud/package/middleware"
)

func New(log log.Log) *Middleware {
return &Middleware{log}
}

type Middleware struct {
log log.Log
}

func (m *Middleware) Middleware(next http.Handler) http.Handler {
func New(log log.Log) middleware.Middleware {
rw := &responseWriter{
code: 0,
body: new(bytes.Buffer),
}
return http.HandlerFunc(func(original http.ResponseWriter, r *http.Request) {
w := httpsnoop.Wrap(original, httpsnoop.Hooks{
WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return rw.WriteHeader
},
Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return rw.Write
},
Flush: func(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc {
rw.writeTo(original)
return flush
},
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(original http.ResponseWriter, r *http.Request) {
w := httpsnoop.Wrap(original, httpsnoop.Hooks{
WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
return rw.WriteHeader
},
Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc {
return rw.Write
},
Flush: func(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc {
rw.writeTo(original)
return flush
},
})
next.ServeHTTP(w, r)
rw.writeTo(original)
})
next.ServeHTTP(w, r)
rw.writeTo(original)
})
}
}

type responseWriter struct {
Expand Down
16 changes: 8 additions & 8 deletions package/middleware/httpbuffer/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func TestHeadersWrapped(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware := httpbuffer.New(log)
h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.Write([]byte("Hello, world!"))
w.Header().Add("X-B", "B")
Expand Down Expand Up @@ -74,8 +74,8 @@ func TestWriteStatusWrapped(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware := httpbuffer.New(log)
h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-A", "A")
w.WriteHeader(201)
w.Write([]byte("Hello, world!"))
Expand Down Expand Up @@ -123,8 +123,8 @@ func TestFlushWrapped(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware := httpbuffer.New(log)
h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, world!"))
w.Header().Add("X-A", "A")
flush, ok := w.(http.Flusher)
Expand All @@ -151,8 +151,8 @@ func TestFlushStatusWrapped(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
is.NoErr(err)
log := testlog.New()
wrap := httpbuffer.New(log)
h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware := httpbuffer.New(log)
h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, world!"))
w.WriteHeader(201)
w.Header().Add("X-A", "A")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package middleware
package methodoverride

import (
"net/http"
"strings"

"github.com/livebud/bud/package/middleware"
)

// Methods eligible for overriding
Expand All @@ -14,11 +16,11 @@ var eligible = map[string]struct{}{

const formType = "application/x-www-form-urlencoded"

// MethodOverride allows HTML <form method="post">'s to dispatch PATCH, PUT and
// New allows HTML <form method="post">'s to dispatch PATCH, PUT and
// DELETE requests by overriding the request method using a hidden "_method"
// field in the form body.
func MethodOverride() Middleware {
return Function(func(next http.Handler) http.Handler {
func New() middleware.Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only override POST requests
if r.Method != http.MethodPost {
Expand Down Expand Up @@ -47,5 +49,5 @@ func MethodOverride() Middleware {
r.Method = override
next.ServeHTTP(w, r)
})
})
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package middleware_test
package methodoverride_test

import (
"bytes"
Expand All @@ -8,7 +8,7 @@ import (
"testing"

"github.com/livebud/bud/internal/is"
"github.com/livebud/bud/package/middleware"
"github.com/livebud/bud/package/middleware/methodoverride"
"github.com/livebud/bud/package/router"
)

Expand All @@ -27,7 +27,8 @@ func TestNoMethod404(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Patch("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 404)
}
Expand All @@ -42,7 +43,8 @@ func TestPatch200(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Patch("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 200)
}
Expand All @@ -55,7 +57,8 @@ func TestPatchNoBody404(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Patch("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 404)
}
Expand All @@ -69,7 +72,8 @@ func TestPatchNoType404(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Patch("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 404)
}
Expand All @@ -84,7 +88,8 @@ func TestPatchInsensitive200(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Patch("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 200)
}
Expand All @@ -99,7 +104,8 @@ func TestDelete200(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Delete("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 200)
}
Expand All @@ -114,7 +120,8 @@ func TestPut200(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Put("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 200)
}
Expand All @@ -129,7 +136,8 @@ func TestGet404(t *testing.T) {
w := httptest.NewRecorder()
router := router.New()
router.Get("/", ok())
middleware.MethodOverride().Middleware(router).ServeHTTP(w, req)
middleware := methodoverride.New()
middleware(router).ServeHTTP(w, req)
res := w.Result()
is.Equal(res.StatusCode, 404)
}
70 changes: 44 additions & 26 deletions package/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,56 @@ import (
"net/http"
)

// Interface for implementing middleware
type Middleware interface {
Middleware(next http.Handler) http.Handler
}

// Function for creating middleware
type Function func(next http.Handler) http.Handler

func (fn Function) Middleware(next http.Handler) http.Handler {
return fn(next)
}

// Stack of middleware
type Stack []Middleware

// Middleware fn
func (stack Stack) Middleware(next http.Handler) http.Handler {
return Compose(stack...).Middleware(next)
}
type Middleware func(http.Handler) http.Handler

// Compose a stack of middleware into a single middleware
func Compose(stack ...Middleware) Middleware {
return Function(func(h http.Handler) http.Handler {
if len(stack) == 0 {
func Compose(middlewares ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
if len(middlewares) == 0 {
return h
}
for i := len(stack) - 1; i >= 0; i-- {
if stack[i] == nil {
for i := len(middlewares) - 1; i >= 0; i-- {
if middlewares[i] == nil {
continue
}
h = stack[i].Middleware(h)
h = middlewares[i](h)
}
return h
})
}
}

// // Interface for implementing middleware
// type Middleware interface {
// Middleware(next http.Handler) http.Handler
// }

// // Function for creating middleware
// type Function func(next http.Handler) http.Handler

// func (fn Function) Middleware(next http.Handler) http.Handler {
// return fn(next)
// }

// // Stack of middleware
// type Stack []Middleware

// // Middleware fn
// func (stack Stack) Middleware(next http.Handler) http.Handler {
// return Compose(stack...).Middleware(next)
// }

// // Compose a stack of middleware into a single middleware
// func Compose(stack ...Middleware) Middleware {
// return Function(func(h http.Handler) http.Handler {
// if len(stack) == 0 {
// return h
// }
// for i := len(stack) - 1; i >= 0; i-- {
// if stack[i] == nil {
// continue
// }
// h = stack[i].Middleware(h)
// }
// return h
// })
// }

0 comments on commit 721420f

Please sign in to comment.