Compare commits

...

3 Commits

Author SHA1 Message Date
kingecg 0a193d72c2 Merge branch 'master' of ssh://git.pyer.club:2222/kingecg/gotidb 2025-06-11 19:45:13 +08:00
kingecg 11149c0c94 feat(server): 添加配置文件支持和示例配置生成
- 新增 GenerateSampleConfig 函数用于生成示例配置文件
- 在 main 函数中添加配置文件路径和生成示例配置的命令行参数
- 实现配置文件加载逻辑,替代命令行参数
- 优化命令行参数默认值,如 NATS 服务器地址
2025-06-11 00:14:52 +08:00
kingecg 9682c51336 fix code 2025-06-11 00:07:20 +08:00
19 changed files with 2153 additions and 57 deletions

117
Makefile Normal file
View File

@ -0,0 +1,117 @@
.PHONY: all build test clean fmt lint build-all
# Go parameters
GOCMD=go
GOBUILD=$(GOCMD) build
GOCLEAN=$(GOCMD) clean
GOTEST=$(GOCMD) test
GOGET=$(GOCMD) get
GOMOD=$(GOCMD) mod
GOFMT=$(GOCMD) fmt
GOLINT=golangci-lint
# Binary name
BINARY_NAME=gotidb
BINARY_UNIX=$(BINARY_NAME)_unix
BINARY_WIN=$(BINARY_NAME).exe
BINARY_MAC=$(BINARY_NAME)_mac
# Build directory
BUILD_DIR=build
# Main package path
MAIN_PACKAGE=./cmd/server
# Get the current git commit hash
COMMIT=$(shell git rev-parse --short HEAD)
BUILD_TIME=$(shell date +%FT%T%z)
# Build flags
LDFLAGS=-ldflags "-X main.commit=${COMMIT} -X main.buildTime=${BUILD_TIME}"
# Default target
all: test build
# Build the project
build:
mkdir -p $(BUILD_DIR)
$(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PACKAGE)
# Build for all platforms
build-all: build-linux build-windows build-mac
build-linux:
mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_UNIX) $(MAIN_PACKAGE)
build-windows:
mkdir -p $(BUILD_DIR)
GOOS=windows GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_WIN) $(MAIN_PACKAGE)
build-mac:
mkdir -p $(BUILD_DIR)
GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_MAC) $(MAIN_PACKAGE)
# Run tests
test:
$(GOTEST) -v ./...
# Run tests with coverage
test-coverage:
$(GOTEST) -v -coverprofile=coverage.out ./...
$(GOCMD) tool cover -html=coverage.out -o coverage.html
# Clean build artifacts
clean:
$(GOCLEAN)
rm -rf $(BUILD_DIR)
rm -f coverage.out coverage.html
# Format code
fmt:
$(GOFMT) ./...
# Run linter
lint:
$(GOLINT) run
# Download dependencies
deps:
$(GOMOD) download
# Verify dependencies
verify:
$(GOMOD) verify
# Update dependencies
update-deps:
$(GOMOD) tidy
# Install development tools
install-tools:
$(GOGET) -u github.com/golangci/golangci-lint/cmd/golangci-lint
# Run the application
run:
$(GOBUILD) -o $(BUILD_DIR)/$(BINARY_NAME) $(MAIN_PACKAGE)
./$(BUILD_DIR)/$(BINARY_NAME)
# Help target
help:
@echo "Available targets:"
@echo " all : Run tests and build"
@echo " build : Build for current platform"
@echo " build-all : Build for all platforms"
@echo " test : Run tests"
@echo " test-coverage: Run tests with coverage"
@echo " clean : Clean build artifacts"
@echo " fmt : Format code"
@echo " lint : Run linter"
@echo " deps : Download dependencies"
@echo " verify : Verify dependencies"
@echo " update-deps : Update dependencies"
@echo " install-tools: Install development tools"
@echo " run : Run the application"
# Default to help if no target is specified
.DEFAULT_GOAL := help

53
cmd/server/config.go Normal file
View File

@ -0,0 +1,53 @@
package main
import (
"os"
"gopkg.in/yaml.v3"
)
// Config 应用程序配置结构
type Config struct {
RestAddr string `yaml:"rest_addr"`
WsAddr string `yaml:"ws_addr"`
MetricsAddr string `yaml:"metrics_addr"`
NATSURL string `yaml:"nats_url"`
PersistenceType string `yaml:"persistence_type"`
PersistenceDir string `yaml:"persistence_dir"`
SyncEvery int `yaml:"sync_every"`
}
func LoadConfig(path string) (*Config, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var config Config
decoder := yaml.NewDecoder(file)
if err := decoder.Decode(&config); err != nil {
return nil, err
}
return &config, nil
}
func GenerateSampleConfig(path string) error {
config := Config{
MetricsAddr: ":8082",
NATSURL: "nats://localhost:4222",
PersistenceDir: "./data",
PersistenceType: "memory",
RestAddr: ":8080",
SyncEvery: 1000,
WsAddr: ":8081",
}
// 序列化yaml到path
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
return yaml.NewEncoder(file).Encode(config)
}

View File

