package api import ( "context" "encoding/json" "log" "net/http" "sync" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "git.pyer.club/kingecg/gotidb/pkg/manager" "git.pyer.club/kingecg/gotidb/pkg/model" ) // WebSocketServer WebSocket服务 type WebSocketServer struct { dataManager *manager.DataManager router *gin.Engine server *http.Server clients map[*websocket.Conn]bool clientsLock sync.RWMutex upgrader websocket.Upgrader } // SubscriptionRequest 订阅请求 type SubscriptionRequest struct { DeviceID string `json:"device_id"` MetricCode string `json:"metric_code"` Labels map[string]string `json:"labels"` } // DataChangeEvent 数据变更事件 type DataChangeEvent struct { DeviceID string `json:"device_id"` MetricCode string `json:"metric_code"` Labels map[string]string `json:"labels"` Timestamp time.Time `json:"timestamp"` Value interface{} `json:"value"` } // NewWebSocketServer 创建一个新的WebSocket服务 func NewWebSocketServer(dataManager *manager.DataManager) *WebSocketServer { router := gin.Default() server := &WebSocketServer{ dataManager: dataManager, router: router, clients: make(map[*websocket.Conn]bool), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的WebSocket连接 }, }, } server.setupRoutes() // 注册数据变更回调 dataManager.RegisterCallback(server.handleDataChange) return server } // setupRoutes 设置路由 func (s *WebSocketServer) setupRoutes() { // WebSocket连接 s.router.GET("/ws", s.handleWebSocket) } // handleWebSocket 处理WebSocket连接 func (s *WebSocketServer) handleWebSocket(c *gin.Context) { // 升级HTTP连接为WebSocket连接 conn, err := s.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("Failed to upgrade connection: %v", err) return } // 注册客户端 s.clientsLock.Lock() s.clients[conn] = true s.clientsLock.Unlock() // 处理客户端消息 go s.handleClient(conn) } // handleClient 处理客户端消息 func (s *WebSocketServer) handleClient(conn *websocket.Conn) { defer func() { // 关闭连接 conn.Close() // 移除客户端 s.clientsLock.Lock() delete(s.clients, conn) s.clientsLock.Unlock() }() // 读取客户端消息 for { _, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } // 解析订阅请求 var req SubscriptionRequest if err := json.Unmarshal(message, &req); err != nil { log.Printf("Failed to parse subscription request: %v", err) continue } // 处理订阅请求 // TODO: 实现订阅逻辑 } } // handleDataChange 处理数据变更 func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.DataValue) { // 创建数据变更事件 event := DataChangeEvent{ DeviceID: id.DeviceID, MetricCode: id.MetricCode, Labels: id.Labels, Timestamp: value.Timestamp, Value: value.Value, } // 序列化事件 data, err := json.Marshal(event) if err != nil { log.Printf("Failed to marshal data change event: %v", err) return } // 广播事件 s.clientsLock.RLock() for client := range s.clients { if err := client.WriteMessage(websocket.TextMessage, data); err != nil { log.Printf("Failed to send data change event: %v", err) client.Close() delete(s.clients, client) } } s.clientsLock.RUnlock() } // Start 启动WebSocket服务 func (s *WebSocketServer) Start(addr string) error { s.server = &http.Server{ Addr: addr, Handler: s.router, } return s.server.ListenAndServe() } // Stop 停止WebSocket服务 func (s *WebSocketServer) Stop(ctx context.Context) error { return s.server.Shutdown(ctx) }