185 lines
4.3 KiB
Go
185 lines
4.3 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"net/http"
|
|
|
|
"git.pyer.club/kingecg/gologger"
|
|
"git.pyer.club/kingecg/gotunnelserver/util"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 4096,
|
|
WriteBufferSize: 4096,
|
|
}
|
|
|
|
type ServerConfig struct {
|
|
Port string `json:"port"`
|
|
Host string `json:"host"`
|
|
Cert string `json:"cert"`
|
|
Key string `json:"key"`
|
|
Salt string `json:"salt"`
|
|
Logger *gologger.LoggersConfig `json:"log"`
|
|
}
|
|
|
|
type Server struct {
|
|
*http.Server
|
|
mux *http.ServeMux
|
|
Config *ServerConfig
|
|
logger *gologger.Logger
|
|
clientSession map[string]*Session
|
|
agentSession map[string]*Session
|
|
pipes map[string]*Pipe
|
|
}
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.logger.Info("request:", r.URL.Path)
|
|
nr, ok := s.auth(r)
|
|
if !ok {
|
|
s.HandleUnAutherized(w, nr)
|
|
} else {
|
|
s.mux.ServeHTTP(w, nr)
|
|
}
|
|
}
|
|
|
|
func (s *Server) auth(r *http.Request) (*http.Request, bool) {
|
|
s.logger.Debug("auth:", r.RemoteAddr)
|
|
lpath := r.URL.Path
|
|
if lpath == "/" || lpath == "/hello" {
|
|
return r, true
|
|
}
|
|
token := r.Header.Get("Authorization")
|
|
if token == "" {
|
|
return r, false
|
|
}
|
|
authEntity, err := util.ParseAuthEntity(token)
|
|
if err != nil {
|
|
return r, false
|
|
}
|
|
// if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" {
|
|
// return r, false
|
|
// }
|
|
if !util.VerifyAuth(authEntity, s.Config.Salt) {
|
|
return r, false
|
|
}
|
|
//set auth entity to request context
|
|
ctx := r.Context()
|
|
nctx := context.WithValue(ctx, "authEntity", authEntity)
|
|
return r.WithContext(nctx), true
|
|
}
|
|
func (s *Server) Shutdown() {
|
|
s.logger.Info("shutdown server")
|
|
s.Close()
|
|
}
|
|
|
|
func (s *Server) HandleHello(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte("hello"))
|
|
}
|
|
|
|
func (s *Server) HandleUnAutherized(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}
|
|
func (s *Server) HandleWs(w http.ResponseWriter, r *http.Request) {
|
|
lpath := r.URL.Path
|
|
lpath = strings.TrimPrefix(lpath, "/ws/")
|
|
c, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
s.logger.Error("upgrade failed:", err)
|
|
return
|
|
}
|
|
if strings.HasPrefix(lpath, "client") {
|
|
s.HandleClient(c, r)
|
|
} else if strings.HasPrefix(lpath, "agent") {
|
|
s.HandleAgent(c, r)
|
|
} else if strings.HasPrefix(lpath, "pipe") {
|
|
s.HandlePipe(c, r)
|
|
} else {
|
|
s.logger.Error("unknown path:", lpath)
|
|
c.Close()
|
|
}
|
|
}
|
|
func (s *Server) registHandler() {
|
|
if s.mux == nil {
|
|
s.mux = http.NewServeMux()
|
|
}
|
|
s.mux.HandleFunc("/hello", s.HandleHello)
|
|
s.mux.HandleFunc("/ws/", s.HandleWs)
|
|
}
|
|
func (s *Server) Start() {
|
|
addr := s.Config.Host + ":" + s.Config.Port
|
|
s.Server.Addr = addr
|
|
|
|
s.Server.Handler = s
|
|
s.registHandler()
|
|
if s.Config.Cert != "" && s.Config.Key != "" {
|
|
s.logger.Info("start https server")
|
|
s.ListenAndServeTLS(s.Config.Cert, s.Config.Key)
|
|
} else {
|
|
s.logger.Info("start http server")
|
|
s.ListenAndServe()
|
|
}
|
|
}
|
|
|
|
func New(configFile string) *Server {
|
|
// load config from configFile
|
|
config, err := LoadConfig(configFile)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
gologger.Configure(*config.Logger)
|
|
logger := gologger.GetLogger("server")
|
|
logger.Info("create server")
|
|
server := &Server{
|
|
Server: &http.Server{},
|
|
Config: config,
|
|
logger: logger,
|
|
clientSession: make(map[string]*Session),
|
|
agentSession: make(map[string]*Session),
|
|
pipes: make(map[string]*Pipe),
|
|
}
|
|
return server
|
|
}
|
|
|
|
func LoadConfig(configFile string) (conf *ServerConfig, err error) {
|
|
aconfPath := configFile
|
|
if aconfPath == "" {
|
|
aconfPath = "conf.json"
|
|
}
|
|
if !filepath.IsAbs(aconfPath) {
|
|
aconfPath, _ = filepath.Abs(aconfPath)
|
|
}
|
|
fileConten, err := os.ReadFile(aconfPath)
|
|
if err != nil {
|
|
fmt.Printf("read config file failed: %s", err)
|
|
return nil, err
|
|
}
|
|
conf = new(ServerConfig)
|
|
err = json.Unmarshal(fileConten, conf)
|
|
if err != nil {
|
|
fmt.Printf("parse config file failed: %s", err)
|
|
return nil, err
|
|
}
|
|
if conf.Cert != "" && conf.Key != "" {
|
|
conf.Cert = Abs(filepath.Dir(aconfPath), conf.Cert)
|
|
conf.Key = Abs(filepath.Dir(aconfPath), conf.Key)
|
|
} else {
|
|
conf.Cert = ""
|
|
conf.Key = ""
|
|
}
|
|
return conf, nil
|
|
}
|
|
|
|
func Abs(basePath string, path string) string {
|
|
if filepath.IsAbs(path) {
|
|
return path
|
|
}
|
|
return filepath.Join(basePath, path)
|
|
}
|