diff --git a/framework/web/loader.go b/framework/web/loader.go index 2f4e4038..421bca95 100644 --- a/framework/web/loader.go +++ b/framework/web/loader.go @@ -50,6 +50,7 @@ func (l *loader) Load() (state *State, err error) { // Add initial imports l.imports.AddStd("net/http", "context") l.imports.AddNamed("middleware", "github.com/livebud/bud/package/middleware") + l.imports.AddNamed("methodoverride", "github.com/livebud/bud/package/middleware/methodoverride") l.imports.AddNamed("webrt", "github.com/livebud/bud/framework/web/webrt") l.imports.AddNamed("router", "github.com/livebud/bud/package/router") // Show the welcome page if we don't have any web resources diff --git a/framework/web/web.gotext b/framework/web/web.gotext index 58faa168..a6fbd4f7 100644 --- a/framework/web/web.gotext +++ b/framework/web/web.gotext @@ -25,12 +25,11 @@ func New( {{- end }} {{- end }} // Compose the middleware together - middleware := middleware.Compose( - middleware.MethodOverride(), - router, + stack := middleware.Compose( + methodoverride.New(), ) - // 404 at the bottom of the middleware - handler := middleware.Middleware(http.NotFoundHandler()) + // Add the router to the bottom of the middleware + handler := stack(router) // Return the web server return &Server{handler} } diff --git a/package/middleware/httpbuffer/middleware.go b/package/middleware/httpbuffer/middleware.go index b6c8c2a6..92c623d4 100644 --- a/package/middleware/httpbuffer/middleware.go +++ b/package/middleware/httpbuffer/middleware.go @@ -6,37 +6,32 @@ import ( "github.com/felixge/httpsnoop" "github.com/livebud/bud/package/log" + "github.com/livebud/bud/package/middleware" ) -func New(log log.Log) *Middleware { - return &Middleware{log} -} - -type Middleware struct { - log log.Log -} - -func (m *Middleware) Middleware(next http.Handler) http.Handler { +func New(log log.Log) middleware.Middleware { rw := &responseWriter{ code: 0, body: new(bytes.Buffer), } - return http.HandlerFunc(func(original http.ResponseWriter, r *http.Request) { - w := httpsnoop.Wrap(original, httpsnoop.Hooks{ - WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { - return rw.WriteHeader - }, - Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc { - return rw.Write - }, - Flush: func(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc { - rw.writeTo(original) - return flush - }, + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(original http.ResponseWriter, r *http.Request) { + w := httpsnoop.Wrap(original, httpsnoop.Hooks{ + WriteHeader: func(_ httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return rw.WriteHeader + }, + Write: func(_ httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return rw.Write + }, + Flush: func(flush httpsnoop.FlushFunc) httpsnoop.FlushFunc { + rw.writeTo(original) + return flush + }, + }) + next.ServeHTTP(w, r) + rw.writeTo(original) }) - next.ServeHTTP(w, r) - rw.writeTo(original) - }) + } } type responseWriter struct { diff --git a/package/middleware/httpbuffer/middleware_test.go b/package/middleware/httpbuffer/middleware_test.go index ce731430..cafbaa0b 100644 --- a/package/middleware/httpbuffer/middleware_test.go +++ b/package/middleware/httpbuffer/middleware_test.go @@ -34,8 +34,8 @@ func TestHeadersWrapped(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) is.NoErr(err) log := testlog.New() - wrap := httpbuffer.New(log) - h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + middleware := httpbuffer.New(log) + h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-A", "A") w.Write([]byte("Hello, world!")) w.Header().Add("X-B", "B") @@ -74,8 +74,8 @@ func TestWriteStatusWrapped(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) is.NoErr(err) log := testlog.New() - wrap := httpbuffer.New(log) - h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + middleware := httpbuffer.New(log) + h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-A", "A") w.WriteHeader(201) w.Write([]byte("Hello, world!")) @@ -123,8 +123,8 @@ func TestFlushWrapped(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) is.NoErr(err) log := testlog.New() - wrap := httpbuffer.New(log) - h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + middleware := httpbuffer.New(log) + h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, world!")) w.Header().Add("X-A", "A") flush, ok := w.(http.Flusher) @@ -151,8 +151,8 @@ func TestFlushStatusWrapped(t *testing.T) { req, err := http.NewRequest("GET", "/", nil) is.NoErr(err) log := testlog.New() - wrap := httpbuffer.New(log) - h := wrap.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + middleware := httpbuffer.New(log) + h := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, world!")) w.WriteHeader(201) w.Header().Add("X-A", "A") diff --git a/package/middleware/methodoverride.go b/package/middleware/methodoverride/methodoverride.go similarity index 84% rename from package/middleware/methodoverride.go rename to package/middleware/methodoverride/methodoverride.go index e22e361a..bfe885e0 100644 --- a/package/middleware/methodoverride.go +++ b/package/middleware/methodoverride/methodoverride.go @@ -1,8 +1,10 @@ -package middleware +package methodoverride import ( "net/http" "strings" + + "github.com/livebud/bud/package/middleware" ) // Methods eligible for overriding @@ -14,11 +16,11 @@ var eligible = map[string]struct{}{ const formType = "application/x-www-form-urlencoded" -// MethodOverride allows HTML
's to dispatch PATCH, PUT and +// New allows HTML 's to dispatch PATCH, PUT and // DELETE requests by overriding the request method using a hidden "_method" // field in the form body. -func MethodOverride() Middleware { - return Function(func(next http.Handler) http.Handler { +func New() middleware.Middleware { + return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only override POST requests if r.Method != http.MethodPost { @@ -47,5 +49,5 @@ func MethodOverride() Middleware { r.Method = override next.ServeHTTP(w, r) }) - }) + } } diff --git a/package/middleware/methodoverride_test.go b/package/middleware/methodoverride/methodoverride_test.go similarity index 83% rename from package/middleware/methodoverride_test.go rename to package/middleware/methodoverride/methodoverride_test.go index a860e27c..03222ac4 100644 --- a/package/middleware/methodoverride_test.go +++ b/package/middleware/methodoverride/methodoverride_test.go @@ -1,4 +1,4 @@ -package middleware_test +package methodoverride_test import ( "bytes" @@ -8,7 +8,7 @@ import ( "testing" "github.com/livebud/bud/internal/is" - "github.com/livebud/bud/package/middleware" + "github.com/livebud/bud/package/middleware/methodoverride" "github.com/livebud/bud/package/router" ) @@ -27,7 +27,8 @@ func TestNoMethod404(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Patch("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 404) } @@ -42,7 +43,8 @@ func TestPatch200(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Patch("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 200) } @@ -55,7 +57,8 @@ func TestPatchNoBody404(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Patch("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 404) } @@ -69,7 +72,8 @@ func TestPatchNoType404(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Patch("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 404) } @@ -84,7 +88,8 @@ func TestPatchInsensitive200(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Patch("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 200) } @@ -99,7 +104,8 @@ func TestDelete200(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Delete("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 200) } @@ -114,7 +120,8 @@ func TestPut200(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Put("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 200) } @@ -129,7 +136,8 @@ func TestGet404(t *testing.T) { w := httptest.NewRecorder() router := router.New() router.Get("/", ok()) - middleware.MethodOverride().Middleware(router).ServeHTTP(w, req) + middleware := methodoverride.New() + middleware(router).ServeHTTP(w, req) res := w.Result() is.Equal(res.StatusCode, 404) } diff --git a/package/middleware/middleware.go b/package/middleware/middleware.go index beb4332a..9bfc34c7 100644 --- a/package/middleware/middleware.go +++ b/package/middleware/middleware.go @@ -4,38 +4,56 @@ import ( "net/http" ) -// Interface for implementing middleware -type Middleware interface { - Middleware(next http.Handler) http.Handler -} - -// Function for creating middleware -type Function func(next http.Handler) http.Handler - -func (fn Function) Middleware(next http.Handler) http.Handler { - return fn(next) -} - -// Stack of middleware -type Stack []Middleware - -// Middleware fn -func (stack Stack) Middleware(next http.Handler) http.Handler { - return Compose(stack...).Middleware(next) -} +type Middleware func(http.Handler) http.Handler // Compose a stack of middleware into a single middleware -func Compose(stack ...Middleware) Middleware { - return Function(func(h http.Handler) http.Handler { - if len(stack) == 0 { +func Compose(middlewares ...Middleware) Middleware { + return func(h http.Handler) http.Handler { + if len(middlewares) == 0 { return h } - for i := len(stack) - 1; i >= 0; i-- { - if stack[i] == nil { + for i := len(middlewares) - 1; i >= 0; i-- { + if middlewares[i] == nil { continue } - h = stack[i].Middleware(h) + h = middlewares[i](h) } return h - }) + } } + +// // Interface for implementing middleware +// type Middleware interface { +// Middleware(next http.Handler) http.Handler +// } + +// // Function for creating middleware +// type Function func(next http.Handler) http.Handler + +// func (fn Function) Middleware(next http.Handler) http.Handler { +// return fn(next) +// } + +// // Stack of middleware +// type Stack []Middleware + +// // Middleware fn +// func (stack Stack) Middleware(next http.Handler) http.Handler { +// return Compose(stack...).Middleware(next) +// } + +// // Compose a stack of middleware into a single middleware +// func Compose(stack ...Middleware) Middleware { +// return Function(func(h http.Handler) http.Handler { +// if len(stack) == 0 { +// return h +// } +// for i := len(stack) - 1; i >= 0; i-- { +// if stack[i] == nil { +// continue +// } +// h = stack[i].Middleware(h) +// } +// return h +// }) +// }