gotidb/pkg/engine/mock_engine.go

371 lines
7.5 KiB
Go

package engine
import (
"context"
"fmt"
"sync"
"time"
)
// MockEngine 是一个用于测试的模拟引擎实现
type MockEngine struct {
mu sync.RWMutex
points map[string][]DataPoint // seriesID -> 数据点列表
stats EngineStats
opened bool
closed bool
writeError error
queryError error
compactError error
cleanupError error
writeCallback func([]DataPoint) error
queryCallback func(Query) (QueryResult, error)
}
// NewMockEngine 创建一个新的模拟引擎
func NewMockEngine() *MockEngine {
return &MockEngine{
points: make(map[string][]DataPoint),
stats: EngineStats{
LastWriteTime: time.Now(),
},
}
}
// SetWriteError 设置写入操作的错误
func (m *MockEngine) SetWriteError(err error) {
m.writeError = err
}
// SetQueryError 设置查询操作的错误
func (m *MockEngine) SetQueryError(err error) {
m.queryError = err
}
// SetCompactError 设置压缩操作的错误
func (m *MockEngine) SetCompactError(err error) {
m.compactError = err
}
// SetCleanupError 设置清理操作的错误
func (m *MockEngine) SetCleanupError(err error) {
m.cleanupError = err
}
// SetWriteCallback 设置写入回调函数
func (m *MockEngine) SetWriteCallback(callback func([]DataPoint) error) {
m.writeCallback = callback
}
// SetQueryCallback 设置查询回调函数
func (m *MockEngine) SetQueryCallback(callback func(Query) (QueryResult, error)) {
m.queryCallback = callback
}
// IsOpened 返回引擎是否已打开
func (m *MockEngine) IsOpened() bool {
return m.opened
}
// IsClosed 返回引擎是否已关闭
func (m *MockEngine) IsClosed() bool {
return m.closed
}
// GetPoints 返回所有数据点
func (m *MockEngine) GetPoints() map[string][]DataPoint {
m.mu.RLock()
defer m.mu.RUnlock()
// 创建一个副本
result := make(map[string][]DataPoint)
for seriesID, points := range m.points {
pointsCopy := make([]DataPoint, len(points))
copy(pointsCopy, points)
result[seriesID] = pointsCopy
}
return result
}
// GetPointsCount 返回数据点总数
func (m *MockEngine) GetPointsCount() int {
m.mu.RLock()
defer m.mu.RUnlock()
count := 0
for _, points := range m.points {
count += len(points)
}
return count
}
// Open 实现Engine接口
func (m *MockEngine) Open() error {
m.opened = true
return nil
}
// Close 实现Engine接口
func (m *MockEngine) Close() error {
m.closed = true
return nil
}
// Write 实现Engine接口
func (m *MockEngine) Write(ctx context.Context, points []DataPoint) error {
if m.writeError != nil {
return m.writeError
}
if m.writeCallback != nil {
if err := m.writeCallback(points); err != nil {
return err
}
}
m.mu.Lock()
defer m.mu.Unlock()
for _, point := range points {
seriesID := point.GetSeriesID()
if seriesID == "" {
return fmt.Errorf("invalid series ID for point: %+v", point)
}
m.points[seriesID] = append(m.points[seriesID], point)
}
m.stats.PointsCount += int64(len(points))
m.stats.LastWriteTime = time.Now()
m.stats.SeriesCount = int64(len(m.points))
return nil
}
// Query 实现Engine接口
func (m *MockEngine) Query(ctx context.Context, query Query) (QueryResult, error) {
if m.queryError != nil {
return nil, m.queryError
}
if m.queryCallback != nil {
return m.queryCallback(query)
}
m.mu.RLock()
defer m.mu.RUnlock()
var result QueryResult
// 根据查询类型返回不同的结果
switch query.Type {
case QueryTypeLatest:
// 返回每个序列的最新数据点
for seriesID, points := range m.points {
if len(points) == 0 {
continue
}
// 检查标签是否匹配
if !matchTags(points[0].Labels, query.Tags) {
continue
}
// 获取最新的数据点
latest := points[len(points)-1]
result = append(result, SeriesResult{
SeriesID: seriesID,
Points: []DataPoint{latest},
})
}
case QueryTypeRaw:
// 返回所有匹配的数据点
for seriesID, points := range m.points {
if len(points) == 0 {
continue
}
// 检查标签是否匹配
if !matchTags(points[0].Labels, query.Tags) {
continue
}
var matchedPoints []DataPoint
for _, point := range points {
if point.Timestamp >= query.StartTime && point.Timestamp <= query.EndTime {
matchedPoints = append(matchedPoints, point)
}
}
if len(matchedPoints) > 0 {
// 应用限制
if query.Limit > 0 && len(matchedPoints) > query.Limit {
matchedPoints = matchedPoints[:query.Limit]
}
result = append(result, SeriesResult{
SeriesID: seriesID,
Points: matchedPoints,
})
}
}
case QueryTypeAggregate:
// 返回聚合结果
for seriesID, points := range m.points {
if len(points) == 0 {
continue
}
// 检查标签是否匹配
if !matchTags(points[0].Labels, query.Tags) {
continue
}
var matchedPoints []DataPoint
for _, point := range points {
if point.Timestamp >= query.StartTime && point.Timestamp <= query.EndTime {
matchedPoints = append(matchedPoints, point)
}
}
if len(matchedPoints) > 0 {
// 计算聚合值
aggregateValue := calculateAggregate(matchedPoints, query.AggregateType)
// 创建聚合结果点
aggregatePoint := DataPoint{
Timestamp: query.EndTime,
Value: aggregateValue,
Labels: matchedPoints[0].Labels,
}
result = append(result, SeriesResult{
SeriesID: seriesID,
Points: []DataPoint{aggregatePoint},
})
}
}
case QueryTypeValueDuration:
// 返回值持续时间查询结果
// 这里简化实现,实际应该计算每个值的持续时间
for seriesID, points := range m.points {
if len(points) == 0 {
continue
}
// 检查标签是否匹配
if !matchTags(points[0].Labels, query.Tags) {
continue
}
var matchedPoints []DataPoint
for _, point := range points {
if point.Timestamp >= query.StartTime && point.Timestamp <= query.EndTime {
matchedPoints = append(matchedPoints, point)
}
}
if len(matchedPoints) > 0 {
result = append(result, SeriesResult{
SeriesID: seriesID,
Points: matchedPoints,
})
}
}
default:
return nil, fmt.Errorf("unsupported query type: %s", query.Type)
}
return result, nil
}
// Compact 实现Engine接口
func (m *MockEngine) Compact() error {
if m.compactError != nil {
return m.compactError
}
m.stats.CompactionCount++
m.stats.LastCompaction = time.Now()
return nil
}
// Cleanup 实现Engine接口
func (m *MockEngine) Cleanup() error {
if m.cleanupError != nil {
return m.cleanupError
}
return nil
}
// Stats 实现Engine接口
func (m *MockEngine) Stats() EngineStats {
return m.stats
}
// matchTags 检查数据点的标签是否匹配查询标签
func matchTags(pointTags, queryTags map[string]string) bool {
for k, v := range queryTags {
if pointTags[k] != v {
return false
}
}
return true
}
// calculateAggregate 计算聚合值
func calculateAggregate(points []DataPoint, aggregateType AggregateType) float64 {
if len(points) == 0 {
return 0
}
switch aggregateType {
case AggregateTypeAvg:
sum := 0.0
for _, p := range points {
sum += p.Value
}
return sum / float64(len(points))
case AggregateTypeSum:
sum := 0.0
for _, p := range points {
sum += p.Value
}
return sum
case AggregateTypeMin:
min := points[0].Value
for _, p := range points {
if p.Value < min {
min = p.Value
}
}
return min
case AggregateTypeMax:
max := points[0].Value
for _, p := range points {
if p.Value > max {
max = p.Value
}
}
return max
case AggregateTypeCount:
return float64(len(points))
default:
return 0
}
}