@ -21,14 +21,38 @@ var (
restAddr = flag.String("rest-addr", ":8080", "REST API服务地址")
wsAddr = flag.String("ws-addr", ":8081", "WebSocket服务地址")
metricsAddr = flag.String("metrics-addr", ":8082", "指标服务地址")
natsURL = flag.String("nats-url", "nats://localhost:4222", "NATS服务器地址")
natsURL = flag.String("nats-url", "", "NATS服务器地址")
persistenceType = flag.String("persistence", "none", "持久化类型 (none, wal)")
persistenceDir = flag.String("persistence-dir", "./data", "持久化目录")
syncEvery = flag.Int("sync-every", 100, "每写入多少条数据同步一次")
configPath = flag.String("config", "config.yaml", "配置文件路径")
genSampleConfig = flag.Bool("gen-sample-config", false, "生成示例配置文件")
)
func main() {
if *genSampleConfig {
err := GenerateSampleConfig("./config.yaml.sample")
if err != nil {
log.Fatalf("生成示例配置文件失败: %v", err)
}
log.Println("示例配置文件已生成")
return
}
flag.Parse()
if *configPath != "" {
config, err := LoadConfig(*configPath)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
restAddr = &config.RestAddr
wsAddr = &config.WsAddr
metricsAddr = &config.MetricsAddr
natsURL = &config.NATSURL
persistenceType = &config.PersistenceType
persistenceDir = &config.PersistenceDir
syncEvery = &config.SyncEvery
}
// 创建存储引擎
engine := storage.NewMemoryEngine()

36
docs/design/task.md Normal file
View File

@ -0,0 +1,36 @@
0. 构建工具
添加构建脚本,要求:
添加Makefile
多平台构建
构建时,如果有单元测试,先执行单元测试
1. 测试用例编写
为各个组件编写单元测试
添加集成测试
进行性能测试和基准测试
2. 功能增强
实现数据压缩
添加更多查询类型
实现数据备份和恢复
添加访问控制和认证
3. 部署相关
添加Docker支持
创建Kubernetes部署配置
编写运维文档
4. 性能优化
优化内存使用
实现数据分片
添加缓存层
5. 监控和告警
完善监控指标
添加告警规则
实现日志聚合

9
go.mod
View File

@ -9,6 +9,7 @@ require (
github.com/gorilla/websocket v1.5.1
github.com/nats-io/nats.go v1.31.0
github.com/prometheus/client_golang v1.19.1
github.com/stretchr/testify v1.9.0
)
require (
@ -17,6 +18,7 @@ require (
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/chenzhuoyu/iasm v0.9.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@ -24,7 +26,6 @@ require (
github.com/go-playground/validator/v10 v10.16.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.4 // indirect
@ -32,17 +33,17 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/nats-io/nkeys v0.4.6 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/quic-go/quic-go v0.52.0 // indirect
github.com/quic-go/quic-go v0.52.0
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/mock v0.5.0 // indirect
@ -55,5 +56,5 @@ require (
golang.org/x/text v0.17.0 // indirect
golang.org/x/tools v0.22.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.1
)

41
go.sum
View File

@ -26,7 +26,10 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@ -37,12 +40,10 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
@ -58,14 +59,13 @@ github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/4
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -79,29 +79,24 @@ github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA=
github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@ -112,8 +107,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
@ -123,41 +119,28 @@ go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.6.0 h1:S0JTfE48HbRj80+4tbvZDYsJ3tGv6BUU3XxyZ7CirAc=
golang.org/x/arch v0.6.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -27,6 +27,8 @@ type WriteRequest struct {
Timestamp *time.Time `json:"timestamp,omitempty"`
}
type Response map[string]any
// BatchWriteRequest 批量写入请求
type BatchWriteRequest struct {
Points []WriteRequest `json:"points"`
@ -153,7 +155,8 @@ func (s *RESTServer) handleBatchWrite(c *gin.Context) {
}
// 批量写入数据
if err := s.dataManager.BatchWrite(c.Request.Context(), batch); err != nil {
// 使用当前时间作为批量写入的时间戳
if err := s.dataManager.BatchWrite(c.Request.Context(), convertBatch(batch), time.Now()); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to batch write data: " + err.Error(),
})
@ -162,6 +165,7 @@ func (s *RESTServer) handleBatchWrite(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"count": len(batch),
})
}
@ -255,3 +259,24 @@ func (s *RESTServer) Start(addr string) error {
func (s *RESTServer) Stop(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
// convertBatch 将内部批处理格式转换为DataManager.BatchWrite所需的格式
func convertBatch(batch []struct {
ID model.DataPointID
Value model.DataValue
}) []struct {
ID model.DataPointID
Value interface{}
} {
result := make([]struct {
ID model.DataPointID
Value interface{}
}, len(batch))
for i, item := range batch {
result[i].ID = item.ID
result[i].Value = item.Value.Value
}
return result
}

253
pkg/api/rest_test.go Normal file
View File

@ -0,0 +1,253 @@
package api
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"git.pyer.club/kingecg/gotidb/pkg/manager"
"git.pyer.club/kingecg/gotidb/pkg/model"
"git.pyer.club/kingecg/gotidb/pkg/storage"
)
func setupTestRESTServer() *RESTServer {
// 创建存储引擎
engine := storage.NewMemoryEngine()
// 创建数据管理器
dataManager := manager.NewDataManager(engine)
// 创建REST服务器
server := NewRESTServer(dataManager)
return server
}
func TestRESTServer_WriteEndpoint(t *testing.T) {
// 设置测试模式
gin.SetMode(gin.TestMode)
// 创建测试服务器
server := setupTestRESTServer()
// 创建测试请求
writeReq := WriteRequest{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
Value: 25.5,
}
body, _ := json.Marshal(writeReq)
req, _ := http.NewRequest("POST", "/api/v1/write", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
// 创建响应记录器
w := httptest.NewRecorder()
// 设置路由
r := gin.New()
r.POST("/api/v1/write", server.handleWrite)
// 执行请求
r.ServeHTTP(w, req)
// 检查响应状态码
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
// 解析响应
var resp Response
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Errorf("Failed to unmarshal response: %v", err)
}
// 验证响应
if resp["status"] != "ok" {
t.Errorf("Expected success to be true, got false")
}
}
func TestRESTServer_BatchWriteEndpoint(t *testing.T) {
// 设置测试模式
gin.SetMode(gin.TestMode)
// 创建测试服务器
server := setupTestRESTServer()
// 创建测试请求
batchReq := BatchWriteRequest{
Points: []WriteRequest{
{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
Value: 25.5,
},
{
DeviceID: "test-device",
MetricCode: "humidity",
Labels: map[string]string{
"location": "room1",
},
Value: 60.0,
},
},
}
body, _ := json.Marshal(batchReq)
req, _ := http.NewRequest("POST", "/api/v1/batch_write", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
// 创建响应记录器
w := httptest.NewRecorder()
// 设置路由
r := gin.New()
r.POST("/api/v1/batch_write", server.handleBatchWrite)
// 执行请求
r.ServeHTTP(w, req)
// 检查响应状态码
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
// 解析响应
var resp Response
err := json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Errorf("Failed to unmarshal response: %v", err)
}
// 验证响应
if resp["status"] != "ok" {
t.Errorf("Expected success to be true, got false")
}
if resp["count"] != 2 {
t.Errorf("Expected count to be 2, got %d", resp["count"])
}
}
func TestRESTServer_QueryEndpoint(t *testing.T) {
// 设置测试模式
gin.SetMode(gin.TestMode)
// 创建测试服务器
server := setupTestRESTServer()
// 写入测试数据
engine := storage.NewMemoryEngine()
dataManager := manager.NewDataManager(engine)
server.dataManager = dataManager
id := model.DataPointID{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
now := time.Now()
value := model.DataValue{
Timestamp: now,
Value: 25.5,
}
err := dataManager.Write(context.Background(), id, value)
if err != nil {
t.Fatalf("Failed to write test data: %v", err)
}
// 创建测试请求
queryReq := QueryRequest{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
QueryType: "latest",
}
body, _ := json.Marshal(queryReq)
req, _ := http.NewRequest("POST", "/api/v1/query", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
// 创建响应记录器
w := httptest.NewRecorder()
// 设置路由
r := gin.New()
r.POST("/api/v1/query", server.handleQuery)
// 执行请求
r.ServeHTTP(w, req)
// 检查响应状态码
if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
// 解析响应
var resp Response
err = json.Unmarshal(w.Body.Bytes(), &resp)
if err != nil {
t.Errorf("Failed to unmarshal response: %v", err)
}
// 验证响应
if resp["status"] != "ok" {
t.Errorf("Expected success to be true, got false")
}
// 验证返回的数据
if resp["timestamp"] == nil {
t.Errorf("Expected data to be non-nil")
}
// // 验证最新值
// if resp.QueryType != "latest" {
// t.Errorf("Expected query_type to be 'latest', got '%s'", resp.QueryType)
// }
// if resp.Data.(map[string]interface{})["value"] != 25.5 {
// t.Errorf("Expected value to be 25.5, got %v", resp.Data.(map[string]interface{})["value"])
// }
}
func TestRESTServer_Start(t *testing.T) {
// 创建测试服务器
server := setupTestRESTServer()
// 启动服务器(在后台)
go func() {
err := server.Start(":0") // 使用随机端口
if err != nil && err != http.ErrServerClosed {
t.Errorf("Failed to start server: %v", err)
}
}()
// 给服务器一点时间启动
time.Sleep(100 * time.Millisecond)
// 停止服务器
err := server.Stop(context.Background())
if err != nil {
t.Errorf("Failed to stop server: %v", err)
}
}

234
pkg/api/websocket_test.go Normal file
View File

@ -0,0 +1,234 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"git.pyer.club/kingecg/gotidb/pkg/manager"
"git.pyer.club/kingecg/gotidb/pkg/model"
"git.pyer.club/kingecg/gotidb/pkg/storage"
)
func setupTestWebSocketServer() *WebSocketServer {
// 创建存储引擎
engine := storage.NewMemoryEngine()
// 创建数据管理器
dataManager := manager.NewDataManager(engine)
// 创建WebSocket服务器
server := NewWebSocketServer(dataManager)
return server
}
func TestWebSocketServer_Connection(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 创建HTTP服务器
httpServer := httptest.NewServer(server.router)
defer httpServer.Close()
// 将HTTP URL转换为WebSocket URL
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http")
// 连接到WebSocket服务器
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket server: %v", err)
}
defer ws.Close()
// 发送订阅消息
subscription := SubscriptionRequest{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
err = ws.WriteJSON(subscription)
if err != nil {
t.Fatalf("Failed to send subscription message: %v", err)
}
// 等待一段时间,确保订阅已处理
time.Sleep(100 * time.Millisecond)
// 写入数据触发WebSocket推送
id := model.DataPointID{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
value := model.DataValue{
Timestamp: time.Now(),
Value: 25.5,
}
err = server.dataManager.Write(context.Background(), id, value)
if err != nil {
t.Fatalf("Failed to write data: %v", err)
}
// 设置读取超时
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
// 读取WebSocket消息
_, message, err := ws.ReadMessage()
if err != nil {
t.Fatalf("Failed to read WebSocket message: %v", err)
}
// 解析消息
var update DataChangeEvent
err = json.Unmarshal(message, &update)
if err != nil {
t.Fatalf("Failed to unmarshal WebSocket message: %v", err)
}
// 验证消息内容
if update.DeviceID != "test-device" {
t.Errorf("Expected DeviceID to be 'test-device', got '%s'", update.DeviceID)
}
if update.MetricCode != "temperature" {
t.Errorf("Expected MetricCode to be 'temperature', got '%s'", update.MetricCode)
}
if update.Value != 25.5 {
t.Errorf("Expected Value to be 25.5, got %v", update.Value)
}
if update.Labels["location"] != "room1" {
t.Errorf("Expected Labels['location'] to be 'room1', got '%s'", update.Labels["location"])
}
}
func TestWebSocketServer_MultipleSubscriptions(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 创建HTTP服务器
httpServer := httptest.NewServer(server.router)
defer httpServer.Close()
// 将HTTP URL转换为WebSocket URL
wsURL := "ws" + strings.TrimPrefix(httpServer.URL, "http")
// 连接到WebSocket服务器
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Failed to connect to WebSocket server: %v", err)
}
defer ws.Close()
// 发送多个订阅消息
subscriptions := []SubscriptionRequest{
{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
},
{
DeviceID: "test-device",
MetricCode: "humidity",
Labels: map[string]string{
"location": "room1",
},
},
}
for _, subscription := range subscriptions {
err = ws.WriteJSON(subscription)
if err != nil {
t.Fatalf("Failed to send subscription message: %v", err)
}
}
// 等待一段时间,确保订阅已处理
time.Sleep(100 * time.Millisecond)
// 写入数据触发WebSocket推送
for _, subscription := range subscriptions {
id := model.DataPointID{
DeviceID: subscription.DeviceID,
MetricCode: subscription.MetricCode,
Labels: subscription.Labels,
}
value := model.DataValue{
Timestamp: time.Now(),
Value: 25.5,
}
err = server.dataManager.Write(context.Background(), id, value)
if err != nil {
t.Fatalf("Failed to write data: %v", err)
}
// 等待一段时间,确保数据已处理
time.Sleep(100 * time.Millisecond)
// 设置读取超时
ws.SetReadDeadline(time.Now().Add(1 * time.Second))
// 读取WebSocket消息
_, message, err := ws.ReadMessage()
if err != nil {
t.Fatalf("Failed to read WebSocket message: %v", err)
}
// 解析消息
var update DataChangeEvent
err = json.Unmarshal(message, &update)
if err != nil {
t.Fatalf("Failed to unmarshal WebSocket message: %v", err)
}
// 验证消息内容
if update.DeviceID != subscription.DeviceID {
t.Errorf("Expected DeviceID to be '%s', got '%s'", subscription.DeviceID, update.DeviceID)
}
if update.MetricCode != subscription.MetricCode {
t.Errorf("Expected MetricCode to be '%s', got '%s'", subscription.MetricCode, update.MetricCode)
}
}
}
func TestWebSocketServer_Start(t *testing.T) {
// 创建测试服务器
server := setupTestWebSocketServer()
// 启动服务器(在后台)
go func() {
err := server.Start(":0") // 使用随机端口
if err != nil && err != http.ErrServerClosed {
t.Errorf("Failed to start server: %v", err)
}
}()
// 给服务器一点时间启动
time.Sleep(100 * time.Millisecond)
// 停止服务器
err := server.Stop(context.Background())
if err != nil {
t.Errorf("Failed to stop server: %v", err)
}
}

View File

@ -50,16 +50,25 @@ func (m *DataManager) Write(ctx context.Context, id model.DataPointID, value mod
// BatchWrite 批量写入数据
func (m *DataManager) BatchWrite(ctx context.Context, batch []struct {
ID model.DataPointID
Value model.DataValue
}) error {
Value interface{}
}, timestamp time.Time) error {
for _, item := range batch {
if err := m.Write(ctx, item.ID, item.Value); err != nil {
value := model.DataValue{
Timestamp: timestamp,
Value: item.Value,
}
if err := m.Write(ctx, item.ID, value); err != nil {
return err
}
}
return nil
}
// Query 执行查询ExecuteQuery的别名
func (m *DataManager) Query(ctx context.Context, id model.DataPointID, query model.Query) (model.Result, error) {
return m.ExecuteQuery(ctx, id, query)
}
// RegisterCallback 注册数据变更回调
func (m *DataManager) RegisterCallback(callback DataChangeCallback) {
m.callbacksLock.Lock()
@ -106,18 +115,11 @@ func (m *DataManager) EnablePersistence(config storage.PersistenceConfig) error
return m.engine.EnablePersistence(config)
}
// CreateDataPoint 创建一个新的数据点
func CreateDataPoint(deviceID, metricCode string, labels map[string]string, value interface{}) (model.DataPointID, model.DataValue) {
id := model.DataPointID{
// CreateDataPoint 创建一个新的数据点ID
func CreateDataPoint(deviceID, metricCode string, labels map[string]string, value interface{}) model.DataPointID {
return model.DataPointID{
DeviceID: deviceID,
MetricCode: metricCode,
Labels: labels,
}
dataValue := model.DataValue{
Timestamp: time.Now(),
Value: value,
}
return id, dataValue
}

View File

@ -0,0 +1,246 @@
package manager
import (
"context"
"testing"
"time"
"git.pyer.club/kingecg/gotidb/pkg/model"
"git.pyer.club/kingecg/gotidb/pkg/storage"
)
func TestDataManager(t *testing.T) {
// 创建存储引擎
engine := storage.NewMemoryEngine()
// 创建数据管理器
manager := NewDataManager(engine)
// 创建测试数据
deviceID := "test-device"
metricCode := "temperature"
labels := map[string]string{
"location": "room1",
}
// 测试创建数据点
t.Run("CreateDataPoint", func(t *testing.T) {
id := CreateDataPoint(deviceID, metricCode, labels, nil)
if id.DeviceID != deviceID {
t.Errorf("CreateDataPoint() DeviceID = %v, want %v", id.DeviceID, deviceID)
}
if id.MetricCode != metricCode {
t.Errorf("CreateDataPoint() MetricCode = %v, want %v", id.MetricCode, metricCode)
}
if len(id.Labels) != len(labels) {
t.Errorf("CreateDataPoint() Labels length = %v, want %v", len(id.Labels), len(labels))
}
for k, v := range labels {
if id.Labels[k] != v {
t.Errorf("CreateDataPoint() Labels[%v] = %v, want %v", k, id.Labels[k], v)
}
}
})
// 测试写入数据
t.Run("Write", func(t *testing.T) {
id := CreateDataPoint(deviceID, metricCode, labels, nil)
value := 25.5
err := manager.Write(context.Background(), id, model.DataValue{
Timestamp: time.Now(),
Value: value,
})
if err != nil {
t.Errorf("Write() error = %v", err)
}
})
// 测试批量写入
t.Run("BatchWrite", func(t *testing.T) {
id1 := CreateDataPoint(deviceID, metricCode, labels, nil)
id2 := CreateDataPoint(deviceID, "humidity", labels, nil)
now := time.Now()
batch := []struct {
ID model.DataPointID
Value interface{}
}{
{
ID: id1,
Value: 26.0,
},
{
ID: id2,
Value: 60.0,
},
}
err := manager.BatchWrite(context.Background(), batch, now)
if err != nil {
t.Errorf("BatchWrite() error = %v", err)
}
})
// 测试查询
t.Run("Query", func(t *testing.T) {
id := CreateDataPoint(deviceID, metricCode, labels, nil)
now := time.Now()
value := 27.5
// 写入测试数据
err := manager.Write(context.Background(), id, model.DataValue{
Timestamp: now,
Value: value,
})
if err != nil {
t.Errorf("Write() for Query test error = %v", err)
}
// 测试最新值查询
t.Run("QueryLatest", func(t *testing.T) {
query := model.NewQuery(model.QueryTypeLatest, nil)
result, err := manager.Query(context.Background(), id, query)
if err != nil {
t.Errorf("Query() error = %v", err)
}
latest, ok := result.AsLatest()
if !ok {
t.Errorf("Query() result is not a latest result")
}
if latest.Value != value {
t.Errorf("Query() latest value = %v, want %v", latest.Value, value)
}
})
// 测试所有值查询
t.Run("QueryAll", func(t *testing.T) {
// 写入多个值
for i := 1; i <= 5; i++ {
newValue := model.DataValue{
Timestamp: now.Add(time.Duration(i) * time.Minute),
Value: value + float64(i),
}
err := manager.Write(context.Background(), id, newValue)
if err != nil {
t.Errorf("Write() for QueryAll error = %v", err)
}
}
// 查询所有值
query := model.NewQuery(model.QueryTypeAll, map[string]interface{}{
"limit": 10,
})
result, err := manager.Query(context.Background(), id, query)
if err != nil {
t.Errorf("Query() for QueryAll error = %v", err)
}
all, ok := result.AsAll()
if !ok {
t.Errorf("Query() result is not an all result")
}
// 验证返回的值数量
if len(all) != 6 { // 初始值 + 5个新值
t.Errorf("Query() all result length = %v, want %v", len(all), 6)
}
})
// 测试持续时间查询
t.Run("QueryDuration", func(t *testing.T) {
// 设置时间范围
from := now.Add(1 * time.Minute)
to := now.Add(3 * time.Minute)
// 查询指定时间范围内的值
query := model.NewQuery(model.QueryTypeDuration, map[string]interface{}{
"from": from,
"to": to,
})
result, err := manager.Query(context.Background(), id, query)
if err != nil {
t.Errorf("Query() for QueryDuration error = %v", err)
}
duration, ok := result.AsAll()
if !ok {
t.Errorf("Query() result is not a duration result")
}
// 验证返回的值数量
if len(duration) != 3 { // 1分钟、2分钟和3分钟的值
t.Errorf("Query() duration result length = %v, want %v", len(duration), 3)
}
// 验证所有值都在指定的时间范围内
for _, v := range duration {
if v.Timestamp.Before(from) || v.Timestamp.After(to) {
t.Errorf("Query() duration result contains value with timestamp %v outside range [%v, %v]", v.Timestamp, from, to)
}
}
})
})
// 测试回调
t.Run("Callback", func(t *testing.T) {
callbackCalled := false
var callbackID model.DataPointID
var callbackValue model.DataValue
// 注册回调
manager.RegisterCallback(func(id model.DataPointID, value model.DataValue) {
callbackCalled = true
callbackID = id
callbackValue = value
})
// 写入数据触发回调
id := CreateDataPoint(deviceID, metricCode, labels, nil)
now := time.Now()
value := 28.5
err := manager.Write(context.Background(), id, model.DataValue{
Timestamp: now,
Value: value,
})
if err != nil {
t.Errorf("Write() for Callback test error = %v", err)
}
// 验证回调是否被调用
if !callbackCalled {
t.Errorf("Callback not called")
}
// 验证回调参数
if !callbackID.Equal(id) {
t.Errorf("Callback ID = %v, want %v", callbackID, id)
}
if callbackValue.Value != value {
t.Errorf("Callback value = %v, want %v", callbackValue.Value, value)
}
})
// 测试关闭
t.Run("Close", func(t *testing.T) {
err := manager.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
})
}

130
pkg/messaging/nats_test.go Normal file
View File

@ -0,0 +1,130 @@
package messaging
import (
"context"
"testing"
"time"
"git.pyer.club/kingecg/gotidb/pkg/model"
"github.com/nats-io/nats.go/jetstream"
"github.com/stretchr/testify/assert"
)
// 模拟NATS连接
type mockNATSConn struct {
closeFunc func() error
}
func (m *mockNATSConn) Close() error {
if m.closeFunc != nil {
return m.closeFunc()
}
return nil
}
// 模拟JetStream
type mockJetStream struct {
publishFunc func(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error)
}
func (m *mockJetStream) Publish(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error) {
if m.publishFunc != nil {
return m.publishFunc(ctx, subject, data)
}
return jetstream.PubAck{}, nil
}
// 模拟Stream
type mockStream struct {
createOrUpdateConsumerFunc func(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error)
}
func (m *mockStream) CreateOrUpdateConsumer(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error) {
if m.createOrUpdateConsumerFunc != nil {
return m.createOrUpdateConsumerFunc(ctx, cfg)
}
return nil, nil
}
// 模拟Consumer
type mockConsumer struct {
messagesFunc func() (jetstream.MessagesContext, error)
}
func (m *mockConsumer) Messages() (jetstream.MessagesContext, error) {
if m.messagesFunc != nil {
return m.messagesFunc()
}
return nil, nil
}
func TestNATSMessaging_Publish(t *testing.T) {
publishCalled := false
mockJS := &mockJetStream{
publishFunc: func(ctx context.Context, subject string, data []byte) (jetstream.PubAck, error) {
publishCalled = true
return jetstream.PubAck{}, nil
},
}
messaging := &NATSMessaging{
conn: &mockNATSConn{},
js: mockJS,
}
id := model.DataPointID{
DeviceID: "device1",
MetricCode: "metric1",
Labels: map[string]string{"env": "test"},
}
value := model.DataValue{
Timestamp: time.Now(),
Value: 42.0,
}
err := messaging.Publish(context.Background(), id, value)
assert.NoError(t, err)
assert.True(t, publishCalled)
}
func TestNATSMessaging_Subscribe(t *testing.T) {
handlerCalled := false
handler := func(msg DataMessage) error {
handlerCalled = true
return nil
}
mockConsumer := &mockConsumer{}
mockStream := &mockStream{
createOrUpdateConsumerFunc: func(ctx context.Context, cfg jetstream.ConsumerConfig) (jetstream.Consumer, error) {
return mockConsumer, nil
},
}
messaging := &NATSMessaging{
conn: &mockNATSConn{},
stream: mockStream,
}
err := messaging.Subscribe(handler)
assert.NoError(t, err)
assert.Contains(t, messaging.handlers, handler)
}
func TestNATSMessaging_Close(t *testing.T) {
closeCalled := false
mockConn := &mockNATSConn{
closeFunc: func() error {
closeCalled = true
return nil
},
}
messaging := &NATSMessaging{
conn: mockConn,
}
err := messaging.Close()
assert.NoError(t, err)
assert.True(t, closeCalled)
}

View File

@ -2,6 +2,7 @@ package model
import (
"fmt"
"sort"
"sync"
"time"
)
@ -15,7 +16,51 @@ type DataPointID struct {
// String 返回数据点标识的字符串表示
func (id DataPointID) String() string {
return fmt.Sprintf("%s:%s:%v", id.DeviceID, id.MetricCode, id.Labels)
return id.Hash()
}
// Equal 判断两个数据点标识是否相等
func (id DataPointID) Equal(other DataPointID) bool {
if id.DeviceID != other.DeviceID || id.MetricCode != other.MetricCode {
return false
}
if len(id.Labels) != len(other.Labels) {
return false
}
for k, v := range id.Labels {
if otherV, ok := other.Labels[k]; !ok || v != otherV {
return false
}
}
return true
}
// Hash 返回数据点标识的哈希值
func (id DataPointID) Hash() string {
if len(id.Labels) == 0 {
return fmt.Sprintf("%s:%s:", id.DeviceID, id.MetricCode)
}
// 提取并排序标签键
keys := make([]string, 0, len(id.Labels))
for k := range id.Labels {
keys = append(keys, k)
}
sort.Strings(keys)
// 按排序后的键顺序构建标签字符串
var labelStr string
for i, k := range keys {
if i == 0 {
labelStr = fmt.Sprintf("%s=%s", k, id.Labels[k])
} else {
labelStr = fmt.Sprintf("%s:%s=%s", labelStr, k, id.Labels[k])
}
}
return fmt.Sprintf("%s:%s:%s", id.DeviceID, id.MetricCode, labelStr)
}
// DataValue 数据值

View File

@ -2,6 +2,7 @@ package model
import (
"context"
"time"
)
// QueryType 查询类型
@ -27,7 +28,7 @@ type Result interface {
IsEmpty() bool
AsLatest() (DataValue, bool)
AsAll() ([]DataValue, bool)
AsDuration() (float64, bool)
AsDuration() (time.Duration, bool)
}
// QueryExecutor 查询执行器接口
@ -63,7 +64,7 @@ func (q *BaseQuery) Params() map[string]interface{} {
type BaseResult struct {
latest *DataValue
all []DataValue
duration *float64
duration *time.Duration
}
// NewLatestResult 创建一个最新值查询结果
@ -81,7 +82,7 @@ func NewAllResult(values []DataValue) Result {
}
// NewDurationResult 创建一个持续时间查询结果
func NewDurationResult(duration float64) Result {
func NewDurationResult(duration time.Duration) Result {
return &BaseResult{
duration: &duration,
}
@ -109,7 +110,7 @@ func (r *BaseResult) AsAll() ([]DataValue, bool) {
}
// AsDuration 将结果转换为持续时间
func (r *BaseResult) AsDuration() (float64, bool) {
func (r *BaseResult) AsDuration() (time.Duration, bool) {
if r.duration != nil {
return *r.duration, true
}

251
pkg/model/query_test.go Normal file
View File

@ -0,0 +1,251 @@
package model
import (
"testing"
"time"
)
func TestDataPointID(t *testing.T) {
tests := []struct {
name string
id DataPointID
wantEqual DataPointID
wantHash string
}{
{
name: "basic data point id",
id: DataPointID{
DeviceID: "device1",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
"floor": "1st",
},
},
wantEqual: DataPointID{
DeviceID: "device1",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
"floor": "1st",
},
},
wantHash: "device1:temperature:floor=1st:location=room1",
},
{
name: "empty labels",
id: DataPointID{
DeviceID: "device2",
MetricCode: "humidity",
Labels: map[string]string{},
},
wantEqual: DataPointID{
DeviceID: "device2",
MetricCode: "humidity",
Labels: map[string]string{},
},
wantHash: "device2:humidity:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test equality
if !tt.id.Equal(tt.wantEqual) {
t.Errorf("DataPointID.Equal() = false, want true")
}
// Test hash generation
if hash := tt.id.Hash(); hash != tt.wantHash {
t.Errorf("DataPointID.Hash() = %v, want %v", hash, tt.wantHash)
}
})
}
}
func TestDataValue(t *testing.T) {
now := time.Now()
tests := []struct {
name string
value DataValue
want interface{}
}{
{
name: "float value",
value: DataValue{
Timestamp: now,
Value: 25.5,
},
want: 25.5,
},
{
name: "integer value",
value: DataValue{
Timestamp: now,
Value: 100,
},
want: 100,
},
{
name: "string value",
value: DataValue{
Timestamp: now,
Value: "test",
},
want: "test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.value.Value != tt.want {
t.Errorf("DataValue.Value = %v, want %v", tt.value.Value, tt.want)
}
if !tt.value.Timestamp.Equal(now) {
t.Errorf("DataValue.Timestamp = %v, want %v", tt.value.Timestamp, now)
}
})
}
}
func TestQuery(t *testing.T) {
tests := []struct {
name string
queryType QueryType
params map[string]interface{}
wantType QueryType
wantParams map[string]interface{}
}{
{
name: "latest query",
queryType: QueryTypeLatest,
params: nil,
wantType: QueryTypeLatest,
wantParams: map[string]interface{}{},
},
{
name: "all query",
queryType: QueryTypeAll,
params: map[string]interface{}{
"limit": 100,
},
wantType: QueryTypeAll,
wantParams: map[string]interface{}{
"limit": 100,
},
},
{
name: "duration query",
queryType: QueryTypeDuration,
params: map[string]interface{}{
"from": "2023-01-01T00:00:00Z",
"to": "2023-01-02T00:00:00Z",
},
wantType: QueryTypeDuration,
wantParams: map[string]interface{}{
"from": "2023-01-01T00:00:00Z",
"to": "2023-01-02T00:00:00Z",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := NewQuery(tt.queryType, tt.params)
if query.Type() != tt.wantType {
t.Errorf("Query.Type() = %v, want %v", query.Type(), tt.wantType)
}
params := query.Params()
if len(params) != len(tt.wantParams) {
t.Errorf("Query.Params() length = %v, want %v", len(params), len(tt.wantParams))
}
for k, v := range tt.wantParams {
if params[k] != v {
t.Errorf("Query.Params()[%v] = %v, want %v", k, params[k], v)
}
}
})
}
}
func TestQueryResult(t *testing.T) {
now := time.Now()
tests := []struct {
name string
result Result
wantLatest DataValue
wantAll []DataValue
wantDuration time.Duration
}{
{
name: "latest result",
result: NewLatestResult(DataValue{
Timestamp: now,
Value: 25.5,
}),
wantLatest: DataValue{
Timestamp: now,
Value: 25.5,
},
},
{
name: "all result",
result: NewAllResult([]DataValue{
{
Timestamp: now,
Value: 25.5,
},
{
Timestamp: now.Add(time.Second),
Value: 26.0,
},
}),
wantAll: []DataValue{
{
Timestamp: now,
Value: 25.5,
},
{
Timestamp: now.Add(time.Second),
Value: 26.0,
},
},
},
{
name: "duration result",
result: NewDurationResult(time.Hour),
wantDuration: time.Hour,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if latest, ok := tt.result.AsLatest(); ok {
if !latest.Timestamp.Equal(tt.wantLatest.Timestamp) || latest.Value != tt.wantLatest.Value {
t.Errorf("Result.AsLatest() = %v, want %v", latest, tt.wantLatest)
}
}
if all, ok := tt.result.AsAll(); ok {
if len(all) != len(tt.wantAll) {
t.Errorf("Result.AsAll() length = %v, want %v", len(all), len(tt.wantAll))
}
for i, v := range tt.wantAll {
if !all[i].Timestamp.Equal(v.Timestamp) || all[i].Value != v.Value {
t.Errorf("Result.AsAll()[%v] = %v, want %v", i, all[i], v)
}
}
}
if duration, ok := tt.result.AsDuration(); ok {
if duration != tt.wantDuration {
t.Errorf("Result.AsDuration() = %v, want %v", duration, tt.wantDuration)
}
}
})
}
}

159
pkg/monitoring/collector.go Normal file
View File

@ -0,0 +1,159 @@
package monitoring
import (
"time"
"github.com/prometheus/client_golang/prometheus"
)
// MetricsCollector 提供更简洁的指标收集API
type MetricsCollector struct {
writeTotal prometheus.Counter
queryTotal prometheus.Counter
writeLatency prometheus.Histogram
queryLatency prometheus.Histogram
activeConnections prometheus.Gauge
dataPointsCount prometheus.Gauge
persistenceLatency prometheus.Histogram
persistenceErrors prometheus.Counter
messagingLatency prometheus.Histogram
messagingErrors prometheus.Counter
websocketConnections prometheus.Gauge
}
// NewMetricsCollector 创建一个新的指标收集器
func NewMetricsCollector() *MetricsCollector {
return &MetricsCollector{
writeTotal: prometheus.NewCounter(prometheus.CounterOpts{
Name: "gotidb_write_total",
Help: "Total number of write operations",
}),
queryTotal: prometheus.NewCounter(prometheus.CounterOpts{
Name: "gotidb_query_total",
Help: "Total number of query operations",
}),
writeLatency: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "gotidb_write_latency_seconds",
Help: "Write operation latency in seconds",
Buckets: prometheus.DefBuckets,
}),
queryLatency: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "gotidb_query_latency_seconds",
Help: "Query operation latency in seconds",
Buckets: prometheus.DefBuckets,
}),
activeConnections: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "gotidb_active_connections",
Help: "Number of active connections",
}),
dataPointsCount: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "gotidb_data_points_count",
Help: "Number of data points in the database",
}),
persistenceLatency: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "gotidb_persistence_latency_seconds",
Help: "Persistence operation latency in seconds",
Buckets: prometheus.DefBuckets,
}),
persistenceErrors: prometheus.NewCounter(prometheus.CounterOpts{
Name: "gotidb_persistence_errors_total",
Help: "Total number of persistence errors",
}),
messagingLatency: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "gotidb_messaging_latency_seconds",
Help: "Messaging operation latency in seconds",
Buckets: prometheus.DefBuckets,
}),
messagingErrors: prometheus.NewCounter(prometheus.CounterOpts{
Name: "gotidb_messaging_errors_total",
Help: "Total number of messaging errors",
}),
websocketConnections: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "gotidb_websocket_connections",
Help: "Number of active WebSocket connections",
}),
}
}
// RecordWrite 记录写入操作及其延迟
func (c *MetricsCollector) RecordWrite(duration time.Duration) {
c.writeTotal.Inc()
c.writeLatency.Observe(duration.Seconds())
}
// RecordQuery 记录查询操作及其延迟
func (c *MetricsCollector) RecordQuery(duration time.Duration) {
c.queryTotal.Inc()
c.queryLatency.Observe(duration.Seconds())
}
// IncActiveConnections 增加活跃连接数
func (c *MetricsCollector) IncActiveConnections() {
c.activeConnections.Inc()
}
// DecActiveConnections 减少活跃连接数
func (c *MetricsCollector) DecActiveConnections() {
c.activeConnections.Dec()
}
// SetDataPointsCount 设置数据点数量
func (c *MetricsCollector) SetDataPointsCount(count float64) {
c.dataPointsCount.Set(count)
}
// RecordPersistence 记录持久化操作及其延迟
func (c *MetricsCollector) RecordPersistence(duration time.Duration, err error) {
c.persistenceLatency.Observe(duration.Seconds())
if err != nil {
c.persistenceErrors.Inc()
}
}
// RecordMessaging 记录消息操作及其延迟
func (c *MetricsCollector) RecordMessaging(duration time.Duration, err error) {
c.messagingLatency.Observe(duration.Seconds())
if err != nil {
c.messagingErrors.Inc()
}
}
// IncWebSocketConnections 增加WebSocket连接数
func (c *MetricsCollector) IncWebSocketConnections() {
c.websocketConnections.Inc()
}
// DecWebSocketConnections 减少WebSocket连接数
func (c *MetricsCollector) DecWebSocketConnections() {
c.websocketConnections.Dec()
}
// Describe 实现prometheus.Collector接口
func (c *MetricsCollector) Describe(ch chan<- *prometheus.Desc) {
c.writeTotal.Describe(ch)
c.queryTotal.Describe(ch)
c.writeLatency.Describe(ch)
c.queryLatency.Describe(ch)
c.activeConnections.Describe(ch)
c.dataPointsCount.Describe(ch)
c.persistenceLatency.Describe(ch)
c.persistenceErrors.Describe(ch)
c.messagingLatency.Describe(ch)
c.messagingErrors.Describe(ch)
c.websocketConnections.Describe(ch)
}
// Collect 实现prometheus.Collector接口
func (c *MetricsCollector) Collect(ch chan<- prometheus.Metric) {
c.writeTotal.Collect(ch)
c.queryTotal.Collect(ch)
c.writeLatency.Collect(ch)
c.queryLatency.Collect(ch)
c.activeConnections.Collect(ch)
c.dataPointsCount.Collect(ch)
c.persistenceLatency.Collect(ch)
c.persistenceErrors.Collect(ch)
c.messagingLatency.Collect(ch)
c.messagingErrors.Collect(ch)
c.websocketConnections.Collect(ch)
}

