package api import ( "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" "git.pyer.club/kingecg/gotidb/pkg/manager" "git.pyer.club/kingecg/gotidb/pkg/model" "git.pyer.club/kingecg/gotidb/pkg/storage" ) func setupTestWebSocketServer() *WebSocketServer { // 创建存储引擎 engine := storage.NewMemoryEngine() // 创建数据管理器 dataManager := manager.NewDataManager(engine) // 创建WebSocket服务器 server := NewWebSocketServer(dataManager) return server } func TestWebSocketServer_Connection(t *testing.T) { // 创建测试服务器 server := setupTestWebSocketServer() // 创建HTTP服务器 httpServer := httptest.NewServer(server.router) defer httpServer.Close() // 将HTTP URL转换为WebSocket URL wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") // 连接到WebSocket服务器 ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("Failed to connect to WebSocket server: %v", err) } defer ws.Close() // 发送订阅消息 subscription := SubscriptionRequest{ DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, } err = ws.WriteJSON(subscription) if err != nil { t.Fatalf("Failed to send subscription message: %v", err) } // 等待一段时间,确保订阅已处理 time.Sleep(100 * time.Millisecond) // 写入数据,触发WebSocket推送 id := model.DataPointID{ DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, } value := model.DataValue{ Timestamp: time.Now(), Value: 25.5, } err = server.dataManager.Write(context.Background(), id, value) if err != nil { t.Fatalf("Failed to write data: %v", err) } // 设置读取超时 ws.SetReadDeadline(time.Now().Add(1 * time.Second)) // 读取WebSocket消息 _, message, err := ws.ReadMessage() if err != nil { t.Fatalf("Failed to read WebSocket message: %v", err) } // 解析消息 var update DataChangeEvent err = json.Unmarshal(message, &update) if err != nil { t.Fatalf("Failed to unmarshal WebSocket message: %v", err) } // 验证消息内容 if update.DeviceID != "test-device" { t.Errorf("Expected DeviceID to be 'test-device', got '%s'", update.DeviceID) } if update.MetricCode != "temperature" { t.Errorf("Expected MetricCode to be 'temperature', got '%s'", update.MetricCode) } if update.Value != 25.5 { t.Errorf("Expected Value to be 25.5, got %v", update.Value) } if update.Labels["location"] != "room1" { t.Errorf("Expected Labels['location'] to be 'room1', got '%s'", update.Labels["location"]) } } func TestWebSocketServer_MultipleSubscriptions(t *testing.T) { // 创建测试服务器 server := setupTestWebSocketServer() // 创建HTTP服务器 httpServer := httptest.NewServer(server.router) defer httpServer.Close() // 将HTTP URL转换为WebSocket URL wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") // 连接到WebSocket服务器 ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("Failed to connect to WebSocket server: %v", err) } defer ws.Close() // 发送多个订阅消息 subscriptions := []SubscriptionRequest{ { DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, }, { DeviceID: "test-device", MetricCode: "humidity", Labels: map[string]string{ "location": "room1", }, }, } for _, subscription := range subscriptions { err = ws.WriteJSON(subscription) if err != nil { t.Fatalf("Failed to send subscription message: %v", err) } } // 等待一段时间,确保订阅已处理 time.Sleep(100 * time.Millisecond) // 写入数据,触发WebSocket推送 for _, subscription := range subscriptions { id := model.DataPointID{ DeviceID: subscription.DeviceID, MetricCode: subscription.MetricCode, Labels: subscription.Labels, } value := model.DataValue{ Timestamp: time.Now(), Value: 25.5, } err = server.dataManager.Write(context.Background(), id, value) if err != nil { t.Fatalf("Failed to write data: %v", err) } // 等待一段时间,确保数据已处理 time.Sleep(100 * time.Millisecond) // 设置读取超时 ws.SetReadDeadline(time.Now().Add(1 * time.Second)) // 读取WebSocket消息 _, message, err := ws.ReadMessage() if err != nil { t.Fatalf("Failed to read WebSocket message: %v", err) } // 解析消息 var update DataChangeEvent err = json.Unmarshal(message, &update) if err != nil { t.Fatalf("Failed to unmarshal WebSocket message: %v", err) } // 验证消息内容 if update.DeviceID != subscription.DeviceID { t.Errorf("Expected DeviceID to be '%s', got '%s'", subscription.DeviceID, update.DeviceID) } if update.MetricCode != subscription.MetricCode { t.Errorf("Expected MetricCode to be '%s', got '%s'", subscription.MetricCode, update.MetricCode) } } } func TestWebSocketServer_Start(t *testing.T) { // 创建测试服务器 server := setupTestWebSocketServer() // 启动服务器(在后台) go func() { err := server.Start(":0") // 使用随机端口 if err != nil && err != http.ErrServerClosed { t.Errorf("Failed to start server: %v", err) } }() // 给服务器一点时间启动 time.Sleep(100 * time.Millisecond) // 停止服务器 err := server.Stop(context.Background()) if err != nil { t.Errorf("Failed to stop server: %v", err) } }