refactor: route can add middleware

This commit is contained in:
kingecg 2023-12-12 21:44:35 +08:00
parent 18e319b407
commit 8ee2bf311f
3 changed files with 91 additions and 74 deletions

View File

@ -3,6 +3,7 @@ package admin
import ( import (
"net/http" "net/http"
"git.pyer.club/kingecg/gohttpd/model"
"git.pyer.club/kingecg/gohttpd/server" "git.pyer.club/kingecg/gohttpd/server"
) )
@ -16,19 +17,13 @@ func setConfig(w http.ResponseWriter, r *http.Request) {
} }
var AdminRoutes = []server.Route{
// Admin Routes
{Method: "GET", Path: "/about", Handle: about},
{Method: "Post", Path: "/config", Handle: setConfig},
}
var AdminServerMux *server.RestMux var AdminServerMux *server.RestMux
func init() { func init() {
AdminServerMux = server.NewRestMux("/") AdminServerMux = server.NewRestMux("/")
// AdminServerMux.routes = make(map[string]map[string]http.HandlerFunc) AdminServerMux.Use(server.BasicAuth)
for _, route := range AdminRoutes { AdminServerMux.HandleFunc("GET", "/about", http.HandlerFunc(about))
AdminServerMux.HandleFunc(route.Method, route.Path, route.Handle) postConfigRoute := AdminServerMux.HandleFunc("POST", "/config", http.HandlerFunc(setConfig))
} postConfigRoute.Add(server.Parse[model.HttpServerConfig])
AdminServerMux.Use(server.BasicAuth) AdminServerMux.Use(server.BasicAuth)
} }

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"container/list"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings" "strings"
@ -10,6 +11,40 @@ import (
type Middleware func(w http.ResponseWriter, r *http.Request, next func()) type Middleware func(w http.ResponseWriter, r *http.Request, next func())
type MiddlewareLink struct {
*list.List
}
func (ml *MiddlewareLink) Add(m Middleware) {
if ml.List.Len() == 0 {
ml.PushBack(m)
} else {
el := ml.Back()
ml.InsertBefore(m, el)
}
}
func (ml *MiddlewareLink) ServeHTTP(w http.ResponseWriter, r *http.Request) bool {
canContinue := false
next := func() {
canContinue = true
}
for e := ml.Front(); e != nil && canContinue; e = e.Next() {
e.Value.(Middleware)(w, r, next)
if !canContinue {
break
} else {
canContinue = false
}
}
return canContinue
}
func NewMiddlewareLink() *MiddlewareLink {
ml := &MiddlewareLink{list.New()}
ml.Add(Done)
return ml
}
func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) { func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
config := model.GetConfig() config := model.GetConfig()
@ -60,3 +95,7 @@ func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) {
} }
next() next()
} }
func Done(w http.ResponseWriter, r *http.Request, next func()) {
next()
}

View File

