diff --git a/pkg/api/websocket.go b/pkg/api/websocket.go index 437c5d3..bf32a99 100644 --- a/pkg/api/websocket.go +++ b/pkg/api/websocket.go @@ -20,7 +20,7 @@ type WebSocketServer struct { dataManager *manager.DataManager router *gin.Engine server *http.Server - clients map[*websocket.Conn]bool + clients map[*websocket.Conn]string clientsLock sync.RWMutex upgrader websocket.Upgrader } @@ -48,7 +48,7 @@ func NewWebSocketServer(dataManager *manager.DataManager) *WebSocketServer { server := &WebSocketServer{ dataManager: dataManager, router: router, - clients: make(map[*websocket.Conn]bool), + clients: make(map[*websocket.Conn]string), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的WebSocket连接 @@ -81,7 +81,7 @@ func (s *WebSocketServer) handleWebSocket(c *gin.Context) { // 注册客户端 s.clientsLock.Lock() - s.clients[conn] = true + s.clients[conn] = "" s.clientsLock.Unlock() // 处理客户端消息 @@ -116,9 +116,12 @@ func (s *WebSocketServer) handleClient(conn *websocket.Conn) { log.Printf("Failed to parse subscription request: %v", err) continue } + id := model.DataPointID{ + DeviceID: req.DeviceID, + MetricCode: req.MetricCode, + } + s.clients[conn] = id.MetricHash() - // 处理订阅请求 - // TODO: 实现订阅逻辑 } } @@ -133,6 +136,7 @@ func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.Dat Value: value.Value, } + metricHash := id.MetricHash() // 序列化事件 data, err := json.Marshal(event) if err != nil { @@ -142,7 +146,10 @@ func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.Dat // 广播事件 s.clientsLock.RLock() - for client := range s.clients { + for client, subject := range s.clients { + if subject != metricHash { + continue // 忽略非匹配的订阅 + } if err := client.WriteMessage(websocket.TextMessage, data); err != nil { log.Printf("Failed to send data change event: %v", err) client.Close() diff --git a/pkg/model/datapoint.go b/pkg/model/datapoint.go index 4ad5d90..567f9d4 100644 --- a/pkg/model/datapoint.go +++ b/pkg/model/datapoint.go @@ -19,6 +19,10 @@ func (id DataPointID) String() string { return id.Hash() } +func (id DataPointID) MetricHash() string { + return fmt.Sprintf("%s:%s", id.DeviceID, id.MetricCode) +} + // Equal 判断两个数据点标识是否相等 func (id DataPointID) Equal(other DataPointID) bool { if id.DeviceID != other.DeviceID || id.MetricCode != other.MetricCode {