gotidb/pkg/api/websocket_test.go

235 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"git.pyer.club/kingecg/gotidb/pkg/manager"
"git.pyer.club/kingecg/gotidb/pkg/model"
"git.pyer.club/kingecg/gotidb/pkg/storage"
)
func setupTestWebSocketServer() *WebSocketServer {
// 创建存储引擎
engine := storage.NewMemoryEngine(200)
// 创建数据管理器
dataManager := manager.NewDataManager(engine)
// 创建WebSocket服务器
server := NewWebSocketServer(dataManager)
return server
}
func TestWebSocketServer_Connection(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 创建HTTP服务器
httpServer := httptest.NewServer(server.router)
defer httpServer.Close()
// 将HTTP URL转换为WebSocket URL - 添加/ws路径
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws"
// 连接到WebSocket服务器
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket server: %v", err)
}
defer ws.Close()
// 发送订阅消息
subscription := SubscriptionRequest{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
err = ws.WriteJSON(subscription)
if err != nil {
t.Fatalf("Failed to send subscription message: %v", err)
}
// 等待一段时间,确保订阅已处理
time.Sleep(100 * time.Millisecond)
// 写入数据触发WebSocket推送
id := model.DataPointID{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
value := model.DataValue{
Timestamp: time.Now(),
Value: 25.5,
}
err = server.dataManager.Write(context.Background(), id, value)
if err != nil {
t.Fatalf("Failed to write data: %v", err)
}
// 设置读取超时
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
// 读取WebSocket消息
_, message, err := ws.ReadMessage()
if err != nil {
t.Fatalf("Failed to read WebSocket message: %v", err)
}
// 解析消息
var update DataChangeEvent
err = json.Unmarshal(message, &update)
if err != nil {
t.Fatalf("Failed to unmarshal WebSocket message: %v", err)
}
// 验证消息内容
if update.DeviceID != "test-device" {
t.Errorf("Expected DeviceID to be 'test-device', got '%s'", update.DeviceID)
}
if update.MetricCode != "temperature" {
t.Errorf("Expected MetricCode to be 'temperature', got '%s'", update.MetricCode)
}
if update.Value != 25.5 {
t.Errorf("Expected Value to be 25.5, got %v", update.Value)
}
if update.Labels["location"] != "room1" {
t.Errorf("Expected Labels['location'] to be 'room1', got '%s'", update.Labels["location"])
}
}
func TestWebSocketServer_MultipleSubscriptions(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 创建HTTP服务器
httpServer := httptest.NewServer(server.router)
defer httpServer.Close()
// 将HTTP URL转换为WebSocket URL - 添加/ws路径
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http") + "/ws"
// 连接到WebSocket服务器
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket server: %v", err)
}
defer ws.Close()
// 发送多个订阅消息
subscriptions := []SubscriptionRequest{
{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
},
{
DeviceID: "test-device",
MetricCode: "humidity",
Labels: map[string]string{
"location": "room1",
},
},
}
for _, subscription := range subscriptions {
err = ws.WriteJSON(subscription)
if err != nil {
t.Fatalf("Failed to send subscription message: %v", err)
}
}
// 等待一段时间,确保订阅已处理
time.Sleep(100 * time.Millisecond)
// 写入数据触发WebSocket推送
for _, subscription := range subscriptions {
id := model.DataPointID{
DeviceID: subscription.DeviceID,
MetricCode: subscription.MetricCode,
Labels: subscription.Labels,
}
value := model.DataValue{
Timestamp: time.Now(),
Value: 25.5,
}
err = server.dataManager.Write(context.Background(), id, value)
if err != nil {
t.Fatalf("Failed to write data: %v", err)
}
// 等待一段时间,确保数据已处理
time.Sleep(100 * time.Millisecond)
// 设置读取超时
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
// 读取WebSocket消息
_, message, err := ws.ReadMessage()
if err != nil {
t.Fatalf("Failed to read WebSocket message: %v", err)
}
// 解析消息
var update DataChangeEvent
err = json.Unmarshal(message, &update)
if err != nil {
t.Fatalf("Failed to unmarshal WebSocket message: %v", err)
}
// 验证消息内容
if update.DeviceID != subscription.DeviceID {
t.Errorf("Expected DeviceID to be '%s', got '%s'", subscription.DeviceID, update.DeviceID)
}
if update.MetricCode != subscription.MetricCode {
t.Errorf("Expected MetricCode to be '%s', got '%s'", subscription.MetricCode, update.MetricCode)
}
}
}
func TestWebSocketServer_Start(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 启动服务器(在后台)
go func() {
err := server.Start(":0") // 使用随机端口
if err != nil && err != http.ErrServerClosed {
t.Errorf("Failed to start server: %v", err)
}
}()
// 给服务器一点时间启动
time.Sleep(100 * time.Millisecond)
// 停止服务器
err := server.Stop(context.Background())
if err != nil {
t.Errorf("Failed to stop server: %v", err)
}
}