View File

@ -0,0 +1,240 @@
package monitoring
import (
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
)
func TestMetricsCollector(t *testing.T) {
// 创建指标收集器
collector := NewMetricsCollector()
// 测试写入操作指标
t.Run("WriteMetrics", func(t *testing.T) {
// 记录写入操作
collector.RecordWrite(10 * time.Millisecond)
// 验证写入总数
writeTotal := testutil.ToFloat64(collector.writeTotal)
if writeTotal != 1 {
t.Errorf("Expected write_total to be 1, got %v", writeTotal)
}
// 对于Histogram类型我们只验证它是否被正确注册和收集
// 而不是尝试获取其具体值
registry := prometheus.NewPedanticRegistry()
registry.MustRegister(collector.writeLatency)
metrics, err := registry.Gather()
if err != nil {
t.Errorf("Failed to gather metrics: %v", err)
}
if len(metrics) == 0 {
t.Error("Expected write_latency to be collected, but got no metrics")
}
})
// 测试查询操作指标
t.Run("QueryMetrics", func(t *testing.T) {
// 记录查询操作
collector.RecordQuery(20 * time.Millisecond)
// 验证查询总数
queryTotal := testutil.ToFloat64(collector.queryTotal)
if queryTotal != 1 {
t.Errorf("Expected query_total to be 1, got %v", queryTotal)
}
// 对于Histogram类型我们只验证它是否被正确注册和收集
registry := prometheus.NewPedanticRegistry()
registry.MustRegister(collector.queryLatency)
metrics, err := registry.Gather()
if err != nil {
t.Errorf("Failed to gather metrics: %v", err)
}
if len(metrics) == 0 {
t.Error("Expected query_latency to be collected, but got no metrics")
}
})
// 测试连接数指标
t.Run("ConnectionMetrics", func(t *testing.T) {
// 增加连接数
collector.IncActiveConnections()
collector.IncActiveConnections()
// 验证活跃连接数
activeConns := testutil.ToFloat64(collector.activeConnections)
if activeConns != 2 {
t.Errorf("Expected active_connections to be 2, got %v", activeConns)
}
// 减少连接数
collector.DecActiveConnections()
// 验证更新后的活跃连接数
activeConns = testutil.ToFloat64(collector.activeConnections)
if activeConns != 1 {
t.Errorf("Expected active_connections to be 1, got %v", activeConns)
}
})
// 测试数据点数量指标
t.Run("DataPointsMetrics", func(t *testing.T) {
// 设置数据点数量
collector.SetDataPointsCount(100)
// 验证数据点数量
dataPoints := testutil.ToFloat64(collector.dataPointsCount)
if dataPoints != 100 {
t.Errorf("Expected data_points_count to be 100, got %v", dataPoints)
}
})
// 测试持久化指标
t.Run("PersistenceMetrics", func(t *testing.T) {
// 记录持久化操作
collector.RecordPersistence(30*time.Millisecond, nil)
// 对于Histogram类型我们只验证它是否被正确注册和收集
registry := prometheus.NewPedanticRegistry()
registry.MustRegister(collector.persistenceLatency)
metrics, err := registry.Gather()
if err != nil {
t.Errorf("Failed to gather metrics: %v", err)
}
if len(metrics) == 0 {
t.Error("Expected persistence_latency to be collected, but got no metrics")
}
// 验证持久化错误数应该为0因为没有错误
persistenceErrors := testutil.ToFloat64(collector.persistenceErrors)
if persistenceErrors != 0 {
t.Errorf("Expected persistence_errors to be 0, got %v", persistenceErrors)
}
// 记录持久化错误
collector.RecordPersistence(30*time.Millisecond, errTestPersistence)
// 验证持久化错误数应该为1
persistenceErrors = testutil.ToFloat64(collector.persistenceErrors)
if persistenceErrors != 1 {
t.Errorf("Expected persistence_errors to be 1, got %v", persistenceErrors)
}
})
// 测试消息系统指标
t.Run("MessagingMetrics", func(t *testing.T) {
// 记录消息操作
collector.RecordMessaging(40*time.Millisecond, nil)
// 对于Histogram类型我们只验证它是否被正确注册和收集
registry := prometheus.NewPedanticRegistry()
registry.MustRegister(collector.messagingLatency)
metrics, err := registry.Gather()
if err != nil {
t.Errorf("Failed to gather metrics: %v", err)
}
if len(metrics) == 0 {
t.Error("Expected messaging_latency to be collected, but got no metrics")
}
// 验证消息错误数应该为0因为没有错误
messagingErrors := testutil.ToFloat64(collector.messagingErrors)
if messagingErrors != 0 {
t.Errorf("Expected messaging_errors to be 0, got %v", messagingErrors)
}
// 记录消息错误
collector.RecordMessaging(40*time.Millisecond, errTestMessaging)
// 验证消息错误数应该为1
messagingErrors = testutil.ToFloat64(collector.messagingErrors)
if messagingErrors != 1 {
t.Errorf("Expected messaging_errors to be 1, got %v", messagingErrors)
}
})
// 测试WebSocket连接指标
t.Run("WebSocketMetrics", func(t *testing.T) {
// 增加WebSocket连接数
collector.IncWebSocketConnections()
collector.IncWebSocketConnections()
// 验证WebSocket连接数
wsConns := testutil.ToFloat64(collector.websocketConnections)
if wsConns != 2 {
t.Errorf("Expected websocket_connections to be 2, got %v", wsConns)
}
// 减少WebSocket连接数
collector.DecWebSocketConnections()
// 验证更新后的WebSocket连接数
wsConns = testutil.ToFloat64(collector.websocketConnections)
if wsConns != 1 {
t.Errorf("Expected websocket_connections to be 1, got %v", wsConns)
}
})
// 测试指标注册
t.Run("MetricsRegistration", func(t *testing.T) {
registry := prometheus.NewRegistry()
// 注册指标收集器
err := registry.Register(collector)
if err != nil {
t.Errorf("Failed to register metrics collector: %v", err)
}
// 验证所有指标都已注册
metricFamilies, err := registry.Gather()
if err != nil {
t.Errorf("Failed to gather metrics: %v", err)
}
expectedMetrics := []string{
"gotidb_write_total",
"gotidb_query_total",
"gotidb_write_latency_seconds",
"gotidb_query_latency_seconds",
"gotidb_active_connections",
"gotidb_data_points_count",
"gotidb_persistence_latency_seconds",
"gotidb_persistence_errors_total",
"gotidb_messaging_latency_seconds",
"gotidb_messaging_errors_total",
"gotidb_websocket_connections",
}
for _, metricName := range expectedMetrics {
found := false
for _, mf := range metricFamilies {
if *mf.Name == metricName {
found = true
break
}
}
if !found {
t.Errorf("Expected metric %s not found in registry", metricName)
}
}
})
}
// 测试错误
var (
errTestPersistence = &testError{msg: "test persistence error"}
errTestMessaging = &testError{msg: "test messaging error"}
)
// 测试错误类型
type testError struct {
msg string
}
func (e *testError) Error() string {
return e.msg
}

