gosocketio/vendor/github.com/googollee/go-socket.io/engineio/server.go

173 lines
3.9 KiB
Go
Raw Permalink Normal View History

2023-11-30 17:38:20 +08:00
package engineio
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/googollee/go-socket.io/engineio/session"
"github.com/googollee/go-socket.io/engineio/transport"
)
// Server is instance of server
type Server struct {
pingInterval time.Duration
pingTimeout time.Duration
transports *transport.Manager
sessions *session.Manager
requestChecker CheckerFunc
connInitor ConnInitorFunc
connChan chan Conn
closeOnce sync.Once
}
// NewServer returns a server.
func NewServer(opts *Options) *Server {
return &Server{
transports: transport.NewManager(opts.getTransport()),
pingInterval: opts.getPingInterval(),
pingTimeout: opts.getPingTimeout(),
requestChecker: opts.getRequestChecker(),
connInitor: opts.getConnInitor(),
sessions: session.NewManager(opts.getSessionIDGenerator()),
connChan: make(chan Conn, 1),
}
}
// Close closes server.
func (s *Server) Close() error {
s.closeOnce.Do(func() {
close(s.connChan)
})
return nil
}
// Accept accepts a connection.
func (s *Server) Accept() (Conn, error) {
c := <-s.connChan
if c == nil {
return nil, io.EOF
}
return c, nil
}
func (s *Server) Addr() net.Addr {
return nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Println("Request URL:", r.URL)
query := r.URL.Query()
reqTransport := query.Get("transport")
srvTransport, ok := s.transports.Get(reqTransport)
if !ok || srvTransport == nil {
http.Error(w, fmt.Sprintf("invalid transport: %s", srvTransport), http.StatusBadRequest)
return
}
header, err := s.requestChecker(r)
if err != nil {
http.Error(w, fmt.Sprintf("request checker err: %s", err.Error()), http.StatusBadGateway)
return
}
for k, v := range header {
w.Header()[k] = v
}
sid := query.Get("sid")
reqSession, ok := s.sessions.Get(sid)
// if we can't find session in current session pool, let's create this. behaviour for new connections
if !ok || reqSession == nil {
if sid != "" {
http.Error(w, fmt.Sprintf("invalid sid value: %s", sid), http.StatusBadRequest)
return
}
transportConn, err := srvTransport.Accept(w, r)
if err != nil {
http.Error(w, fmt.Sprintf("transport accept err: %s", err.Error()), http.StatusBadGateway)
return
}
reqSession, err = s.newSession(r.Context(), transportConn, reqTransport)
if err != nil {
http.Error(w, fmt.Sprintf("create new session err: %s", err.Error()), http.StatusBadRequest)
return
}
s.connInitor(r, reqSession)
}
// try upgrade current connection
if reqSession.Transport() != reqTransport {
transportConn, err := srvTransport.Accept(w, r)
if err != nil {
// don't call http.Error() for HandshakeErrors because
// they get handled by the websocket library internally.
if _, ok := err.(websocket.HandshakeError); !ok {
http.Error(w, err.Error(), http.StatusBadGateway)
}
return
}
reqSession.Upgrade(reqTransport, transportConn)
if handler, ok := transportConn.(http.Handler); ok {
handler.ServeHTTP(w, r)
}
return
}
reqSession.ServeHTTP(w, r)
}
// Count counts connected
func (s *Server) Count() int {
return s.sessions.Count()
}
// Remove session from sessions pool. Experimental API.
func (s *Server) Remove(sid string) {
s.sessions.Remove(sid)
}
func (s *Server) newSession(_ context.Context, conn transport.Conn, reqTransport string) (*session.Session, error) {
params := transport.ConnParameters{
PingInterval: s.pingInterval,
PingTimeout: s.pingTimeout,
Upgrades: s.transports.UpgradeFrom(reqTransport),
}
sid := s.sessions.NewID()
newSession, err := session.New(conn, sid, reqTransport, params)
if err != nil {
return nil, err
}
go func(newSession *session.Session) {
if err = newSession.InitSession(); err != nil {
log.Println("init new session:", err)
return
}
s.sessions.Add(newSession)
s.connChan <- newSession
}(newSession)
return newSession, nil
}