diff --git a/admin/admin.go b/admin/admin.go index c2fd335..a022093 100644 --- a/admin/admin.go +++ b/admin/admin.go @@ -81,7 +81,7 @@ var AdminServerMux *server.RestMux func init() { AdminServerMux = server.NewRestMux("/api") - AdminServerMux.Use(server.BasicAuth) + AdminServerMux.Use(server.JwtAuth) AdminServerMux.HandleFunc("GET", "/about", http.HandlerFunc(about)) postConfigRoute := AdminServerMux.HandleFunc("POST", "/config", http.HandlerFunc(setConfig)) postConfigRoute.Add(server.Parse[model.HttpServerConfig]) diff --git a/server/middleware.go b/server/middleware.go index 909fbcd..80b749f 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -2,12 +2,17 @@ package server import ( "container/list" + "context" "encoding/json" + "fmt" "net/http" + "path" "reflect" "strings" "git.pyer.club/kingecg/gohttpd/model" + "git.pyer.club/kingecg/gologger" + "github.com/golang-jwt/jwt/v5" ) type Middleware func(w http.ResponseWriter, r *http.Request, next http.Handler) @@ -100,6 +105,41 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next http.Handler) { http.Error(w, "Unauthorized.", http.StatusUnauthorized) } } +func JwtAuth(w http.ResponseWriter, r *http.Request, next http.Handler) { + l := gologger.GetLogger("JwtAuth") + config := model.GetConfig() + jwtConfig := config.Jwt + if jwtConfig.Secret == "" || path.Base(r.URL.Path) == "login" { + next.ServeHTTP(w, r) + return + } + // 从cookie中获取token + tokenCookie, err := r.Cookie("auth_token") + if err != nil { + http.Error(w, "Unauthorized.", http.StatusUnauthorized) + return + } + tokenString := tokenCookie.Value + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + // 确保签名方法是正确的 + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(jwtConfig.Secret), nil + }) + if err != nil { + l.Error("Failed to parse JWT: %v", err) + http.Error(w, "Unauthorized.", http.StatusUnauthorized) + return + } + if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid { + // 验证通过,将用户信息存储在请求上下文中 + ctx := context.WithValue(r.Context(), "user", claims) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + http.Error(w, "Unauthorized.", http.StatusUnauthorized) +} func RecordAccess(w http.ResponseWriter, r *http.Request, next http.Handler) { model.Incr(r.Host)