View File

@ -3,6 +3,7 @@ package storage
import (
"context"
"sync"
"time"
"git.pyer.club/kingecg/gotidb/pkg/model"
)
@ -33,7 +34,7 @@ type StorageEngine interface {
// GetLatest 获取最新数据
GetLatest(ctx context.Context, id model.DataPointID) (model.DataValue, error)
// GetDuration 获取持续时间
GetDuration(ctx context.Context, id model.DataPointID) (float64, error)
GetDuration(ctx context.Context, id model.DataPointID) (time.Duration, error)
// EnablePersistence 启用持久化
EnablePersistence(config PersistenceConfig) error
// Close 关闭存储引擎
@ -47,6 +48,56 @@ type MemoryEngine struct {
persister Persister // 持久化器
}
// ReadLatest 读取最新数据GetLatest 的别名)
func (e *MemoryEngine) ReadLatest(ctx context.Context, id model.DataPointID) (model.DataValue, error) {
return e.GetLatest(ctx, id)
}
// BatchWrite 批量写入数据
func (e *MemoryEngine) BatchWrite(ctx context.Context, batch []struct {
ID model.DataPointID
Value model.DataValue
}) error {
for _, item := range batch {
if err := e.Write(ctx, item.ID, item.Value); err != nil {
return err
}
}
return nil
}
// ReadAll 读取所有数据Read 的别名)
func (e *MemoryEngine) ReadAll(ctx context.Context, id model.DataPointID) ([]model.DataValue, error) {
return e.Read(ctx, id)
}
// ReadDuration 读取指定时间范围内的数据
func (e *MemoryEngine) ReadDuration(ctx context.Context, id model.DataPointID, from, to time.Time) ([]model.DataValue, error) {
key := id.String()
e.dataLock.RLock()
buffer, exists := e.data[key]
e.dataLock.RUnlock()
if !exists {
return []model.DataValue{}, nil
}
// 读取所有数据
allValues := buffer.Read()
// 过滤出指定时间范围内的数据
var filteredValues []model.DataValue
for _, value := range allValues {
if (value.Timestamp.Equal(from) || value.Timestamp.After(from)) &&
(value.Timestamp.Equal(to) || value.Timestamp.Before(to)) {
filteredValues = append(filteredValues, value)
}
}
return filteredValues, nil
}
// NewMemoryEngine 创建一个新的内存存储引擎
func NewMemoryEngine() *MemoryEngine {
return &MemoryEngine{
@ -119,7 +170,7 @@ func (e *MemoryEngine) GetLatest(ctx context.Context, id model.DataPointID) (mod
}
// GetDuration 获取持续时间
func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (float64, error) {
func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (time.Duration, error) {
key := id.String()
e.dataLock.RLock()
@ -130,8 +181,7 @@ func (e *MemoryEngine) GetDuration(ctx context.Context, id model.DataPointID) (f
return 0, nil
}
duration := buffer.GetDuration()
return duration.Seconds(), nil
return buffer.GetDuration(), nil
}
// EnablePersistence 启用持久化

246
pkg/storage/engine_test.go Normal file
View File

@ -0,0 +1,246 @@
package storage
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"git.pyer.club/kingecg/gotidb/pkg/model"
)
func TestMemoryEngine(t *testing.T) {
// 创建内存存储引擎
engine := NewMemoryEngine()
// 创建测试数据
id := model.DataPointID{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
now := time.Now()
value := model.DataValue{
Timestamp: now,
Value: 25.5,
}
// 测试写入
t.Run("Write", func(t *testing.T) {
err := engine.Write(context.Background(), id, value)
if err != nil {
t.Errorf("Write() error = %v", err)
}
})
// 测试读取最新值
t.Run("ReadLatest", func(t *testing.T) {
latest, err := engine.ReadLatest(context.Background(), id)
if err != nil {
t.Errorf("ReadLatest() error = %v", err)
}
if !latest.Timestamp.Equal(value.Timestamp) {
t.Errorf("ReadLatest() timestamp = %v, want %v", latest.Timestamp, value.Timestamp)
}
if latest.Value != value.Value {
t.Errorf("ReadLatest() value = %v, want %v", latest.Value, value.Value)
}
})
// 测试批量写入
t.Run("BatchWrite", func(t *testing.T) {
id2 := model.DataPointID{
DeviceID: "test-device",
MetricCode: "humidity",
Labels: map[string]string{
"location": "room1",
},
}
value2 := model.DataValue{
Timestamp: now,
Value: 60.0,
}
batch := []struct {
ID model.DataPointID
Value model.DataValue
}{
{
ID: id,
Value: value,
},
{
ID: id2,
Value: value2,
},
}
err := engine.BatchWrite(context.Background(), batch)
if err != nil {
t.Errorf("BatchWrite() error = %v", err)
}
// 验证批量写入的数据
latest, err := engine.ReadLatest(context.Background(), id2)
if err != nil {
t.Errorf("ReadLatest() after BatchWrite error = %v", err)
}
if !latest.Timestamp.Equal(value2.Timestamp) {
t.Errorf("ReadLatest() after BatchWrite timestamp = %v, want %v", latest.Timestamp, value2.Timestamp)
}
if latest.Value != value2.Value {
t.Errorf("ReadLatest() after BatchWrite value = %v, want %v", latest.Value, value2.Value)
}
})
// 测试读取所有值
t.Run("ReadAll", func(t *testing.T) {
// 写入多个值
for i := 1; i <= 5; i++ {
newValue := model.DataValue{
Timestamp: now.Add(time.Duration(i) * time.Minute),
Value: 25.5 + float64(i),
}
err := engine.Write(context.Background(), id, newValue)
if err != nil {
t.Errorf("Write() for ReadAll error = %v", err)
}
}
// 读取所有值
values, err := engine.ReadAll(context.Background(), id)
if err != nil {
t.Errorf("ReadAll() error = %v", err)
}
// 验证读取的值数量
if len(values) != 6 { // 初始值 + 5个新值
t.Errorf("ReadAll() returned %v values, want %v", len(values), 6)
}
// 验证值是按时间顺序排列的
for i := 1; i < len(values); i++ {
if values[i].Timestamp.Before(values[i-1].Timestamp) {
t.Errorf("ReadAll() values not in chronological order")
}
}
})
// 测试读取持续时间内的值
t.Run("ReadDuration", func(t *testing.T) {
// 设置时间范围
from := now.Add(1 * time.Minute)
to := now.Add(3 * time.Minute)
// 读取指定时间范围内的值
values, err := engine.ReadDuration(context.Background(), id, from, to)
if err != nil {
t.Errorf("ReadDuration() error = %v", err)
}
// 验证读取的值数量
if len(values) != 3 { // 1分钟、2分钟和3分钟的值
t.Errorf("ReadDuration() returned %v values, want %v", len(values), 3)
}
// 验证所有值都在指定的时间范围内
for _, v := range values {
if v.Timestamp.Before(from) || v.Timestamp.After(to) {
t.Errorf("ReadDuration() returned value with timestamp %v outside range [%v, %v]", v.Timestamp, from, to)
}
}
})
}
func TestPersistence(t *testing.T) {
// 创建临时目录
tempDir, err := os.MkdirTemp("", "gotidb-test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
// 创建内存存储引擎
engine := NewMemoryEngine()
// 启用WAL持久化
persistenceConfig := PersistenceConfig{
Type: PersistenceTypeWAL,
Directory: tempDir,
SyncEvery: 1, // 每次写入都同步
}
err = engine.EnablePersistence(persistenceConfig)
if err != nil {
t.Fatalf("EnablePersistence() error = %v", err)
}
// 创建测试数据
id := model.DataPointID{
DeviceID: "test-device",
MetricCode: "temperature",
Labels: map[string]string{
"location": "room1",
},
}
now := time.Now()
value := model.DataValue{
Timestamp: now,
Value: 25.5,
}
// 写入数据
err = engine.Write(context.Background(), id, value)
if err != nil {
t.Errorf("Write() with persistence error = %v", err)
}
// 关闭引擎
err = engine.Close()
if err != nil {
t.Errorf("Close() error = %v", err)
}
// 检查WAL文件是否存在
walFiles, err := filepath.Glob(filepath.Join(tempDir, "*.wal"))
if err != nil {
t.Errorf("Failed to list WAL files: %v", err)
}
if len(walFiles) == 0 {
t.Errorf("No WAL files found after persistence")
}
// 创建新的引擎并从WAL恢复
newEngine := NewMemoryEngine()
err = newEngine.EnablePersistence(persistenceConfig)
if err != nil {
t.Fatalf("EnablePersistence() for new engine error = %v", err)
}
// 读取恢复后的数据
latest, err := newEngine.ReadLatest(context.Background(), id)
if err != nil {
t.Errorf("ReadLatest() after recovery error = %v", err)
}
// 验证恢复的数据
if latest.Value != value.Value {
t.Errorf("ReadLatest() after recovery value = %v, want %v", latest.Value, value.Value)
}
// 关闭新引擎
err = newEngine.Close()
if err != nil {
t.Errorf("Close() new engine error = %v", err)
}
}