@ -12,17 +12,16 @@ type RequestCtxKey string
type Route struct { type Route struct {
Method string Method string
Path string Path string
Handle http.HandlerFunc
Handler http.Handler Handler http.Handler
matcher *UrlParamMatcher matcher *UrlParamMatcher
middles *MiddlewareLink
} }
func (route *Route) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (route *Route) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if route.Handler != nil { if route.middles.Len() > 0 && !route.middles.ServeHTTP(w, r) {
route.Handler.ServeHTTP(w, r) return
} else {
route.Handle(w, r)
} }
route.Handler.ServeHTTP(w, r)
} }
func (route *Route) Match(r *http.Request) bool { func (route *Route) Match(r *http.Request) bool {
@ -48,6 +47,22 @@ func (route *Route) Match(r *http.Request) bool {
return strings.HasPrefix(r.URL.Path, route.Path) return strings.HasPrefix(r.URL.Path, route.Path)
} }
func (route *Route) Add(m Middleware) {
route.middles.Add(m)
}
func NewRoute(method string, path string, handleFn http.Handler) *Route {
ret := &Route{
Method: method,
Path: path,
middles: NewMiddlewareLink(),
}
p := ParseUrl(path)
//使用handleFn构建handler
ret.Handler = handleFn
ret.matcher = &p
return ret
}
type Routes []*Route type Routes []*Route
func (rs Routes) Less(i, j int) bool { func (rs Routes) Less(i, j int) bool {
@ -79,35 +94,21 @@ func (rs Routes) Swap(i, j int) {
type RestMux struct { type RestMux struct {
Path string Path string
routes Routes routes Routes
middlewares []Middleware middlewares *MiddlewareLink
} }
func (mux *RestMux) Use(m Middleware) { func (mux *RestMux) Use(m Middleware) {
mux.middlewares = append(mux.middlewares, m) mux.middlewares.Add(m)
} }
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
canContinue := false
c := r.Context() c := r.Context()
data := map[string]interface{}{} data := map[string]interface{}{}
cn := context.WithValue(c, RequestCtxKey("data"), data) cn := context.WithValue(c, RequestCtxKey("data"), data)
newRequest := r.WithContext(cn) newRequest := r.WithContext(cn)
if len(mux.middlewares) > 0 { if mux.middlewares.Len() > 0 && !mux.middlewares.ServeHTTP(w, newRequest) {
for _, m := range mux.middlewares {
canContinue = false
m(w, newRequest, func() { canContinue = true })
if !canContinue {
return return
} }
}
}
// _, has := lo.Find(mux.rmuxPaths, func(s string) bool {
// return strings.HasPrefix(newRequest.URL.Path, s)
// })
// if has {
// mux.imux.ServeHTTP(w, newRequest)
// return
// }
// 根据r 从routes中找到匹配的路由 // 根据r 从routes中找到匹配的路由
for _, route := range mux.routes { for _, route := range mux.routes {
@ -116,76 +117,58 @@ func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
} }
// newRequest.URL.Path = "/" + strings.ToLower(newRequest.Method) + newRequest.URL.Path
// newRequest.RequestURI = "/" + strings.ToLower(newRequest.Method) + newRequest.RequestURI
// h, _ := mux.imux.Handler(newRequest)
// h.ServeHTTP(w, newRequest)
http.NotFound(w, r) http.NotFound(w, r)
} }
func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) HandleFunc(method string, path string, f func(http.ResponseWriter, *http.Request)) *Route {
r := &Route{ r := NewRoute(method, path, http.HandlerFunc(f))
Method: method,
Path: path,
Handle: f,
Handler: nil,
}
*r.matcher = ParseUrl(path)
mux.routes = append(mux.routes, r) mux.routes = append(mux.routes, r)
sort.Sort(mux.routes) sort.Sort(mux.routes)
return r
} }
func (mux *RestMux) Get(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Get(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("GET", path, f) return mux.HandleFunc("GET", path, f)
} }
func (mux *RestMux) Post(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Post(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("POST", path, f) return mux.HandleFunc("POST", path, f)
} }
func (mux *RestMux) Put(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Put(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("PUT", path, f) return mux.HandleFunc("PUT", path, f)
} }
func (mux *RestMux) Delete(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Delete(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("DELETE", path, f) return mux.HandleFunc("DELETE", path, f)
} }
func (mux *RestMux) Patch(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Patch(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("PATCH", path, f) return mux.HandleFunc("PATCH", path, f)
} }
func (mux *RestMux) Head(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Head(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("HEAD", path, f) return mux.HandleFunc("HEAD", path, f)
} }
func (mux *RestMux) Option(path string, f func(http.ResponseWriter, *http.Request)) { func (mux *RestMux) Option(path string, f func(http.ResponseWriter, *http.Request)) *Route {
mux.HandleFunc("OPTION", path, f) return mux.HandleFunc("OPTION", path, f)
} }
func (mux *RestMux) HandleMux(nmux *RestMux) { func (mux *RestMux) HandleMux(nmux *RestMux) *Route {
p := nmux.Path p := nmux.Path
if !strings.HasSuffix(p, "/") { if !strings.HasSuffix(p, "/") {
p = p + "/" p = p + "/"
} }
// mux.imux.Handle(p, http.StripPrefix(nmux.Path, nmux)) r := NewRoute("", p, nmux)
// mux.rmuxPaths = append(mux.rmuxPaths, nmux.Path) mux.routes = append(mux.routes, r)
r := &Route{
Method: "",
Path: p,
Handle: nil,
Handler: nmux,
}
r.matcher = &UrlParamMatcher{
Params: []string{},
Reg: nil,
}
sort.Sort(mux.routes) sort.Sort(mux.routes)
return r
} }
func NewRestMux(path string) *RestMux { func NewRestMux(path string) *RestMux {
ret := &RestMux{ ret := &RestMux{
Path: path, Path: path,
routes: Routes{}, routes: Routes{},
middlewares: []Middleware{}, middlewares: NewMiddlewareLink(),
} }
return ret return ret
} }