add pipe
This commit is contained in:
parent
dfb199bdd1
commit
45427d2020
|
@ -3,9 +3,19 @@ package server
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"git.pyer.club/kingecg/gotunnelserver/util"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func (s *Server) HandleAgent(c *websocket.Conn, r *http.Request) {
|
||||
|
||||
agentSession := NewSession(c)
|
||||
agentSession.Start()
|
||||
s.agentSession[agentSession.Id] = agentSession
|
||||
command := &Command{
|
||||
Type: util.NewSession,
|
||||
Payload: map[string]string{
|
||||
"id": agentSession.Id,
|
||||
},
|
||||
}
|
||||
agentSession.Send(command)
|
||||
}
|
||||
|
|
|
@ -7,5 +7,31 @@ import (
|
|||
)
|
||||
|
||||
func (s *Server) HandleClient(conn *websocket.Conn, r *http.Request) {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
clientSession := NewSession(conn)
|
||||
clientSession.Start()
|
||||
s.clientSession[clientSession.Id] = clientSession
|
||||
command := NcmdSession(clientSession.Id)
|
||||
clientSession.Send(command)
|
||||
clientSession.On("NewConnection", func(args ...interface{}) {
|
||||
orgPlayload := args[0].(map[string]string) // orgPayload should include target, host and port field
|
||||
targetSessionId, ok := orgPlayload["target"]
|
||||
|
||||
if !ok {
|
||||
clientSession.Send(NewErrorResponse("target field is required", command))
|
||||
return
|
||||
}
|
||||
targetSession := s.findSession(targetSessionId)
|
||||
if targetSession == nil {
|
||||
clientSession.Send(NewErrorResponse("target session not found", command))
|
||||
return
|
||||
}
|
||||
command := NcmdConnectionInited(clientSession.Id)
|
||||
|
||||
for k, v := range orgPlayload {
|
||||
command.Payload[k] = v
|
||||
}
|
||||
clientSession.Send(command)
|
||||
targetSession.Send(command)
|
||||
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,6 +6,89 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
type Pipe struct {
|
||||
Id string
|
||||
Src string
|
||||
Dst string
|
||||
src *websocket.Conn
|
||||
dst *websocket.Conn
|
||||
stopChan chan int
|
||||
}
|
||||
|
||||
func (p *Pipe) Start() {
|
||||
p.stopChan = make(chan int)
|
||||
go p.forward(p.src, p.dst)
|
||||
go p.forward(p.dst, p.src)
|
||||
<-p.stopChan
|
||||
p.src.Close()
|
||||
p.dst.Close()
|
||||
}
|
||||
|
||||
func (p *Pipe) forward(src, dst *websocket.Conn) {
|
||||
for {
|
||||
mtype, message, err := src.ReadMessage()
|
||||
if err != nil {
|
||||
|
||||
break
|
||||
}
|
||||
err = dst.WriteMessage(mtype, message)
|
||||
if err != nil {
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
p.stopChan <- 1
|
||||
}
|
||||
|
||||
func (s *Server) HandlePipe(conn *websocket.Conn, r *http.Request) {
|
||||
session := r.Header.Get("Session")
|
||||
if session == "" {
|
||||
s.logger.Error("no session header")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
cmdSession := s.findSession(session)
|
||||
if cmdSession == nil {
|
||||
conn.Close()
|
||||
s.logger.Error("command session not found")
|
||||
return
|
||||
}
|
||||
lpath := r.URL.Path
|
||||
pipeId := lpath[len("/ws/pipe/"):]
|
||||
pipe, ok := s.pipes[pipeId]
|
||||
|
||||
if !ok {
|
||||
conn.Close()
|
||||
s.logger.Error("pipe not found")
|
||||
cmdSession.Send(NewErrorResponse("pipe not found", nil))
|
||||
return
|
||||
}
|
||||
|
||||
if pipe.Src == session {
|
||||
pipe.src = conn
|
||||
} else if pipe.Dst == session {
|
||||
pipe.dst = conn
|
||||
} else {
|
||||
cmdSession.Send(NewErrorResponse("not endpoint of current pipe", nil))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if pipe.src != nil && pipe.dst != nil {
|
||||
pipe.Start()
|
||||
clientCmdSession := s.findSession(pipe.Src)
|
||||
clientCmdSession.Send(NcmdConnectionReady(pipe.Id)) // info src endpoint ready and can setup proxy listener
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) findSession(sessionId string) *Session {
|
||||
clientSession, ok := s.clientSession[sessionId]
|
||||
if ok {
|
||||
return clientSession
|
||||
}
|
||||
agentSession, ok := s.agentSession[sessionId]
|
||||
if ok {
|
||||
return agentSession
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -31,9 +31,12 @@ type ServerConfig struct {
|
|||
|
||||
type Server struct {
|
||||
*http.Server
|
||||
mux *http.ServeMux
|
||||
Config *ServerConfig
|
||||
logger *gologger.Logger
|
||||
mux *http.ServeMux
|
||||
Config *ServerConfig
|
||||
logger *gologger.Logger
|
||||
clientSession map[string]*Session
|
||||
agentSession map[string]*Session
|
||||
pipes map[string]*Pipe
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -60,10 +63,10 @@ func (s *Server) auth(r *http.Request) (*http.Request, bool) {
|
|||
if err != nil {
|
||||
return r, false
|
||||
}
|
||||
if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
|
||||
return r, false
|
||||
}
|
||||
if !util.VerifyAuth(authEntity, s.Config.Key) {
|
||||
// if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
|
||||
// return r, false
|
||||
// }
|
||||
if !util.VerifyAuth(authEntity, s.Config.Salt) {
|
||||
return r, false
|
||||
}
|
||||
//set auth entity to request context
|
||||
|
@ -134,9 +137,12 @@ func New(configFile string) *Server {
|
|||
logger := gologger.GetLogger("server")
|
||||
logger.Info("create server")
|
||||
server := &Server{
|
||||
Server: &http.Server{},
|
||||
Config: config,
|
||||
logger: logger,
|
||||
Server: &http.Server{},
|
||||
Config: config,
|
||||
logger: logger,
|
||||
clientSession: make(map[string]*Session),
|
||||
agentSession: make(map[string]*Session),
|
||||
pipes: make(map[string]*Pipe),
|
||||
}
|
||||
return server
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"git.pyer.club/kingecg/goemitter"
|
||||
"git.pyer.club/kingecg/gotunnelserver/util"
|
||||
|
@ -39,7 +40,7 @@ func (s *Session) Start() {
|
|||
return
|
||||
}
|
||||
if cmd.Type == util.NewConnection {
|
||||
s.Emit("NewSession", cmd.Payload)
|
||||
s.Emit("NewConnection", cmd.Payload)
|
||||
}
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
|
@ -63,3 +64,37 @@ func NewSession(conn *websocket.Conn) *Session {
|
|||
EventEmitter: goemitter.NewEmitter(),
|
||||
}
|
||||
}
|
||||
|
||||
func NewCommand(t int, p map[string]string) *Command {
|
||||
return &Command{
|
||||
Type: t,
|
||||
Payload: p,
|
||||
}
|
||||
}
|
||||
|
||||
func NcmdSession(sessionId string) *Command {
|
||||
return NewCommand(util.NewSession, map[string]string{"sessionId": sessionId})
|
||||
}
|
||||
|
||||
func NewErrorResponse(err string, cmd *Command) *Command {
|
||||
payload := map[string]string{
|
||||
"error": err,
|
||||
// "originalcmd": strconv.Itoa(cmd.Type),
|
||||
}
|
||||
if cmd != nil {
|
||||
|
||||
originalCmd, _ := json.Marshal(cmd)
|
||||
payload["originalcmd"] = string(originalCmd)
|
||||
|
||||
}
|
||||
|
||||
return NewCommand(util.ErrorCmd, payload)
|
||||
}
|
||||
|
||||
func NcmdConnectionInited(sessionId string) *Command {
|
||||
return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId})
|
||||
}
|
||||
|
||||
func NcmdConnectionReady(sessionId string) *Command {
|
||||
return NewCommand(util.ConnectionReady, map[string]string{"sessionId": sessionId})
|
||||
}
|
||||
|
|
|
@ -3,5 +3,7 @@ package util
|
|||
const (
|
||||
NewSession = iota
|
||||
NewConnection
|
||||
AgentReady
|
||||
ConnectInited //used to disptatch new connection session id
|
||||
ConnectionReady
|
||||
ErrorCmd
|
||||
)
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type Entity struct {
|
||||
Username string `json:"username"`
|
||||
Token string `json:"token"`
|
||||
// Token string `json:"token"`
|
||||
}
|
||||
|
||||
type AuthEntity struct {
|
||||
|
@ -31,7 +31,7 @@ func VerifyAuth(a *AuthEntity, salt string) bool {
|
|||
}
|
||||
|
||||
func GenAuthToken(a *AuthEntity, salt string) string {
|
||||
str := fmt.Sprintf("%s:%s:%d:%s", a.Username, a.Token, a.Time, salt)
|
||||
str := fmt.Sprintf("%s:%d:%s", a.Username, a.Time, salt)
|
||||
h := md5.New()
|
||||
h.Write([]byte(str))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
|
|
Loading…
Reference in New Issue