diff --git a/server/handleAgent.go b/server/handleAgent.go index 4fd6597..4c0b8f6 100644 --- a/server/handleAgent.go +++ b/server/handleAgent.go @@ -3,9 +3,19 @@ package server import ( "net/http" + "git.pyer.club/kingecg/gotunnelserver/util" "github.com/gorilla/websocket" ) func (s *Server) HandleAgent(c *websocket.Conn, r *http.Request) { - + agentSession := NewSession(c) + agentSession.Start() + s.agentSession[agentSession.Id] = agentSession + command := &Command{ + Type: util.NewSession, + Payload: map[string]string{ + "id": agentSession.Id, + }, + } + agentSession.Send(command) } diff --git a/server/handleClient.go b/server/handleClient.go index f78c64e..5be212e 100644 --- a/server/handleClient.go +++ b/server/handleClient.go @@ -7,5 +7,31 @@ import ( ) func (s *Server) HandleClient(conn *websocket.Conn, r *http.Request) { - conn.WriteMessage(websocket.TextMessage, []byte("hello")) + clientSession := NewSession(conn) + clientSession.Start() + s.clientSession[clientSession.Id] = clientSession + command := NcmdSession(clientSession.Id) + clientSession.Send(command) + clientSession.On("NewConnection", func(args ...interface{}) { + orgPlayload := args[0].(map[string]string) // orgPayload should include target, host and port field + targetSessionId, ok := orgPlayload["target"] + + if !ok { + clientSession.Send(NewErrorResponse("target field is required", command)) + return + } + targetSession := s.findSession(targetSessionId) + if targetSession == nil { + clientSession.Send(NewErrorResponse("target session not found", command)) + return + } + command := NcmdConnectionInited(clientSession.Id) + + for k, v := range orgPlayload { + command.Payload[k] = v + } + clientSession.Send(command) + targetSession.Send(command) + + }) } diff --git a/server/handlePipe.go b/server/handlePipe.go index dd75c2c..9c8cf4e 100644 --- a/server/handlePipe.go +++ b/server/handlePipe.go @@ -6,6 +6,89 @@ import ( "github.com/gorilla/websocket" ) -func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) { - conn.WriteMessage(websocket.TextMessage, []byte("hello")) +type Pipe struct { + Id string + Src string + Dst string + src *websocket.Conn + dst *websocket.Conn + stopChan chan int +} + +func (p *Pipe) Start() { + p.stopChan = make(chan int) + go p.forward(p.src, p.dst) + go p.forward(p.dst, p.src) + <-p.stopChan + p.src.Close() + p.dst.Close() +} + +func (p *Pipe) forward(src, dst *websocket.Conn) { + for { + mtype, message, err := src.ReadMessage() + if err != nil { + + break + } + err = dst.WriteMessage(mtype, message) + if err != nil { + + break + } + } + p.stopChan <- 1 +} + +func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) { + session := r.Header.Get("Session") + if session == "" { + s.logger.Error("no session header") + conn.Close() + return + } + cmdSession := s.findSession(session) + if cmdSession == nil { + conn.Close() + s.logger.Error("command session not found") + return + } + lpath := r.URL.Path + pipeId := lpath[len("/ws/pipe/"):] + pipe, ok := s.pipes[pipeId] + + if !ok { + conn.Close() + s.logger.Error("pipe not found") + cmdSession.Send(NewErrorResponse("pipe not found", nil)) + return + } + + if pipe.Src == session { + pipe.src = conn + } else if pipe.Dst == session { + pipe.dst = conn + } else { + cmdSession.Send(NewErrorResponse("not endpoint of current pipe", nil)) + conn.Close() + return + } + + if pipe.src != nil && pipe.dst != nil { + pipe.Start() + clientCmdSession := s.findSession(pipe.Src) + clientCmdSession.Send(NcmdConnectionReady(pipe.Id)) // info src endpoint ready and can setup proxy listener + } +} + +func (s *Server) findSession(sessionId string) *Session { + clientSession, ok := s.clientSession[sessionId] + if ok { + return clientSession + } + agentSession, ok := s.agentSession[sessionId] + if ok { + return agentSession + } + return nil } diff --git a/server/main.go b/server/main.go index dee268b..4ff4a49 100644 --- a/server/main.go +++ b/server/main.go @@ -31,9 +31,12 @@ type ServerConfig struct { type Server struct { *http.Server - mux *http.ServeMux - Config *ServerConfig - logger *gologger.Logger + mux *http.ServeMux + Config *ServerConfig + logger *gologger.Logger + clientSession map[string]*Session + agentSession map[string]*Session + pipes map[string]*Pipe } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -60,10 +63,10 @@ func (s *Server) auth(r *http.Request) (*http.Request, bool) { if err != nil { return r, false } - if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" { - return r, false - } - if !util.VerifyAuth(authEntity, s.Config.Key) { + // if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" { + // return r, false + // } + if !util.VerifyAuth(authEntity, s.Config.Salt) { return r, false } //set auth entity to request context @@ -134,9 +137,12 @@ func New(configFile string) *Server { logger := gologger.GetLogger("server") logger.Info("create server") server := &Server{ - Server: &http.Server{}, - Config: config, - logger: logger, + Server: &http.Server{}, + Config: config, + logger: logger, + clientSession: make(map[string]*Session), + agentSession: make(map[string]*Session), + pipes: make(map[string]*Pipe), } return server } diff --git a/server/session.go b/server/session.go index e8c892c..219cc38 100644 --- a/server/session.go +++ b/server/session.go @@ -2,6 +2,7 @@ package server import ( "context" + "encoding/json" "git.pyer.club/kingecg/goemitter" "git.pyer.club/kingecg/gotunnelserver/util" @@ -39,7 +40,7 @@ func (s *Session) Start() { return } if cmd.Type == util.NewConnection { - s.Emit("NewSession", cmd.Payload) + s.Emit("NewConnection", cmd.Payload) } select { case <-s.ctx.Done(): @@ -63,3 +64,37 @@ func NewSession(conn *websocket.Conn) *Session { EventEmitter: goemitter.NewEmitter(), } } + +func NewCommand(t int, p map[string]string) *Command { + return &Command{ + Type: t, + Payload: p, + } +} + +func NcmdSession(sessionId string) *Command { + return NewCommand(util.NewSession, map[string]string{"sessionId": sessionId}) +} + +func NewErrorResponse(err string, cmd *Command) *Command { + payload := map[string]string{ + "error": err, + // "originalcmd": strconv.Itoa(cmd.Type), + } + if cmd != nil { + + originalCmd, _ := json.Marshal(cmd) + payload["originalcmd"] = string(originalCmd) + + } + + return NewCommand(util.ErrorCmd, payload) +} + +func NcmdConnectionInited(sessionId string) *Command { + return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId}) +} + +func NcmdConnectionReady(sessionId string) *Command { + return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId}) +} diff --git a/util/const.go b/util/const.go index c3ccc21..b6ba165 100644 --- a/util/const.go +++ b/util/const.go @@ -3,5 +3,7 @@ package util const ( NewSession = iota NewConnection - AgentReady + ConnectInited //used to disptatch new connection session id + ConnectionReady + ErrorCmd ) diff --git a/util/entity.go b/util/entity.go index 5f1bee1..17b1cc6 100644 --- a/util/entity.go +++ b/util/entity.go @@ -12,7 +12,7 @@ import ( type Entity struct { Username string `json:"username"` - Token string `json:"token"` + // Token string `json:"token"` } type AuthEntity struct { @@ -31,7 +31,7 @@ func VerifyAuth(a *AuthEntity, salt string) bool { } func GenAuthToken(a *AuthEntity, salt string) string { - str := fmt.Sprintf("%s:%s:%d:%s", a.Username, a.Token, a.Time, salt) + str := fmt.Sprintf("%s:%d:%s", a.Username, a.Time, salt) h := md5.New() h.Write([]byte(str)) return hex.EncodeToString(h.Sum(nil))