diff --git a/go.mod b/go.mod index 66495b2..f3213d0 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.pyer.club/kingecg/gotunnelserver go 1.23.1 require ( + git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082 // indirect git.pyer.club/kingecg/gologger v1.0.5 // indirect github.com/gorilla/websocket v1.5.3 // indirect ) diff --git a/go.sum b/go.sum index 744a0bd..68276cd 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082 h1:U7Jbet3zObT2qPJ2g408Z9OUvR6phQyHOoHeidM5zUg= +git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082/go.mod h1:2jbknDqoWH41M3MQ9pQZDKBiNtDmNgPcM3XfkE9YkbQ= git.pyer.club/kingecg/gologger v1.0.5 h1:L/N/bleGHhEiaBYBf9U1z2ni0HfhaU71pk8ik/D11oo= git.pyer.club/kingecg/gologger v1.0.5/go.mod h1:SNSl2jRHPzIpHSzdKOoVG798rtYMjPDPFyxUrEgivkY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= diff --git a/server/main.go b/server/main.go index 68470e5..dee268b 100644 --- a/server/main.go +++ b/server/main.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "fmt" "os" @@ -10,6 +11,7 @@ import ( "net/http" "git.pyer.club/kingecg/gologger" + "git.pyer.club/kingecg/gotunnelserver/util" "github.com/gorilla/websocket" ) @@ -23,6 +25,7 @@ type ServerConfig struct { Host string `json:"host"` Cert string `json:"cert"` Key string `json:"key"` + Salt string `json:"salt"` Logger *gologger.LoggersConfig `json:"log"` } @@ -35,17 +38,38 @@ type Server struct { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logger.Info("request:", r.URL.Path) - - if !s.auth(r) { - s.HandleUnAutherized(w, r) + nr, ok := s.auth(r) + if !ok { + s.HandleUnAutherized(w, nr) } else { - s.mux.ServeHTTP(w, r) + s.mux.ServeHTTP(w, nr) } } -func (s *Server) auth(r *http.Request) bool { - s.logger.Info("auth:", r.RemoteAddr) - return true +func (s *Server) auth(r *http.Request) (*http.Request, bool) { + s.logger.Debug("auth:", r.RemoteAddr) + lpath := r.URL.Path + if lpath == "/" || lpath == "/hello" { + return r, true + } + token := r.Header.Get("Authorization") + if token == "" { + return r, false + } + authEntity, err := util.ParseAuthEntity(token) + if err != nil { + return r, false + } + if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" { + return r, false + } + if !util.VerifyAuth(authEntity, s.Config.Key) { + return r, false + } + //set auth entity to request context + ctx := r.Context() + nctx := context.WithValue(ctx, "authEntity", authEntity) + return r.WithContext(nctx), true } func (s *Server) Shutdown() { s.logger.Info("shutdown server") diff --git a/server/session.go b/server/session.go new file mode 100644 index 0000000..e8c892c --- /dev/null +++ b/server/session.go @@ -0,0 +1,65 @@ +package server + +import ( + "context" + + "git.pyer.club/kingecg/goemitter" + "git.pyer.club/kingecg/gotunnelserver/util" + "github.com/gorilla/websocket" +) + +type Session struct { + Id string + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + *goemitter.EventEmitter +} + +type Command struct { + Type int `json:"type"` + Payload map[string]string `json:"payload"` +} + +func (s *Session) Send(cmd *Command) (err error) { + return s.conn.WriteJSON(cmd) +} +func (s *Session) Close() { + s.conn.Close() +} + +func (s *Session) Start() { + s.ctx, s.cancel = context.WithCancel(context.Background()) + go func() { + defer s.conn.Close() + for { + var cmd Command + err := s.conn.ReadJSON(&cmd) + if err != nil { + return + } + if cmd.Type == util.NewConnection { + s.Emit("NewSession", cmd.Payload) + } + select { + case <-s.ctx.Done(): + return + default: + continue + } + } + }() + +} + +func (s *Session) Stop() { + s.cancel() +} + +func NewSession(conn *websocket.Conn) *Session { + return &Session{ + conn: conn, + Id: util.GenRandomstring(16), + EventEmitter: goemitter.NewEmitter(), + } +} diff --git a/util/const.go b/util/const.go new file mode 100644 index 0000000..c3ccc21 --- /dev/null +++ b/util/const.go @@ -0,0 +1,7 @@ +package util + +const ( + NewSession = iota + NewConnection + AgentReady +) diff --git a/util/entity.go b/util/entity.go new file mode 100644 index 0000000..5f1bee1 --- /dev/null +++ b/util/entity.go @@ -0,0 +1,59 @@ +package util + +import ( + "crypto/md5" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "math/rand" + "time" +) + +type Entity struct { + Username string `json:"username"` + Token string `json:"token"` +} + +type AuthEntity struct { + Entity + Time int64 `json:"time"` + Authtoken string `json:"authtoken"` +} + +func VerifyAuth(a *AuthEntity, salt string) bool { + current := time.Now() + if current.Sub(time.Unix(a.Time, 0)).Minutes() > 5 { + return false + } + + return a.Authtoken == GenAuthToken(a, salt) +} + +func GenAuthToken(a *AuthEntity, salt string) string { + str := fmt.Sprintf("%s:%s:%d:%s", a.Username, a.Token, a.Time, salt) + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} + +func ParseAuthEntity(str string) (a *AuthEntity, err error) { + dbytes, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return nil, err + } + a = new(AuthEntity) + err = json.Unmarshal(dbytes, a) + return a, err +} + +func GenRandomstring(n int) string { + // 生成长度为n的随机字符串 + // rand.Seed(time.Now().UnixNano()) + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +}