diff --git a/admin/admin.go b/admin/admin.go index a313216..a022093 100644 --- a/admin/admin.go +++ b/admin/admin.go @@ -81,11 +81,13 @@ var AdminServerMux *server.RestMux func init() { AdminServerMux = server.NewRestMux("/api") - AdminServerMux.Use(server.BasicAuth) + AdminServerMux.Use(server.JwtAuth) AdminServerMux.HandleFunc("GET", "/about", http.HandlerFunc(about)) postConfigRoute := AdminServerMux.HandleFunc("POST", "/config", http.HandlerFunc(setConfig)) postConfigRoute.Add(server.Parse[model.HttpServerConfig]) AdminServerMux.HandleFunc("GET", "/config/:id", http.HandlerFunc(getServerConfigure)) AdminServerMux.HandleFunc("GET", "/status", http.HandlerFunc(getStatus)) + loginRoute := AdminServerMux.HandleFunc("POST", "/login", http.HandlerFunc(login)) + loginRoute.Add(server.Parse[LoginModel]) // AdminServerMux.Use(server.BasicAuth) } diff --git a/admin/login.go b/admin/login.go new file mode 100644 index 0000000..3af48c0 --- /dev/null +++ b/admin/login.go @@ -0,0 +1,94 @@ +package admin + +import ( + "errors" + "net/http" + "time" + + "git.pyer.club/kingecg/gohttpd/server" + "github.com/golang-jwt/jwt/v5" + "pyer.club/kingecg/gohttpd/model" +) + +type LoginModel struct { + Username string `json:"username"` + Encrypt string `json:"password"` +} + +func login(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctxData := ctx.Value(server.RequestCtxKey("data")).(map[string]interface{}) + data, ok := ctxData["data"] + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + t := data.(LoginModel) + if t.Username == "admin" { + decryptText, _ := Decrypt(t.Encrypt) + if decryptText == model.GetConfig().Admin.Password { + token, err := GenerateToken(t.Username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write(server.NewErrorResult(err)) + return + } + w.WriteHeader(http.StatusOK) + http.SetCookie(w, &http.Cookie{ + Name: "token", + Value: token, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + Expires: time.Now().Add(time.Hour * 24 * 7), + }) + w.Write(server.NewSuccessResult(token)) + } + } else { + w.WriteHeader(http.StatusForbidden) + resp := server.NewErrorResult(errors.New("Not allowed user/password")) + w.Write(resp) + } + return +} + +// 实现非对称加密 +func Encrypt(plaintext string) (string, error) { + ciphertext := make([]byte, len(plaintext)) + for i := 0; i < len(plaintext); i++ { + ciphertext[i] = plaintext[i] ^ 0xFF + } + return string(ciphertext), nil +} + +// 实现非对称解密 +func Decrypt(ciphertext string) (string, error) { + plaintext := make([]byte, len(ciphertext)) + for i := 0; i < len(ciphertext); i++ { + plaintext[i] = ciphertext[i] ^ 0xFF + } + //去除末尾13个字节 + plaintext = plaintext[:len(plaintext)-13] + return string(plaintext), nil +} + +// 生成token +func GenerateToken(username string) (string, error) { + // jwt token + jwtConfig := model.GetConfig().Jwt + secret := jwtConfig.Secret + expire := jwtConfig.Expire + issuer := jwtConfig.Issuer + audience := jwtConfig.Audience + claim := &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(expire) * time.Hour)), + Issuer: issuer, + Audience: []string{audience}, + IssuedAt: jwt.NewNumericDate(time.Now()), + Subject: username, + } + // 生成token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) + return token.SignedString([]byte(secret)) +} diff --git a/go.mod b/go.mod index 11b5e9e..64f6a36 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-playground/locales v0.14.0 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect github.com/go-playground/validator/v10 v10.9.0 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.13.6 // indirect diff --git a/go.sum b/go.sum index 070ce83..7ef7618 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.9.0 h1:NgTtmN58D0m8+UuxtYmGztBJB7VnPgjj221I1QHci2A= github.com/go-playground/validator/v10 v10.9.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= diff --git a/model/model.go b/model/model.go index fd2940d..63e3581 100644 --- a/model/model.go +++ b/model/model.go @@ -28,11 +28,20 @@ type HttpServerConfig struct { Port int `json:"port"` Host string `json:"host"` Paths []HttpPath - Username string `json:"username"` - Password string `json:"password"` - CertFile string `json:"certfile"` - KeyFile string `json:"keyfile"` - Directives []string `json:"directives"` + Username string `json:"username"` + Password string `json:"password"` + CertFile string `json:"certfile"` + KeyFile string `json:"keyfile"` + Directives []string `json:"directives"` + AuthType string `json:"auth_type"` + Jwt *JwtConfig `json:"jwt"` +} + +type JwtConfig struct { + Secret string `json:"secret"` + Expire int `json:"expire"` + Issuer string `json:"issuer"` + Audience string `json:"audience"` } type GoHttpdConfig struct { diff --git a/server/middleware.go b/server/middleware.go index d356c92..80b749f 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -2,12 +2,17 @@ package server import ( "container/list" + "context" "encoding/json" + "fmt" "net/http" + "path" "reflect" "strings" "git.pyer.club/kingecg/gohttpd/model" + "git.pyer.club/kingecg/gologger" + "github.com/golang-jwt/jwt/v5" ) type Middleware func(w http.ResponseWriter, r *http.Request, next http.Handler) @@ -55,14 +60,29 @@ func (ml *MiddlewareLink) Add(m Middleware) { // } // return canContinue // } + +func (ml *MiddlewareLink) wrap(m Middleware, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m(w, r, next) + }) +} func (ml *MiddlewareLink) WrapHandler(next http.Handler) http.Handler { + if ml.Back() == nil { + return next + } + + var handler http.Handler = next for e := ml.Back(); e != nil; e = e.Prev() { - next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - e.Value.(Middleware)(w, r, next) - }) + + middleware, ok := e.Value.(Middleware) + if !ok { + break + } + + handler = ml.wrap(middleware, handler) } - return next + return handler } func NewMiddlewareLink() *MiddlewareLink { ml := &MiddlewareLink{list.New()} @@ -85,6 +105,41 @@ func BasicAuth(w http.ResponseWriter, r *http.Request, next http.Handler) { http.Error(w, "Unauthorized.", http.StatusUnauthorized) } } +func JwtAuth(w http.ResponseWriter, r *http.Request, next http.Handler) { + l := gologger.GetLogger("JwtAuth") + config := model.GetConfig() + jwtConfig := config.Jwt + if jwtConfig.Secret == "" || path.Base(r.URL.Path) == "login" { + next.ServeHTTP(w, r) + return + } + // 从cookie中获取token + tokenCookie, err := r.Cookie("auth_token") + if err != nil { + http.Error(w, "Unauthorized.", http.StatusUnauthorized) + return + } + tokenString := tokenCookie.Value + token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { + // 确保签名方法是正确的 + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(jwtConfig.Secret), nil + }) + if err != nil { + l.Error("Failed to parse JWT: %v", err) + http.Error(w, "Unauthorized.", http.StatusUnauthorized) + return + } + if claims, ok := token.Claims.(*jwt.RegisteredClaims); ok && token.Valid { + // 验证通过,将用户信息存储在请求上下文中 + ctx := context.WithValue(r.Context(), "user", claims) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + http.Error(w, "Unauthorized.", http.StatusUnauthorized) +} func RecordAccess(w http.ResponseWriter, r *http.Request, next http.Handler) { model.Incr(r.Host) diff --git a/server/server.go b/server/server.go index 3be0dee..53d66a3 100644 --- a/server/server.go +++ b/server/server.go @@ -33,7 +33,7 @@ func (route *Route) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (route *Route) Match(r *http.Request) bool { l := logger.GetLogger("Route") - l.Debug(fmt.Sprintf("match route: %s %s", r.Method, r.URL.Path)) + l.Debug(fmt.Sprintf("matching route: %s %s with %s %s", r.Method, r.URL.Path, route.Method, route.Path)) if route.Method != "" && route.Method != r.Method { l.Debug("method not match") return false