gotidb/pkg/api/websocket.go

176 lines
4.1 KiB
Go

package api
import (
"context"
"encoding/json"
"log"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"git.pyer.club/kingecg/gotidb/pkg/manager"
"git.pyer.club/kingecg/gotidb/pkg/model"
)
// WebSocketServer WebSocket服务
type WebSocketServer struct {
dataManager *manager.DataManager
router *gin.Engine
server *http.Server
clients map[*websocket.Conn]string
clientsLock sync.RWMutex
upgrader websocket.Upgrader
}
// SubscriptionRequest 订阅请求
type SubscriptionRequest struct {
DeviceID string `json:"device_id"`
MetricCode string `json:"metric_code"`
Labels map[string]string `json:"labels"`
}
// DataChangeEvent 数据变更事件
type DataChangeEvent struct {
DeviceID string `json:"device_id"`
MetricCode string `json:"metric_code"`
Labels map[string]string `json:"labels"`
Timestamp time.Time `json:"timestamp"`
Value interface{} `json:"value"`
}
// NewWebSocketServer 创建一个新的WebSocket服务
func NewWebSocketServer(dataManager *manager.DataManager) *WebSocketServer {
router := gin.Default()
server := &WebSocketServer{
dataManager: dataManager,
router: router,
clients: make(map[*websocket.Conn]string),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源的WebSocket连接
},
},
}
server.setupRoutes()
// 注册数据变更回调
dataManager.RegisterCallback(server.handleDataChange)
return server
}
// setupRoutes 设置路由
func (s *WebSocketServer) setupRoutes() {
// WebSocket连接
s.router.GET("/ws", s.handleWebSocket)
}
// handleWebSocket 处理WebSocket连接
func (s *WebSocketServer) handleWebSocket(c *gin.Context) {
// 升级HTTP连接为WebSocket连接
conn, err := s.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
return
}
// 注册客户端
s.clientsLock.Lock()
s.clients[conn] = ""
s.clientsLock.Unlock()
// 处理客户端消息
go s.handleClient(conn)
}
// handleClient 处理客户端消息
func (s *WebSocketServer) handleClient(conn *websocket.Conn) {
defer func() {
// 关闭连接
conn.Close()
// 移除客户端
s.clientsLock.Lock()
delete(s.clients, conn)
s.clientsLock.Unlock()
}()
// 读取客户端消息
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
// 解析订阅请求
var req SubscriptionRequest
if err := json.Unmarshal(message, &req); err != nil {
log.Printf("Failed to parse subscription request: %v", err)
continue
}
id := model.DataPointID{
DeviceID: req.DeviceID,
MetricCode: req.MetricCode,
}
s.clients[conn] = id.MetricHash()
}
}
// handleDataChange 处理数据变更
func (s *WebSocketServer) handleDataChange(id model.DataPointID, value model.DataValue) {
// 创建数据变更事件
event := DataChangeEvent{
DeviceID: id.DeviceID,
MetricCode: id.MetricCode,
Labels: id.Labels,
Timestamp: value.Timestamp,
Value: value.Value,
}
metricHash := id.MetricHash()
// 序列化事件
data, err := json.Marshal(event)
if err != nil {
log.Printf("Failed to marshal data change event: %v", err)
return
}
// 广播事件
s.clientsLock.RLock()
for client, subject := range s.clients {
if subject != metricHash {
continue // 忽略非匹配的订阅
}
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
log.Printf("Failed to send data change event: %v", err)
client.Close()
delete(s.clients, client)
}
}
s.clientsLock.RUnlock()
}
// Start 启动WebSocket服务
func (s *WebSocketServer) Start(addr string) error {
s.server = &http.Server{
Addr: addr,
Handler: s.router,
}
return s.server.ListenAndServe()
}
// Stop 停止WebSocket服务
func (s *WebSocketServer) Stop(ctx context.Context) error {
return s.server.Shutdown(ctx)
}