From e22b39513e0055efff1d0999e292f941ffa0f9b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=B9=BF?= Date: Mon, 11 Dec 2023 18:15:29 +0800 Subject: [PATCH] refactor and add middleware --- admin/admin.go | 6 ++++++ gohttp.go | 7 +++---- model/model.go | 28 ++++++++++++++++++++++++++++ server/middleware.go | 32 ++++++++++++++++++++++++++++++++ server/server.go | 21 ++++++++++++++++++--- 5 files changed, 87 insertions(+), 7 deletions(-) create mode 100644 server/middleware.go diff --git a/admin/admin.go b/admin/admin.go index 5c9b713..f4db749 100644 --- a/admin/admin.go +++ b/admin/admin.go @@ -18,9 +18,14 @@ func about(w http.ResponseWriter, r *http.Request) { } +func setConfig(w http.ResponseWriter, r *http.Request) { + +} + var AdminRoutes = []Route{ // Admin Routes {"GET", "/about", about}, + {"Post", "/config", setConfig}, } var AdminServerMux *server.RestMux @@ -31,4 +36,5 @@ func init() { for _, route := range AdminRoutes { AdminServerMux.HandleFunc(route.Method, route.Path, route.Handle) } + AdminServerMux.Use(server.BasicAuth) } diff --git a/gohttp.go b/gohttp.go index 7b0bac5..2b92ac2 100644 --- a/gohttp.go +++ b/gohttp.go @@ -12,13 +12,12 @@ import ( "git.pyer.club/kingecg/gologger" ) -var conf model.GoHttpdConfig - type GoHttp struct { logger gologger.Logger } func (g *GoHttp) Start() { + conf := model.GetConfig() g.logger = gologger.GetLogger("Server") g.logger.Info("start gohttpd") // if g.conf != nil { @@ -78,8 +77,8 @@ func LoadConfig(configPath string) { // read content from cpath content, _ := os.ReadFile(cpath) - json.Unmarshal(content, &conf) - gologger.Configure(conf.Logging) + json.Unmarshal(content, &model.Config) + gologger.Configure(model.Config.Logging) logger := gologger.GetLogger("Server") logger.Info("Load config success") } diff --git a/model/model.go b/model/model.go index 0a0785e..5055517 100644 --- a/model/model.go +++ b/model/model.go @@ -16,10 +16,13 @@ type PathRewrite struct { } type HttpServerConfig struct { + Name string `json:"name"` ServerName string `json:"server"` Port int `json:"port"` Host string `json:"host"` Paths []HttpPath + Username string `json:"username"` + Password string `json:"password"` } type GoHttpdConfig struct { @@ -32,3 +35,28 @@ var DefaultAdminConfig HttpServerConfig = HttpServerConfig{ ServerName: "admin", Port: 8080, } + +var Config GoHttpdConfig = GoHttpdConfig{} + +func GetConfig() *GoHttpdConfig { + return &Config +} + +func SetServerConfig(c *HttpServerConfig) { + for i, s := range Config.Servers { + if s.Name == c.Name { + Config.Servers[i] = c + return + } + } + Config.Servers = append(Config.Servers, c) +} + +func GetServerConfig(name string) *HttpServerConfig { + for _, s := range Config.Servers { + if s.Name == name { + return s + } + } + return nil +} diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..4e27cb0 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,32 @@ +package server + +import ( + "net/http" + + "git.pyer.club/kingecg/gohttpd/model" +) + +type Middleware func(w http.ResponseWriter, r *http.Request, next func()) + +func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) { + config := model.GetConfig() + + if config.Admin.Username == "" || config.Admin.Password == "" { + next() + return + } + user, pass, ok := r.BasicAuth() + if ok && user == config.Admin.Username && pass == config.Admin.Password { + next() + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, "Unauthorized.", http.StatusUnauthorized) + } +} + +func Parse(w http.ResponseWriter, r *http.Request, next func()) { + if r.Method == "POST" || r.Method == "PUT" { + r.ParseForm() + } + next() +} diff --git a/server/server.go b/server/server.go index e7622af..313292d 100644 --- a/server/server.go +++ b/server/server.go @@ -9,12 +9,27 @@ import ( // 可以嵌套的Rest http server mux type RestMux struct { - Path string - imux *http.ServeMux - rmuxPaths []string + Path string + imux *http.ServeMux + rmuxPaths []string + middlewares []Middleware } +func (mux *RestMux) Use(m Middleware) { + mux.middlewares = append(mux.middlewares, m) +} func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + canContinue := false + if len(mux.middlewares) > 0 { + for _, m := range mux.middlewares { + canContinue = false + m(w, r, func() { canContinue = true }) + if !canContinue { + return + } + } + } + _, has := lo.Find(mux.rmuxPaths, func(s string) bool { return strings.HasPrefix(r.URL.Path, s) })