diff --git a/kvstore.go b/kvstore.go index 306c7cd..18bd468 100644 --- a/kvstore.go +++ b/kvstore.go @@ -13,31 +13,34 @@ import ( // KVStore represents a simple in-memory key/value store. type KVStore struct { mu sync.RWMutex - store map[string]string + store map[string]interface{} // 修改为interface{}以支持多种类型 file string transaction *Transaction logFile string memoryLimit int64 memoryUsage int64 - dirtyKeys map[string]bool // 记录发生变化的Key - bucketCount int // 哈希桶数量 + dirtyKeys map[string]bool + bucketCount int + index map[string]int64 // 添加索引字段 + wg sync.WaitGroup // 添加WaitGroup以等待异步操作 } // Transaction represents an ongoing transaction type Transaction struct { - store map[string]string + store map[string]interface{} } // NewKVStore creates a new instance of KVStore. func NewKVStore(file string, memoryLimit int64, bucketCount int) *KVStore { store := &KVStore{ - store: make(map[string]string), + store: make(map[string]interface{}), file: file, logFile: file + ".log", memoryLimit: memoryLimit, memoryUsage: 0, dirtyKeys: make(map[string]bool), bucketCount: bucketCount, + index: make(map[string]int64), // 初始化索引 } // 启动时自动恢复日志 if err := store.RecoverFromLog(); err != nil { @@ -67,23 +70,117 @@ func (k *KVStore) periodicSave() { } } +func (k *KVStore) getBucketNo(key string) int { + return hashKey(key) % k.bucketCount +} + // saveKeyToBucket saves a key to its corresponding bucket file func (k *KVStore) saveKeyToBucket(key string) error { // 计算哈希值确定桶 - hash := hashKey(key) % k.bucketCount + hash := k.getBucketNo(key) bucketFile := fmt.Sprintf("%s.bucket%d", k.file, hash) - file, err := os.OpenFile(bucketFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + // 创建索引文件 + indexFile := fmt.Sprintf("%s.index%d", k.file, hash) + indexData := make(map[string]int64) + + // 读取整个桶文件,过滤掉旧键值对 + var entries [][]byte + if _, err := os.Stat(bucketFile); err == nil { + file, err := os.Open(bucketFile) + if err != nil { + return err + } + defer file.Close() + + // 读取索引文件 + if _, err := os.Stat(indexFile); err == nil { + indexFileData, err := os.ReadFile(indexFile) + if err == nil { + json.Unmarshal(indexFileData, &indexData) + } + } + + for { + // 读取键长度 + var keyLen int64 + if err := binary.Read(file, binary.LittleEndian, &keyLen); err != nil { + break + } + + // 读取键 + keyBytes := make([]byte, keyLen) + if _, err := file.Read(keyBytes); err != nil { + return err + } + + // 读取值长度 + var valueLen int64 + if err := binary.Read(file, binary.LittleEndian, &valueLen); err != nil { + return err + } + + // 读取值 + valueBytes := make([]byte, valueLen) + if _, err := file.Read(valueBytes); err != nil { + return err + } + + // 如果键不等于当前键,保留该记录 + if string(keyBytes) != key { + entry := new(bytes.Buffer) + if err := binary.Write(entry, binary.LittleEndian, keyLen); err != nil { + return err + } + entry.Write(keyBytes) + if err := binary.Write(entry, binary.LittleEndian, valueLen); err != nil { + return err + } + entry.Write(valueBytes) + entries = append(entries, entry.Bytes()) + } + } + } + + // 重新写入所有保留的记录和新记录 + file, err := os.Create(bucketFile) if err != nil { return err } defer file.Close() + // 更新索引 + offset, _ := file.Seek(0, 1) + indexData[key] = offset + + // 写入保留的记录 + for _, entry := range entries { + if _, err := file.Write(entry); err != nil { + return err + } + } + + // 写入新记录 value, exists := k.store[key] if !exists { return nil } + // 序列化值 + valueBytes, _ := json.Marshal(value) + // switch v := value.(type) { + // case string: + // valueBytes = []byte(v) + // case int: + // valueBytes = []byte(fmt.Sprintf("%d", v)) + // case bool: + // valueBytes = []byte(fmt.Sprintf("%t", v)) + // case []string, []int, []bool: + // valueBytes, _ = json.Marshal(v) + // default: + // return fmt.Errorf("unsupported value type") + // } + // 写入键长度和键 if err := binary.Write(file, binary.LittleEndian, int64(len(key))); err != nil { return err @@ -93,16 +190,75 @@ func (k *KVStore) saveKeyToBucket(key string) error { } // 写入值长度和值 - if err := binary.Write(file, binary.LittleEndian, int64(len(value))); err != nil { + if err := binary.Write(file, binary.LittleEndian, int64(len(valueBytes))); err != nil { return err } - if _, err := file.WriteString(value); err != nil { + if _, err := file.Write(valueBytes); err != nil { return err } + // 保存索引文件 + indexDataBytes, _ := json.Marshal(indexData) + os.WriteFile(indexFile, indexDataBytes, 0644) + return nil } +func (k *KVStore) getKeyFromBucket(key string) (interface{}, bool) { + // 计算哈希值确定桶 + hash := k.getBucketNo(key) + bucketFile := fmt.Sprintf("%s.bucket%d", k.file, hash) + bucketIndexFile := fmt.Sprintf("%s.index%d", k.file, hash) + // 读取索引文件 + indexData := make(map[string]int64) + if _, err := os.Stat(bucketIndexFile); err == nil { + indexFileData, err := os.ReadFile(bucketIndexFile) + if err == nil { + json.Unmarshal(indexFileData, &indexData) + } + } + // 检查索引中是否 + if offset, exists := indexData[key]; exists { + file, err := os.Open(bucketFile) + if err != nil { + return nil, false + } + defer file.Close() + // 定位到偏移量 + _, err = file.Seek(offset, 0) + if err != nil { + return nil, false + } + // 读取键长度 + var keyLen int64 + if err := binary.Read(file, binary.LittleEndian, &keyLen); err != nil { + return nil, false + } + // 读取键 + keyBytes := make([]byte, keyLen) + if _, err := file.Read(keyBytes); err != nil { + return nil, false + } + // 读取值长度 + var valueLen int64 + if err := binary.Read(file, binary.LittleEndian, &valueLen); err != nil { + return nil, false + } + // 读取值 + valueBytes := make([]byte, valueLen) + if _, err := file.Read(valueBytes); err != nil { + return nil, false + } + // 反序列化值 + var value interface{} + if err := json.Unmarshal(valueBytes, &value); err != nil { + return nil, false + } + return value, true + } + return nil, false +} + // hashKey computes a hash for the key func hashKey(key string) int { hash := 0 @@ -136,36 +292,71 @@ func (k *KVStore) Flush() error { } // 清空内存 - k.store = make(map[string]string) + k.store = make(map[string]interface{}) k.memoryUsage = 0 return nil } +func (k *KVStore) valueSize(value interface{}) int64 { + // 计算值的内存占用 + var valueSize int64 + switch v := value.(type) { + case string: + valueSize = int64(len(v)) + case int: + valueSize = 8 // int在64位系统上通常占8字节 + case bool: + valueSize = 1 // bool占1字节 + case []string: + for _, s := range v { + valueSize += int64(len(s)) + } + case []int: + valueSize = int64(len(v)) * 8 + case []bool: + valueSize = int64(len(v)) + } + return valueSize +} + // Put adds a key/value pair to the store. -func (k *KVStore) Put(key, value string) { +func (k *KVStore) Put(key string, value interface{}) { k.mu.Lock() defer k.mu.Unlock() + // 检查值类型 + switch value.(type) { + case string, int, bool, []string, []int, []bool: + // 有效类型,继续 + default: + panic("unsupported value type") + } + // 如果处于事务中,不检查内存限制 if k.transaction != nil { k.transaction.store[key] = value return } + // 计算值的内存占用 + valueSize := k.valueSize(value) + // 检查内存使用量 - newMemoryUsage := k.memoryUsage + int64(len(key)+len(value)) + newMemoryUsage := k.memoryUsage + int64(len(key)) + valueSize if newMemoryUsage > k.memoryLimit { if err := k.Flush(); err != nil { panic("Failed to flush store") } - newMemoryUsage = int64(len(key) + len(value)) + newMemoryUsage = int64(len(key)) + valueSize } k.store[key] = value k.memoryUsage = newMemoryUsage k.dirtyKeys[key] = true + k.wg.Add(1) // 增加WaitGroup计数器 go func() { + defer k.wg.Done() // 减少WaitGroup计数器 k.mu.Lock() defer k.mu.Unlock() k.LogOperation("put", key, value) // 异步记录操作日志 @@ -179,9 +370,12 @@ func (k *KVStore) Delete(key string) { if value, exists := k.store[key]; exists { delete(k.store, key) - k.memoryUsage -= int64(len(key) + len(value)) + valueSize := k.valueSize(value) + k.memoryUsage -= int64(len(key)) + valueSize k.dirtyKeys[key] = true + k.wg.Add(1) // 增加WaitGroup计数器 go func() { + defer k.wg.Done() // 减少WaitGroup计数器 k.mu.Lock() defer k.mu.Unlock() k.LogOperation("delete", key, "") // 异步记录操作日志 @@ -190,7 +384,7 @@ func (k *KVStore) Delete(key string) { } // Get retrieves the value associated with the given key. -func (k *KVStore) Get(key string) (string, bool) { +func (k *KVStore) Get(key string) (interface{}, bool) { k.mu.RLock() defer k.mu.RUnlock() if k.transaction != nil { @@ -200,6 +394,14 @@ func (k *KVStore) Get(key string) (string, bool) { } } value, exists := k.store[key] + if !exists { + value, exists = k.getKeyFromBucket(key) + if exists { + k.store[key] = value + } else { + return nil, false + } + } return value, exists } @@ -208,7 +410,24 @@ func (k *KVStore) BeginTransaction() { k.mu.Lock() defer k.mu.Unlock() k.transaction = &Transaction{ - store: make(map[string]string), + store: make(map[string]interface{}), // 修改为interface{}以支持多种类型 + } +} + +// PutInTransaction puts a key/value pair in the current transaction +func (k *KVStore) PutInTransaction(key string, value interface{}) { + k.mu.Lock() + defer k.mu.Unlock() + if k.transaction != nil { + // 检查值类型 + switch value.(type) { + case string, int, bool, []string, []int, []bool: + k.transaction.store[key] = value + default: + panic("unsupported value type") + } + } else { + k.store[key] = value } } @@ -223,7 +442,24 @@ func (k *KVStore) Commit() error { // 计算事务中键值对的内存使用量 var transactionMemoryUsage int64 for key, value := range k.transaction.store { - transactionMemoryUsage += int64(len(key) + len(value)) + var valueSize int64 + switch v := value.(type) { + case string: + valueSize = int64(len(v)) + case int: + valueSize = 8 // int在64位系统上通常占8字节 + case bool: + valueSize = 1 // bool占1字节 + case []string: + for _, s := range v { + valueSize += int64(len(s)) + } + case []int: + valueSize = int64(len(v)) * 8 + case []bool: + valueSize = int64(len(v)) + } + transactionMemoryUsage += int64(len(key)) + valueSize } // 检查内存使用量 @@ -253,17 +489,6 @@ func (k *KVStore) Rollback() { k.transaction = nil } -// PutInTransaction puts a key/value pair in the current transaction -func (k *KVStore) PutInTransaction(key, value string) { - k.mu.Lock() - defer k.mu.Unlock() - if k.transaction != nil { - k.transaction.store[key] = value - } else { - k.store[key] = value - } -} - // RecoverFromLog recovers data from log file func (k *KVStore) RecoverFromLog() error { k.mu.Lock() @@ -286,15 +511,15 @@ func (k *KVStore) RecoverFromLog() error { if len(line) == 0 { continue } - var logEntry map[string]string + var logEntry map[string]interface{} if err := json.Unmarshal(line, &logEntry); err != nil { return err } switch logEntry["op"] { case "put": - k.store[logEntry["key"]] = logEntry["value"] + k.store[logEntry["key"].(string)] = logEntry["value"] case "delete": - delete(k.store, logEntry["key"]) + delete(k.store, logEntry["key"].(string)) } } @@ -302,47 +527,9 @@ func (k *KVStore) RecoverFromLog() error { return os.Truncate(k.logFile, 0) } -// SaveToFile saves the current store to a file using binary format. -func (k *KVStore) SaveToFile() error { - k.mu.RLock() - defer k.mu.RUnlock() - - file, err := os.Create(k.file) - if err != nil { - return err - } - defer file.Close() - - // 写入键值对数量 - if err := binary.Write(file, binary.LittleEndian, int64(len(k.store))); err != nil { - return err - } - - // 逐个写入键值对 - for key, value := range k.store { - // 写入键长度和键 - if err := binary.Write(file, binary.LittleEndian, int64(len(key))); err != nil { - return err - } - if _, err := file.WriteString(key); err != nil { - return err - } - - // 写入值长度和值 - if err := binary.Write(file, binary.LittleEndian, int64(len(value))); err != nil { - return err - } - if _, err := file.WriteString(value); err != nil { - return err - } - } - - return nil -} - // LogOperation logs a key/value operation to the log file -func (k *KVStore) LogOperation(op string, key, value string) error { - logEntry := map[string]string{ +func (k *KVStore) LogOperation(op string, key, value interface{}) error { + logEntry := map[string]interface{}{ "op": op, "key": key, "value": value, diff --git a/kvstore_test.go b/kvstore_test.go index d706104..13f8f7b 100644 --- a/kvstore_test.go +++ b/kvstore_test.go @@ -17,47 +17,52 @@ func TestNewKVStore(t *testing.T) { func TestPutAndGet(t *testing.T) { store := NewKVStore("test.db", mumlimit, bucketCount) + + // 测试字符串类型 store.Put("key1", "value1") + store.wg.Wait() // 等待异步操作完成 value, exists := store.Get("key1") if !exists || value != "value1" { t.Error("Expected value 'value1' for key 'key1'") } + + // 测试整数类型 + store.Put("key2", 123) + store.wg.Wait() // 等待异步操作完成 + value, exists = store.Get("key2") + if !exists || value != 123 { + t.Error("Expected value 123 for key 'key2'") + } + + // 测试布尔类型 + store.Put("key3", true) + store.wg.Wait() // 等待异步操作完成 + value, exists = store.Get("key3") + if !exists || value != true { + t.Error("Expected value true for key 'key3'") + } + + // 测试数组类型 + store.Put("key4", []string{"a", "b", "c"}) + store.wg.Wait() // 等待异步操作完成 + value, exists = store.Get("key4") + if !exists { + t.Error("Expected value for key 'key4'") + } } func TestDelete(t *testing.T) { store := NewKVStore("test.db", mumlimit, bucketCount) store.Put("key1", "value1") + store.wg.Wait() // 等待异步操作完成 store.Delete("key1") + store.wg.Wait() // 等待异步操作完成 _, exists := store.Get("key1") if exists { t.Error("Expected key 'key1' to be deleted") } } -func TestSaveAndLoadFromFile(t *testing.T) { - store := NewKVStore("test.db", mumlimit, bucketCount) - store.Put("key1", "value1") - err := store.SaveToFile() - if err != nil { - t.Error("Failed to save store to file") - } - - newStore := NewKVStore("test.db", mumlimit, bucketCount) - err = newStore.LoadFromFile() - if err != nil { - t.Error("Failed to load store from file") - } - - value, exists := newStore.Get("key1") - if !exists || value != "value1" { - t.Error("Expected value 'value1' for key 'key1' after loading from file") - } - - // Clean up - os.Remove("test.db") - os.Remove("test.db.log") -} - func TestLogOperation(t *testing.T) { store := NewKVStore("test.db", mumlimit, bucketCount) err := store.LogOperation("put", "key1", "value1") @@ -69,8 +74,29 @@ func TestLogOperation(t *testing.T) { os.Remove("test.db.log") } +func createDB() { + store := NewKVStore("test.db", mumlimit, bucketCount) + store.Put("key1", "value1") + store.Put("key2", 123) + store.Put("key3", true) + store.Put("key4", []string{"a", "b", "c"}) + store.wg.Wait() +} + +func TestRecoverFromLog(t *testing.T) { + createDB() + store2 := NewKVStore("test.db", mumlimit, bucketCount) + // store2.RecoverFromLog() + value, exists := store2.Get("key1") + if !exists || value != "value1" { + t.Error("Expected value 'value1' for key 'key1' after recovery") + } +} + func TestTransaction(t *testing.T) { store := NewKVStore("test.db", mumlimit, bucketCount) + + // 测试字符串类型 store.BeginTransaction() store.PutInTransaction("key1", "value1") store.Commit() @@ -80,27 +106,24 @@ func TestTransaction(t *testing.T) { t.Error("Expected value 'value1' for key 'key1' after commit") } + // 测试整数类型 store.BeginTransaction() - store.PutInTransaction("key2", "value2") - store.Rollback() - - _, exists = store.Get("key2") - if exists { - t.Error("Expected key 'key2' to be rolled back") - } - - // 新增测试:验证事务中的PutInTransaction - store.BeginTransaction() - store.PutInTransaction("key3", "value3") - value, exists = store.Get("key3") - if !exists || value != "value3" { - t.Error("Expected value 'value3' for key 'key3' during transaction") - } + store.PutInTransaction("key2", 123) store.Commit() - value, exists = store.Get("key3") - if !exists || value != "value3" { - t.Error("Expected value 'value3' for key 'key3' after commit") + value, exists = store.Get("key2") + if !exists || value != 123 { + t.Error("Expected value 123 for key 'key2' after commit") + } + + // 测试回滚 + store.BeginTransaction() + store.PutInTransaction("key3", "value3") + store.Rollback() + + _, exists = store.Get("key3") + if exists { + t.Error("Expected key 'key3' to be rolled back") } // Clean up