gohttp/server/middleware.go

184 lines
4.5 KiB
Go

package server
import (
"container/list"
"context"
"encoding/json"
"fmt"
"net/http"
"path"
"reflect"
"strings"
"git.pyer.club/kingecg/gohttpd/model"
"git.pyer.club/kingecg/gologger"
"github.com/golang-jwt/jwt/v5"
)
type Middleware func(w http.ResponseWriter, r *http.Request, next http.Handler)
type MiddlewareLink struct {
*list.List
}
func IsEqualMiddleware(a, b Middleware) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
}
func (ml *MiddlewareLink) Add(m Middleware) {
if m == nil {
return
}
if ml.List.Len() == 0 {
ml.PushBack(m)
} else {
if IsEqualMiddleware(m, Done) {
return
}
el := ml.Back()
ml.InsertBefore(m, el)
}
}
// func (ml *MiddlewareLink) ServeHTTP(w http.ResponseWriter, r *http.Request) bool {
// canContinue := true
// next := func() {
// canContinue = true
// }
// for e := ml.Front(); e != nil && canContinue; e = e.Next() {
// canContinue = false
// e.Value.(Middleware)(w, r, next)
// if !canContinue {
// break
// }
// }
// return canContinue
// }
func (ml *MiddlewareLink) wrap(m Middleware, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m(w, r, next)
})
}
func (ml *MiddlewareLink) WrapHandler(next http.Handler) http.Handler {
if ml.Back() == nil {
return next
}
var handler http.Handler = next
for e := ml.Back(); e != nil; e = e.Prev() {
middleware, ok := e.Value.(Middleware)
if !ok {
break
}
handler = ml.wrap(middleware, handler)
}
return handler
}
func NewMiddlewareLink() *MiddlewareLink {
ml := &MiddlewareLink{list.New()}
ml.Add(Done)
return ml
}
func BasicAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
config := model.GetConfig()
if config.Admin.Username == "" || config.Admin.Password == "" {
next.ServeHTTP(w, r)
return
}
user, pass, ok := r.BasicAuth()
if ok && user == config.Admin.Username && pass == config.Admin.Password {
next.ServeHTTP(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
}
func JwtAuth(w http.ResponseWriter, r *http.Request, next http.Handler) {
l := gologger.GetLogger("JwtAuth")
config := model.GetConfig()
jwtConfig := config.Jwt
if jwtConfig.Secret == "" || path.Base(r.URL.Path) == "login" {
next.ServeHTTP(w, r)
return
}
// 从cookie中获取token
tokenCookie, err := r.Cookie("auth_token")
if err != nil {
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
return
}
tokenString := tokenCookie.Value
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
// 确保签名方法是正确的
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(jwtConfig.Secret), nil
})
if err != nil {
l.Error("Failed to parse JWT: %v", err)
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
return
}
if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid {
// 验证通过,将用户信息存储在请求上下文中
ctx := context.WithValue(r.Context(), "user", claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
http.Error(w, "Unauthorized.", http.StatusUnauthorized)
}
func RecordAccess(w http.ResponseWriter, r *http.Request, next http.Handler) {
model.Incr(r.Host)
next.ServeHTTP(w, r)
}
func Parse[T any](w http.ResponseWriter, r *http.Request, next http.Handler) {
if r.Method == "POST" || r.Method == "PUT" {
//判断r的content-type是否是application/x-www-form-urlencoded
if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
r.ParseForm()
} else if r.Header.Get("Content-Type") == "multipart/form-data" {
r.ParseMultipartForm(r.ContentLength)
} else {
// 判断r的content-type是否是application/json
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
var t T
// 读取json数据
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ctx := r.Context()
m := ctx.Value(RequestCtxKey("data")).(map[string]interface{})
if m != nil {
m["data"] = t
}
}
}
}
next.ServeHTTP(w, r)
}
func Done(w http.ResponseWriter, r *http.Request, next http.Handler) {
next.ServeHTTP(w, r)
}