gohttp/server/middleware.go

118 lines
2.5 KiB
Go
Raw Normal View History

2023-12-11 18:15:29 +08:00
package server
import (
2023-12-12 21:44:35 +08:00
"container/list"
2023-12-11 23:46:40 +08:00
"encoding/json"
2023-12-11 18:15:29 +08:00
"net/http"
2023-12-12 22:58:30 +08:00
"reflect"
2023-12-11 23:46:40 +08:00
"strings"
2023-12-11 18:15:29 +08:00
"git.pyer.club/kingecg/gohttpd/model"
)
type Middleware func(w http.ResponseWriter, r *http.Request, next func())
2023-12-12 21:44:35 +08:00
type MiddlewareLink struct {
*list.List
}
2023-12-12 22:58:30 +08:00
func IsEqualMiddleware(a, b Middleware) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
}
2023-12-12 21:44:35 +08:00
func (ml *MiddlewareLink) Add(m Middleware) {
2023-12-12 22:58:30 +08:00
if m == nil {
return
}
2023-12-12 21:44:35 +08:00
if ml.List.Len() == 0 {
ml.PushBack(m)
} else {
2023-12-12 22:58:30 +08:00
if IsEqualMiddleware(m, Done) {
return
}
2023-12-12 21:44:35 +08:00
el := ml.Back()
ml.InsertBefore(m, el)
}
}
func (ml *MiddlewareLink) ServeHTTP(w http.ResponseWriter, r *http.Request) bool {
2023-12-12 22:58:30 +08:00
canContinue := true
2023-12-12 21:44:35 +08:00
next := func() {
canContinue = true
}
for e := ml.Front(); e != nil && canContinue; e = e.Next() {
2023-12-12 22:58:30 +08:00
canContinue = false
2023-12-12 21:44:35 +08:00
e.Value.(Middleware)(w, r, next)
if !canContinue {
break
}
}
return canContinue
}
func NewMiddlewareLink() *MiddlewareLink {
ml := &MiddlewareLink{list.New()}
ml.Add(Done)
return ml
}
2023-12-11 18:15:29 +08:00
func BasicAuth(w http.ResponseWriter, r *http.Request, next func()) {
config := model.GetConfig()
if config.Admin.Username == "" || config.Admin.Password == "" {
next()
return
}
user, pass, ok := r.BasicAuth()
if ok && user == config.Admin.Username && pass == config.Admin.Password {
next()
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
}
2023-12-11 23:46:40 +08:00
func Parse[T any](w http.ResponseWriter, r *http.Request, next func()) {
2023-12-11 18:15:29 +08:00
if r.Method == "POST" || r.Method == "PUT" {
2023-12-11 23:46:40 +08:00
//判断r的content-type是否是application/x-www-form-urlencoded
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
r.ParseForm()
next()
return
}
if r.Header.Get("Content-Type") == "multipart/form-data" {
r.ParseMultipartForm(r.ContentLength)
next()
return
}
// 判断r的content-type是否是application/json
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
var t T
// 读取json数据
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ctx := r.Context()
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
if m != nil {
m["data"] = t
}
next()
return
}
2023-12-11 18:15:29 +08:00
}
next()
}
2023-12-12 21:44:35 +08:00
func Done(w http.ResponseWriter, r *http.Request, next func()) {
next()
}