package api import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "git.pyer.club/kingecg/gotidb/pkg/manager" "git.pyer.club/kingecg/gotidb/pkg/model" "git.pyer.club/kingecg/gotidb/pkg/storage" ) func setupTestRESTServer() *RESTServer { // 创建存储引擎 engine := storage.NewMemoryEngine() // 创建数据管理器 dataManager := manager.NewDataManager(engine) // 创建REST服务器 server := NewRESTServer(dataManager) return server } func TestRESTServer_WriteEndpoint(t *testing.T) { // 设置测试模式 gin.SetMode(gin.TestMode) // 创建测试服务器 server := setupTestRESTServer() // 创建测试请求 writeReq := RESTWriteRequest{ DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, Value: 25.5, } body, _ := json.Marshal(writeReq) req, _ := http.NewRequest("POST", "/api/v1/write", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") // 创建响应记录器 w := httptest.NewRecorder() // 设置路由 r := gin.New() r.POST("/api/v1/write", server.handleWrite) // 执行请求 r.ServeHTTP(w, req) // 检查响应状态码 if w.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) } // 解析响应 var resp Response err := json.Unmarshal(w.Body.Bytes(), &resp) if err != nil { t.Errorf("Failed to unmarshal response: %v", err) } // 验证响应 if resp["status"] != "ok" { t.Errorf("Expected success to be true, got false") } } func TestRESTServer_BatchWriteEndpoint(t *testing.T) { // 设置测试模式 gin.SetMode(gin.TestMode) // 创建测试服务器 server := setupTestRESTServer() // 创建测试请求 batchReq := BatchWriteRequest{ Points: []RESTWriteRequest{ { DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, Value: 25.5, }, { DeviceID: "test-device", MetricCode: "humidity", Labels: map[string]string{ "location": "room1", }, Value: 60.0, }, }, } body, _ := json.Marshal(batchReq) req, _ := http.NewRequest("POST", "/api/v1/batch_write", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") // 创建响应记录器 w := httptest.NewRecorder() // 设置路由 r := gin.New() r.POST("/api/v1/batch_write", server.handleBatchWrite) // 执行请求 r.ServeHTTP(w, req) // 检查响应状态码 if w.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) } // 解析响应 var resp Response err := json.Unmarshal(w.Body.Bytes(), &resp) if err != nil { t.Errorf("Failed to unmarshal response: %v", err) } // 验证响应 if resp["status"] != "ok" { t.Errorf("Expected success to be true, got false") } // 验证计数 if count, ok := resp["count"].(float64); !ok || count != 2 { t.Errorf("Expected count to be 2, got %v", resp["count"]) } } func TestRESTServer_QueryEndpoint(t *testing.T) { // 设置测试模式 gin.SetMode(gin.TestMode) // 创建测试服务器 server := setupTestRESTServer() // 写入测试数据 engine := storage.NewMemoryEngine() dataManager := manager.NewDataManager(engine) server.dataManager = dataManager id := model.DataPointID{ DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, } now := time.Now() value := model.DataValue{ Timestamp: now, Value: 25.5, } err := dataManager.Write(context.Background(), id, value) if err != nil { t.Fatalf("Failed to write test data: %v", err) } // 创建测试请求 queryReq := QueryRequest{ DeviceID: "test-device", MetricCode: "temperature", Labels: map[string]string{ "location": "room1", }, QueryType: "latest", } body, _ := json.Marshal(queryReq) req, _ := http.NewRequest("POST", "/api/v1/query", bytes.NewBuffer(body)) req.Header.Set("Content-Type", "application/json") // 创建响应记录器 w := httptest.NewRecorder() // 设置路由 r := gin.New() r.POST("/api/v1/query", server.handleQuery) // 执行请求 r.ServeHTTP(w, req) // 检查响应状态码 if w.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) } // 解析响应 var resp Response err = json.Unmarshal(w.Body.Bytes(), &resp) if err != nil { t.Errorf("Failed to unmarshal response: %v", err) } // 验证响应 if resp["status"] != "ok" { t.Errorf("Expected success to be true, got false") } // 验证返回的数据 if resp["timestamp"] == nil { t.Errorf("Expected data to be non-nil") } // // 验证最新值 // if resp.QueryType != "latest" { // t.Errorf("Expected query_type to be 'latest', got '%s'", resp.QueryType) // } // if resp.Data.(map[string]interface{})["value"] != 25.5 { // t.Errorf("Expected value to be 25.5, got %v", resp.Data.(map[string]interface{})["value"]) // } } func TestRESTServer_Start(t *testing.T) { // 创建测试服务器 server := setupTestRESTServer() // 启动服务器(在后台) go func() { err := server.Start(":0") // 使用随机端口 if err != nil && err != http.ErrServerClosed { t.Errorf("Failed to start server: %v", err) } }() // 给服务器一点时间启动 time.Sleep(100 * time.Millisecond) // 停止服务器 err := server.Stop(context.Background()) if err != nil { t.Errorf("Failed to stop server: %v", err) } }