This commit is contained in:
kingecg 2024-11-13 22:26:29 +08:00
parent dfb199bdd1
commit 45427d2020
7 changed files with 180 additions and 18 deletions

View File

@ -3,9 +3,19 @@ package server
import (
"net/http"
"git.pyer.club/kingecg/gotunnelserver/util"
"github.com/gorilla/websocket"
)
func (s *Server) HandleAgent(c *websocket.Conn, r *http.Request) {
agentSession := NewSession(c)
agentSession.Start()
s.agentSession[agentSession.Id] = agentSession
command := &Command{
Type: util.NewSession,
Payload: map[string]string{
"id": agentSession.Id,
},
}
agentSession.Send(command)
}

View File

@ -7,5 +7,31 @@ import (
)
func (s *Server) HandleClient(conn *websocket.Conn, r *http.Request) {
conn.WriteMessage(websocket.TextMessage, []byte("hello"))
clientSession := NewSession(conn)
clientSession.Start()
s.clientSession[clientSession.Id] = clientSession
command := NcmdSession(clientSession.Id)
clientSession.Send(command)
clientSession.On("NewConnection", func(args ...interface{}) {
orgPlayload := args[0].(map[string]string) // orgPayload should include target, host and port field
targetSessionId, ok := orgPlayload["target"]
if !ok {
clientSession.Send(NewErrorResponse("target field is required", command))
return
}
targetSession := s.findSession(targetSessionId)
if targetSession == nil {
clientSession.Send(NewErrorResponse("target session not found", command))
return
}
command := NcmdConnectionInited(clientSession.Id)
for k, v := range orgPlayload {
command.Payload[k] = v
}
clientSession.Send(command)
targetSession.Send(command)
})
}

View File

@ -6,6 +6,89 @@ import (
"github.com/gorilla/websocket"
)
func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) {
conn.WriteMessage(websocket.TextMessage, []byte("hello"))
type Pipe struct {
Id string
Src string
Dst string
src *websocket.Conn
dst *websocket.Conn
stopChan chan int
}
func (p *Pipe) Start() {
p.stopChan = make(chan int)
go p.forward(p.src, p.dst)
go p.forward(p.dst, p.src)
<-p.stopChan
p.src.Close()
p.dst.Close()
}
func (p *Pipe) forward(src, dst *websocket.Conn) {
for {
mtype, message, err := src.ReadMessage()
if err != nil {
break
}
err = dst.WriteMessage(mtype, message)
if err != nil {
break
}
}
p.stopChan <- 1
}
func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) {
session := r.Header.Get("Session")
if session == "" {
s.logger.Error("no session header")
conn.Close()
return
}
cmdSession := s.findSession(session)
if cmdSession == nil {
conn.Close()
s.logger.Error("command session not found")
return
}
lpath := r.URL.Path
pipeId := lpath[len("/ws/pipe/"):]
pipe, ok := s.pipes[pipeId]
if !ok {
conn.Close()
s.logger.Error("pipe not found")
cmdSession.Send(NewErrorResponse("pipe not found", nil))
return
}
if pipe.Src == session {
pipe.src = conn
} else if pipe.Dst == session {
pipe.dst = conn
} else {
cmdSession.Send(NewErrorResponse("not endpoint of current pipe", nil))
conn.Close()
return
}
if pipe.src != nil && pipe.dst != nil {
pipe.Start()
clientCmdSession := s.findSession(pipe.Src)
clientCmdSession.Send(NcmdConnectionReady(pipe.Id)) // info src endpoint ready and can setup proxy listener
}
}
func (s *Server) findSession(sessionId string) *Session {
clientSession, ok := s.clientSession[sessionId]
if ok {
return clientSession
}
agentSession, ok := s.agentSession[sessionId]
if ok {
return agentSession
}
return nil
}

View File

@ -34,6 +34,9 @@ type Server struct {
mux *http.ServeMux
Config *ServerConfig
logger *gologger.Logger
clientSession map[string]*Session
agentSession map[string]*Session
pipes map[string]*Pipe
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -60,10 +63,10 @@ func (s *Server) auth(r *http.Request) (*http.Request, bool) {
if err != nil {
return r, false
}
if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
return r, false
}
if !util.VerifyAuth(authEntity, s.Config.Key) {
// if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
// return r, false
// }
if !util.VerifyAuth(authEntity, s.Config.Salt) {
return r, false
}
//set auth entity to request context
@ -137,6 +140,9 @@ func New(configFile string) *Server {
Server: &http.Server{},
Config: config,
logger: logger,
clientSession: make(map[string]*Session),
agentSession: make(map[string]*Session),
pipes: make(map[string]*Pipe),
}
return server
}

View File

@ -2,6 +2,7 @@ package server
import (
"context"
"encoding/json"
"git.pyer.club/kingecg/goemitter"
"git.pyer.club/kingecg/gotunnelserver/util"
@ -39,7 +40,7 @@ func (s *Session) Start() {
return
}
if cmd.Type == util.NewConnection {
s.Emit("NewSession", cmd.Payload)
s.Emit("NewConnection", cmd.Payload)
}
select {
case <-s.ctx.Done():
@ -63,3 +64,37 @@ func NewSession(conn *websocket.Conn) *Session {
EventEmitter: goemitter.NewEmitter(),
}
}
func NewCommand(t int, p map[string]string) *Command {
return &Command{
Type: t,
Payload: p,
}
}
func NcmdSession(sessionId string) *Command {
return NewCommand(util.NewSession, map[string]string{"sessionId": sessionId})
}
func NewErrorResponse(err string, cmd *Command) *Command {
payload := map[string]string{
"error": err,
// "originalcmd": strconv.Itoa(cmd.Type),
}
if cmd != nil {
originalCmd, _ := json.Marshal(cmd)
payload["originalcmd"] = string(originalCmd)
}
return NewCommand(util.ErrorCmd, payload)
}
func NcmdConnectionInited(sessionId string) *Command {
return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId})
}
func NcmdConnectionReady(sessionId string) *Command {
return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId})
}

View File

@ -3,5 +3,7 @@ package util
const (
NewSession = iota
NewConnection
AgentReady
ConnectInited //used to disptatch new connection session id
ConnectionReady
ErrorCmd
)

View File

@ -12,7 +12,7 @@ import (
type Entity struct {
Username string `json:"username"`
Token string `json:"token"`
// Token string `json:"token"`
}
type AuthEntity struct {
@ -31,7 +31,7 @@ func VerifyAuth(a *AuthEntity, salt string) bool {
}
func GenAuthToken(a *AuthEntity, salt string) string {
str := fmt.Sprintf("%s:%s:%d:%s", a.Username, a.Token, a.Time, salt)
str := fmt.Sprintf("%s:%d:%s", a.Username, a.Time, salt)
h := md5.New()
h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil))