add session
This commit is contained in:
parent
547930d1ce
commit
dfb199bdd1
1
go.mod
1
go.mod
|
@ -3,6 +3,7 @@ module git.pyer.club/kingecg/gotunnelserver
|
|||
go 1.23.1
|
||||
|
||||
require (
|
||||
git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082 // indirect
|
||||
git.pyer.club/kingecg/gologger v1.0.5 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
)
|
||||
|
|
2
go.sum
2
go.sum
|
@ -1,3 +1,5 @@
|
|||
git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082 h1:U7Jbet3zObT2qPJ2g408Z9OUvR6phQyHOoHeidM5zUg=
|
||||
git.pyer.club/kingecg/goemitter v0.0.0-20240919084107-533c3d1be082/go.mod h1:2jbknDqoWH41M3MQ9pQZDKBiNtDmNgPcM3XfkE9YkbQ=
|
||||
git.pyer.club/kingecg/gologger v1.0.5 h1:L/N/bleGHhEiaBYBf9U1z2ni0HfhaU71pk8ik/D11oo=
|
||||
git.pyer.club/kingecg/gologger v1.0.5/go.mod h1:SNSl2jRHPzIpHSzdKOoVG798rtYMjPDPFyxUrEgivkY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"git.pyer.club/kingecg/gologger"
|
||||
"git.pyer.club/kingecg/gotunnelserver/util"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
|
@ -23,6 +25,7 @@ type ServerConfig struct {
|
|||
Host string `json:"host"`
|
||||
Cert string `json:"cert"`
|
||||
Key string `json:"key"`
|
||||
Salt string `json:"salt"`
|
||||
Logger *gologger.LoggersConfig `json:"log"`
|
||||
}
|
||||
|
||||
|
@ -35,17 +38,38 @@ type Server struct {
|
|||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.Info("request:", r.URL.Path)
|
||||
|
||||
if !s.auth(r) {
|
||||
s.HandleUnAutherized(w, r)
|
||||
nr, ok := s.auth(r)
|
||||
if !ok {
|
||||
s.HandleUnAutherized(w, nr)
|
||||
} else {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
s.mux.ServeHTTP(w, nr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) auth(r *http.Request) bool {
|
||||
s.logger.Info("auth:", r.RemoteAddr)
|
||||
return true
|
||||
func (s *Server) auth(r *http.Request) (*http.Request, bool) {
|
||||
s.logger.Debug("auth:", r.RemoteAddr)
|
||||
lpath := r.URL.Path
|
||||
if lpath == "/" || lpath == "/hello" {
|
||||
return r, true
|
||||
}
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
return r, false
|
||||
}
|
||||
authEntity, err := util.ParseAuthEntity(token)
|
||||
if err != nil {
|
||||
return r, false
|
||||
}
|
||||
if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
|
||||
return r, false
|
||||
}
|
||||
if !util.VerifyAuth(authEntity, s.Config.Key) {
|
||||
return r, false
|
||||
}
|
||||
//set auth entity to request context
|
||||
ctx := r.Context()
|
||||
nctx := context.WithValue(ctx, "authEntity", authEntity)
|
||||
return r.WithContext(nctx), true
|
||||
}
|
||||
func (s *Server) Shutdown() {
|
||||
s.logger.Info("shutdown server")
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.pyer.club/kingecg/goemitter"
|
||||
"git.pyer.club/kingecg/gotunnelserver/util"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Id string
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
*goemitter.EventEmitter
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
Type int `json:"type"`
|
||||
Payload map[string]string `json:"payload"`
|
||||
}
|
||||
|
||||
func (s *Session) Send(cmd *Command) (err error) {
|
||||
return s.conn.WriteJSON(cmd)
|
||||
}
|
||||
func (s *Session) Close() {
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *Session) Start() {
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
go func() {
|
||||
defer s.conn.Close()
|
||||
for {
|
||||
var cmd Command
|
||||
err := s.conn.ReadJSON(&cmd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if cmd.Type == util.NewConnection {
|
||||
s.Emit("NewSession", cmd.Payload)
|
||||
}
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
func (s *Session) Stop() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func NewSession(conn *websocket.Conn) *Session {
|
||||
return &Session{
|
||||
conn: conn,
|
||||
Id: util.GenRandomstring(16),
|
||||
EventEmitter: goemitter.NewEmitter(),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
package util
|
||||
|
||||
const (
|
||||
NewSession = iota
|
||||
NewConnection
|
||||
AgentReady
|
||||
)
|
|
@ -0,0 +1,59 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Entity struct {
|
||||
Username string `json:"username"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type AuthEntity struct {
|
||||
Entity
|
||||
Time int64 `json:"time"`
|
||||
Authtoken string `json:"authtoken"`
|
||||
}
|
||||
|
||||
func VerifyAuth(a *AuthEntity, salt string) bool {
|
||||
current := time.Now()
|
||||
if current.Sub(time.Unix(a.Time, 0)).Minutes() > 5 {
|
||||
return false
|
||||
}
|
||||
|
||||
return a.Authtoken == GenAuthToken(a, salt)
|
||||
}
|
||||
|
||||
func GenAuthToken(a *AuthEntity, salt string) string {
|
||||
str := fmt.Sprintf("%s:%s:%d:%s", a.Username, a.Token, a.Time, salt)
|
||||
h := md5.New()
|
||||
h.Write([]byte(str))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func ParseAuthEntity(str string) (a *AuthEntity, err error) {
|
||||
dbytes, err := base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a = new(AuthEntity)
|
||||
err = json.Unmarshal(dbytes, a)
|
||||
return a, err
|
||||
}
|
||||
|
||||
func GenRandomstring(n int) string {
|
||||
// 生成长度为n的随机字符串
|
||||
// rand.Seed(time.Now().UnixNano())
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
Loading…
Reference in New Issue