package server import ( "context" "encoding/json" "fmt" "os" "path/filepath" "strings" "net/http" "git.pyer.club/kingecg/gologger" "git.pyer.club/kingecg/gotunnelserver/util" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } type ServerConfig struct { Port string `json:"port"` Host string `json:"host"` Cert string `json:"cert"` Key string `json:"key"` Salt string `json:"salt"` Logger *gologger.LoggersConfig `json:"log"` } type Server struct { *http.Server mux *http.ServeMux Config *ServerConfig logger *gologger.Logger } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.logger.Info("request:", r.URL.Path) nr, ok := s.auth(r) if !ok { s.HandleUnAutherized(w, nr) } else { s.mux.ServeHTTP(w, nr) } } func (s *Server) auth(r *http.Request) (*http.Request, bool) { s.logger.Debug("auth:", r.RemoteAddr) lpath := r.URL.Path if lpath == "/" || lpath == "/hello" { return r, true } token := r.Header.Get("Authorization") if token == "" { return r, false } authEntity, err := util.ParseAuthEntity(token) if err != nil { return r, false } if strings.HasPrefix(lpath, "/ws/pipe") && authEntity.Token == "nil" { return r, false } if !util.VerifyAuth(authEntity, s.Config.Key) { return r, false } //set auth entity to request context ctx := r.Context() nctx := context.WithValue(ctx, "authEntity", authEntity) return r.WithContext(nctx), true } func (s *Server) Shutdown() { s.logger.Info("shutdown server") s.Close() } func (s *Server) HandleHello(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) } func (s *Server) HandleUnAutherized(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) } func (s *Server) HandleWs(w http.ResponseWriter, r *http.Request) { lpath := r.URL.Path lpath = strings.TrimPrefix(lpath, "/ws/") c, err := upgrader.Upgrade(w, r, nil) if err != nil { s.logger.Error("upgrade failed:", err) return } if strings.HasPrefix(lpath, "client") { s.HandleClient(c, r) } else if strings.HasPrefix(lpath, "agent") { s.HandleAgent(c, r) } else if strings.HasPrefix(lpath, "pipe") { s.HandlePipe(c, r) } else { s.logger.Error("unknown path:", lpath) c.Close() } } func (s *Server) registHandler() { if s.mux == nil { s.mux = http.NewServeMux() } s.mux.HandleFunc("/hello", s.HandleHello) s.mux.HandleFunc("/ws", s.HandleWs) } func (s *Server) Start() { addr := s.Config.Host + ":" + s.Config.Port s.Server.Addr = addr s.Server.Handler = s s.registHandler() if s.Config.Cert != "" && s.Config.Key != "" { s.logger.Info("start https server") s.ListenAndServeTLS(s.Config.Cert, s.Config.Key) } else { s.logger.Info("start http server") s.ListenAndServe() } } func New(configFile string) *Server { // load config from configFile config, err := LoadConfig(configFile) if err != nil { return nil } gologger.Configure(*config.Logger) logger := gologger.GetLogger("server") logger.Info("create server") server := &Server{ Server: &http.Server{}, Config: config, logger: logger, } return server } func LoadConfig(configFile string) (conf *ServerConfig, err error) { aconfPath := configFile if aconfPath == "" { aconfPath = "conf.json" } if !filepath.IsAbs(aconfPath) { aconfPath, _ = filepath.Abs(aconfPath) } fileConten, err := os.ReadFile(aconfPath) if err != nil { fmt.Printf("read config file failed: %s", err) return nil, err } conf = new(ServerConfig) err = json.Unmarshal(fileConten, conf) if err != nil { fmt.Printf("parse config file failed: %s", err) return nil, err } if conf.Cert != "" && conf.Key != "" { conf.Cert = Abs(filepath.Dir(aconfPath), conf.Cert) conf.Key = Abs(filepath.Dir(aconfPath), conf.Key) } else { conf.Cert = "" conf.Key = "" } return conf, nil } func Abs(basePath string, path string) string { if filepath.IsAbs(path) { return path } return filepath.Join(basePath, path) }