diff --git a/cache.go b/cache.go index de300fa..ca3e75c 100644 --- a/cache.go +++ b/cache.go @@ -59,6 +59,9 @@ type unexportedInterface interface { } type Item struct { + Right *Item + Left *Item + Key string Object interface{} Expiration *time.Time } @@ -68,7 +71,177 @@ func (item *Item) Expired() bool { if item.Expiration == nil { return false } - return item.Expiration.Before(time.Now()) + return time.Now().UnixNano() > item.Expiration.UnixNano() +} + +func (item *Item) add(key string, object interface{}, expiration *time.Time) { + if item.Key == key { + item.Expiration = expiration + return + } + if item.Key == "" { + item.Key = key + item.Object = object + item.Expiration = expiration + return + } + if item.Key < key { + if item.Left == nil { + item.Left = new(Item) + } + item.Left.add(key, object, expiration) + return + } + if item.Key > key { + if item.Right == nil { + item.Right = new(Item) + } + item.Right.add(key, object, expiration) + return + } +} + +func (item *Item) copy(res *Item) { + item.Right = res.Right + item.Left = res.Left + item.Key = res.Key + item.Object = res.Object + item.Expiration = res.Expiration +} + +func (item *Item) addItem(res *Item) { + if res == nil { + return + } + if item.Object == nil { + item.copy(res) + return + } + if item.Key == res.Key { + item.copy(res) + } + if item.Key < res.Key { + item.Left.addItem(res) + return + } + if item.Key > res.Key { + item.Right.addItem(res) + return + } +} + +func (item *Item) delExpirie() { + if item.Left != nil { + item.Left.delExpirie() + } + if item.Expired() { + newItem := new(Item) + newItem.addItem(item.Left) + newItem.addItem(item.Right) + item.copy(newItem) + item = newItem + } + if item.Right != nil { + item.Right.delExpirie() + } +} + +func (item *Item) get(key string) (res *Item, found bool) { + if item == nil { + return res, false + } + if item.Key == key { + return item, true + } + if item.Key < key { + return item.Left.get(key) + } + if item.Key > key { + return item.Right.get(key) + } + return item, found +} + +func (item *Item) set(key string, obj interface{}) { + if item == nil { + return + } + if item.Key == key { + item.Object = obj + return + } + if item.Key < key { + if item.Left != nil { + item.Left.set(key, obj) + return + } + } + if item.Key > key { + if item.Right != nil { + item.Right.set(key, obj) + return + } + } +} + +func (item *Item) del(key string) { + res, found := item.get(key) + if found { + item.Key = "" + item.Object = nil + item.Expiration = nil + if res.Right != nil { + item.addItem(res.Right) + item.Right = nil + } + if res.Left != nil { + item.addItem(res.Left) + item.Left = nil + } + res = nil + } +} + +func (item *Item) rangeFunc(body func(*Item)) { + if item.Left != nil { + item.Left.rangeFunc(body) + } + body(item) + if item.Right != nil { + item.Right.rangeFunc(body) + } +} + +func (item *Item) len() (count int) { + if item.Left != nil { + count = item.Left.len() + } + count++ + if item.Right != nil { + count += item.Right.len() + } + return count +} + +func (item *Item) toSlice() (items []Item) { + if item.Left != nil { + items = append(items, item.Left.toSlice()...) + } + items = append(items, Item{ + Key: item.Key, + Object: item.Object, + Expiration: item.Expiration, + }) + if item.Right != nil { + items = append(items, item.Right.toSlice()...) + } + return items +} + +func (item *Item) fromSlice(items []Item) { + for _, i := range items { + item.add(i.Key, i.Object, i.Expiration) + } } type Cache struct { @@ -79,7 +252,7 @@ type Cache struct { type cache struct { sync.RWMutex defaultExpiration time.Duration - items map[string]*Item + items *Item //map[string]*Item janitor *janitor } @@ -103,10 +276,13 @@ func (c *cache) set(k string, x interface{}, d time.Duration) { t := time.Now().Add(d) e = &t } - c.items[k] = &Item{ - Object: x, - Expiration: e, - } + c.items.add(k, x, e) + /* + c.items[k] = &Item{ + Object: x, + Expiration: e, + } + */ } // Add an item to the cache only if an item doesn't already exist for the given @@ -147,7 +323,7 @@ func (c *cache) Get(k string) (interface{}, bool) { } func (c *cache) get(k string) (interface{}, bool) { - item, found := c.items[k] + item, found := c.items.get(k) if !found || item.Expired() { return nil, false } @@ -161,7 +337,7 @@ func (c *cache) get(k string) (interface{}, bool) { // of the specialized methods, e.g. IncrementInt64. func (c *cache) Increment(k string, n int64) error { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return fmt.Errorf("Item %s not found", k) @@ -208,7 +384,7 @@ func (c *cache) Increment(k string, n int64) error { // e.g. IncrementFloat64. func (c *cache) IncrementFloat(k string, n float64) error { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return fmt.Errorf("Item %s not found", k) @@ -231,7 +407,7 @@ func (c *cache) IncrementFloat(k string, n float64) error { // value is returned. func (c *cache) IncrementInt(k string, n int) (int, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -252,7 +428,7 @@ func (c *cache) IncrementInt(k string, n int) (int, error) { // value is returned. func (c *cache) IncrementInt8(k string, n int8) (int8, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -273,7 +449,7 @@ func (c *cache) IncrementInt8(k string, n int8) (int8, error) { // value is returned. func (c *cache) IncrementInt16(k string, n int16) (int16, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -294,7 +470,7 @@ func (c *cache) IncrementInt16(k string, n int16) (int16, error) { // value is returned. func (c *cache) IncrementInt32(k string, n int32) (int32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -315,7 +491,7 @@ func (c *cache) IncrementInt32(k string, n int32) (int32, error) { // value is returned. func (c *cache) IncrementInt64(k string, n int64) (int64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -336,7 +512,7 @@ func (c *cache) IncrementInt64(k string, n int64) (int64, error) { // value is returned. func (c *cache) IncrementUint(k string, n uint) (uint, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -357,7 +533,7 @@ func (c *cache) IncrementUint(k string, n uint) (uint, error) { // incremented value is returned. func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -378,7 +554,7 @@ func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { // incremented value is returned. func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -399,7 +575,7 @@ func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { // incremented value is returned. func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -420,7 +596,7 @@ func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { // incremented value is returned. func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -441,7 +617,7 @@ func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { // incremented value is returned. func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -462,7 +638,7 @@ func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { // incremented value is returned. func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -483,7 +659,7 @@ func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { // incremented value is returned. func (c *cache) IncrementFloat64(k string, n float64) (float64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -508,7 +684,7 @@ func (c *cache) Decrement(k string, n int64) error { // TODO: Implement Increment and Decrement more cleanly. // (Cannot do Increment(k, n*-1) for uints.) c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return fmt.Errorf("Item not found") @@ -555,7 +731,7 @@ func (c *cache) Decrement(k string, n int64) error { // e.g. DecrementFloat64. func (c *cache) DecrementFloat(k string, n float64) error { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return fmt.Errorf("Item %s not found", k) @@ -578,7 +754,7 @@ func (c *cache) DecrementFloat(k string, n float64) error { // value is returned. func (c *cache) DecrementInt(k string, n int) (int, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -599,7 +775,7 @@ func (c *cache) DecrementInt(k string, n int) (int, error) { // value is returned. func (c *cache) DecrementInt8(k string, n int8) (int8, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -620,7 +796,7 @@ func (c *cache) DecrementInt8(k string, n int8) (int8, error) { // value is returned. func (c *cache) DecrementInt16(k string, n int16) (int16, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -641,7 +817,7 @@ func (c *cache) DecrementInt16(k string, n int16) (int16, error) { // value is returned. func (c *cache) DecrementInt32(k string, n int32) (int32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -662,7 +838,7 @@ func (c *cache) DecrementInt32(k string, n int32) (int32, error) { // value is returned. func (c *cache) DecrementInt64(k string, n int64) (int64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -683,7 +859,7 @@ func (c *cache) DecrementInt64(k string, n int64) (int64, error) { // value is returned. func (c *cache) DecrementUint(k string, n uint) (uint, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -704,7 +880,7 @@ func (c *cache) DecrementUint(k string, n uint) (uint, error) { // decremented value is returned. func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -725,7 +901,7 @@ func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { // value is returned. func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -746,7 +922,7 @@ func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { // decremented value is returned. func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -767,7 +943,7 @@ func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { // decremented value is returned. func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -788,7 +964,7 @@ func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { // decremented value is returned. func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -809,7 +985,7 @@ func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { // decremented value is returned. func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -830,7 +1006,7 @@ func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { // decremented value is returned. func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { c.Lock() - v, found := c.items[k] + v, found := c.items.get(k) if !found || v.Expired() { c.Unlock() return 0, fmt.Errorf("Item %s not found", k) @@ -854,17 +1030,13 @@ func (c *cache) Delete(k string) { } func (c *cache) delete(k string) { - delete(c.items, k) + c.items.del(k) } // Delete all expired items from the cache. func (c *cache) DeleteExpired() { c.Lock() - for k, v := range c.items { - if v.Expired() { - c.delete(k) - } - } + c.items.delExpirie() c.Unlock() } @@ -878,10 +1050,14 @@ func (c *cache) Save(w io.Writer) (err error) { }() c.RLock() defer c.RUnlock() - for _, v := range c.items { - gob.Register(v.Object) + //.items = new(Item) + itemSlice := c.items.toSlice() + for _, item := range itemSlice { + if item.Object != nil { + gob.Register(item.Object) + } } - err = enc.Encode(&c.items) + err = enc.Encode(itemSlice) return } @@ -904,16 +1080,13 @@ func (c *cache) SaveFile(fname string) error { // keys that already exist (and haven't expired) in the current cache. func (c *cache) Load(r io.Reader) error { dec := gob.NewDecoder(r) - items := map[string]*Item{} - err := dec.Decode(&items) + var itemSlice []Item + err := dec.Decode(&itemSlice) if err == nil { c.Lock() defer c.Unlock() - for k, v := range items { - ov, found := c.items[k] - if !found || ov.Expired() { - c.items[k] = v - } + for _, item := range itemSlice { + c.items.add(item.Key, item.Object, item.Expiration) } } return err @@ -939,7 +1112,7 @@ func (c *cache) LoadFile(fname string) error { // fields of the items should be checked. Note that explicit synchronization // is needed to use a cache and its corresponding Items() return value at // the same time, as the map is shared. -func (c *cache) Items() map[string]*Item { +func (c *cache) Items() *Item { c.RLock() defer c.RUnlock() return c.items @@ -949,7 +1122,7 @@ func (c *cache) Items() map[string]*Item { // expired, but have not yet been cleaned up. Equivalent to len(c.Items()). func (c *cache) ItemCount() int { c.RLock() - n := len(c.items) + n := c.items.len() c.RUnlock() return n } @@ -957,7 +1130,7 @@ func (c *cache) ItemCount() int { // Delete all items from the cache. func (c *cache) Flush() { c.Lock() - c.items = map[string]*Item{} + c.items = new(Item) c.Unlock() } @@ -997,7 +1170,7 @@ func newCache(de time.Duration) *cache { } c := &cache{ defaultExpiration: de, - items: map[string]*Item{}, + items: new(Item), } return c } @@ -1121,7 +1294,7 @@ func newShardedCache(n int, de time.Duration) *shardedCache { for i := 0; i < n; i++ { c := &cache{ defaultExpiration: de, - items: map[string]*Item{}, + items: new(Item), } sc.cs[i] = c } diff --git a/cache_test.go b/cache_test.go index 09980a6..ab1f40b 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1246,7 +1246,6 @@ func testFillAndSerialize(t *testing.T, tc *Cache) { if a.(string) != "a" { t.Error("a is not a") } - b, found := oc.Get("b") if !found { t.Error("b was not found")