refactor and add middleware

This commit is contained in:
程广 2023-12-11 18:15:29 +08:00
parent 8c4972acb4
commit e22b39513e
5 changed files with 87 additions and 7 deletions

View File

@ -18,9 +18,14 @@ func about(w http.ResponseWriter, r *http.Request) {
} }
func setConfig(w http.ResponseWriter, r *http.Request) {
}
var AdminRoutes = []Route{ var AdminRoutes = []Route{
// Admin Routes // Admin Routes
{"GET", "/about", about}, {"GET", "/about", about},
{"Post", "/config", setConfig},
} }
var AdminServerMux *server.RestMux var AdminServerMux *server.RestMux
@ -31,4 +36,5 @@ func init() {
for _, route := range AdminRoutes { for _, route := range AdminRoutes {
AdminServerMux.HandleFunc(route.Method, route.Path, route.Handle) AdminServerMux.HandleFunc(route.Method, route.Path, route.Handle)
} }
AdminServerMux.Use(server.BasicAuth)
} }

View File

@ -12,13 +12,12 @@ import (
"git.pyer.club/kingecg/gologger" "git.pyer.club/kingecg/gologger"
) )
var conf model.GoHttpdConfig
type GoHttp struct { type GoHttp struct {
logger gologger.Logger logger gologger.Logger
} }
func (g *GoHttp) Start() { func (g *GoHttp) Start() {
conf := model.GetConfig()
g.logger = gologger.GetLogger("Server") g.logger = gologger.GetLogger("Server")
g.logger.Info("start gohttpd") g.logger.Info("start gohttpd")
// if g.conf != nil { // if g.conf != nil {
@ -78,8 +77,8 @@ func LoadConfig(configPath string) {
// read content from cpath // read content from cpath
content, _ := os.ReadFile(cpath) content, _ := os.ReadFile(cpath)
json.Unmarshal(content, &conf) json.Unmarshal(content, &model.Config)
gologger.Configure(conf.Logging) gologger.Configure(model.Config.Logging)
logger := gologger.GetLogger("Server") logger := gologger.GetLogger("Server")
logger.Info("Load config success") logger.Info("Load config success")
} }

View File

@ -16,10 +16,13 @@ type PathRewrite struct {
} }
type HttpServerConfig struct { type HttpServerConfig struct {
Name string `json:"name"`
ServerName string `json:"server"` ServerName string `json:"server"`
Port int `json:"port"` Port int `json:"port"`
Host string `json:"host"` Host string `json:"host"`
Paths []HttpPath Paths []HttpPath
Username string `json:"username"`
Password string `json:"password"`
} }
type GoHttpdConfig struct { type GoHttpdConfig struct {
@ -32,3 +35,28 @@ var DefaultAdminConfig HttpServerConfig = HttpServerConfig{
ServerName: "admin", ServerName: "admin",
Port: 8080, Port: 8080,
} }
var Config GoHttpdConfig = GoHttpdConfig{}
func GetConfig() *GoHttpdConfig {
return &Config
}
func SetServerConfig(c *HttpServerConfig) {
for i, s := range Config.Servers {
if s.Name == c.Name {
Config.Servers[i] = c
return
}
}
Config.Servers = append(Config.Servers, c)
}
func GetServerConfig(name string) *HttpServerConfig {
for _, s := range Config.Servers {
if s.Name == name {
return s
}
}
return nil
}

32
server/middleware.go Normal file
View File

@ -0,0 +1,32 @@
package server
import (
"net/http"
"git.pyer.club/kingecg/gohttpd/model"
)
type Middleware func(w http.ResponseWriter, r *http.Request, next func())
func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
config := model.GetConfig()
if config.Admin.Username == "" || config.Admin.Password == "" {
next()
return
}
user, pass, ok := r.BasicAuth()
if ok && user == config.Admin.Username && pass == config.Admin.Password {
next()
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
}
func Parse(w http.ResponseWriter, r *http.Request, next func()) {
if r.Method == "POST" || r.Method == "PUT" {
r.ParseForm()
}
next()
}

View File

@ -9,12 +9,27 @@ import (
// 可以嵌套的Rest http server mux // 可以嵌套的Rest http server mux
type RestMux struct { type RestMux struct {
Path string Path string
imux *http.ServeMux imux *http.ServeMux
rmuxPaths []string rmuxPaths []string
middlewares []Middleware
} }
func (mux *RestMux) Use(m Middleware) {
mux.middlewares = append(mux.middlewares, m)
}
func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (mux *RestMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
canContinue := false
if len(mux.middlewares) > 0 {
for _, m := range mux.middlewares {
canContinue = false
m(w, r, func() { canContinue = true })
if !canContinue {
return
}
}
}
_, has := lo.Find(mux.rmuxPaths, func(s string) bool { _, has := lo.Find(mux.rmuxPaths, func(s string) bool {
return strings.HasPrefix(r.URL.Path, s) return strings.HasPrefix(r.URL.Path, s)
}) })