From f97b0b2ea7fc323f9065e552eb43edb7b3c593b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=B9=BF?= Date: Fri, 13 Jun 2025 13:31:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E5=AE=9E=E7=8E=B0=20WebSocket=20?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=99=A8=E7=9A=84=E8=AE=A2=E9=98=85=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 clients map 的值从 bool 改为 string,用于存储订阅的主题 - 在 handleWebSocket 函数中实现客户端注册逻辑 - 在 handleClient 函数中处理订阅请求,并保存订阅的主题 - 修改 handleDataChange 函数,仅向订阅了相应主题的客户端广播事件 - 在 DataPointID 结构中添加 MetricHash 方法,用于生成主题哈希值 --- pkg/api/websocket.go | 19 +++++++++++++------ pkg/model/datapoint.go | 4 ++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pkg/api/websocket.go b/pkg/api/websocket.go index 437c5d3..bf32a99 100644 --- a/pkg/api/websocket.go +++ b/pkg/api/websocket.go @@ -20,7 +20,7 @@ type WebSocketServer struct { dataManager *manager.DataManager router *gin.Engine server *http.Server - clients map[*websocket.Conn]bool + clients map[*websocket.Conn]string clientsLock sync.RWMutex upgrader websocket.Upgrader } @@ -48,7 +48,7 @@ func NewWebSocketServer(dataManager *manager.DataManager) *WebSocketServer { server := &WebSocketServer{ dataManager: dataManager, router: router, - clients: make(map[*websocket.Conn]bool), + clients: make(map[*websocket.Conn]string), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的WebSocket连接 @@ -81,7 +81,7 @@ func (s *WebSocketServer) handleWebSocket(c *gin.Context) { // 注册客户端 s.clientsLock.Lock() - s.clients[conn] = true + s.clients[conn] = "" s.clientsLock.Unlock() // 处理客户端消息 @@ -116,9 +116,12 @@ func (s *WebSocketServer) handleClient(conn *websocket.Conn) { log.Printf("Failed to parse subscription request: %v", err) continue } + id := model.DataPointID{ + DeviceID: req.DeviceID, + MetricCode: req.MetricCode, + } + s.clients[conn] = id.MetricHash() - // 处理订阅请求 - // TODO: 实现订阅逻辑 } } @@ -133,6 +136,7 @@ func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.Dat Value: value.Value, } + metricHash := id.MetricHash() // 序列化事件 data, err := json.Marshal(event) if err != nil { @@ -142,7 +146,10 @@ func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.Dat // 广播事件 s.clientsLock.RLock() - for client := range s.clients { + for client, subject := range s.clients { + if subject != metricHash { + continue // 忽略非匹配的订阅 + } if err := client.WriteMessage(websocket.TextMessage, data); err != nil { log.Printf("Failed to send data change event: %v", err) client.Close() diff --git a/pkg/model/datapoint.go b/pkg/model/datapoint.go index 4ad5d90..567f9d4 100644 --- a/pkg/model/datapoint.go +++ b/pkg/model/datapoint.go @@ -19,6 +19,10 @@ func (id DataPointID) String() string { return id.Hash() } +func (id DataPointID) MetricHash() string { + return fmt.Sprintf("%s:%s", id.DeviceID, id.MetricCode) +} + // Equal 判断两个数据点标识是否相等 func (id DataPointID) Equal(other DataPointID) bool { if id.DeviceID != other.DeviceID || id.MetricCode != other.MetricCode {