From 133b774867fab9b2cc19321bce97b9eca332c682 Mon Sep 17 00:00:00 2001 From: Michael Dresner Date: Thu, 24 Mar 2022 23:41:39 +0300 Subject: [PATCH] Return zero value instead of nil --- cache.go | 47 +++++++++++++++++++++---------------- cache_test.go | 61 ++++++++++++++++++++++++------------------------- ordered.go | 7 +++--- ordered_test.go | 22 ++++++++---------- sharded.go | 2 +- 5 files changed, 72 insertions(+), 67 deletions(-) diff --git a/cache.go b/cache.go index a6d8f01..83d6ed4 100644 --- a/cache.go +++ b/cache.go @@ -115,68 +115,74 @@ func (c *cache[K, V]) Replace(k K, x V, d time.Duration) error { return nil } -// Get an item from the cache. Returns the item or nil, and a bool indicating +// Get an item from the cache. Returns the item or zero value, and a bool indicating // whether the key was found. -func (c *cache[K, V]) Get(k K) (*V, bool) { +func (c *cache[K, V]) Get(k K) (V, bool) { + var zeroValue V + c.mu.RLock() // "Inlining" of get and Expired item, found := c.items[k] if !found { c.mu.RUnlock() - return nil, false + return zeroValue, false } if item.Expiration > 0 { if time.Now().UnixNano() > item.Expiration { c.mu.RUnlock() - return nil, false + return zeroValue, false } } c.mu.RUnlock() - return &item.Object, true + return item.Object, true } // GetWithExpiration returns an item and its expiration time from the cache. -// It returns the item or nil, the expiration time if one is set (if the item +// It returns the item or zero value, the expiration time if one is set (if the item // never expires a zero value for time.Time is returned), and a bool indicating // whether the key was found. -func (c *cache[K, V]) GetWithExpiration(k K) (*V, time.Time, bool) { +func (c *cache[K, V]) GetWithExpiration(k K) (V, time.Time, bool) { + var zeroValue V + c.mu.RLock() // "Inlining" of get and Expired item, found := c.items[k] if !found { c.mu.RUnlock() - return nil, time.Time{}, false + return zeroValue, time.Time{}, false } if item.Expiration > 0 { if time.Now().UnixNano() > item.Expiration { c.mu.RUnlock() - return nil, time.Time{}, false + return zeroValue, time.Time{}, false } // Return the item and the expiration time c.mu.RUnlock() - return &item.Object, time.Unix(0, item.Expiration), true + return item.Object, time.Unix(0, item.Expiration), true } // If expiration <= 0 (i.e. no expiration time set) then return the item // and a zeroed time.Time c.mu.RUnlock() - return &item.Object, time.Time{}, true + return item.Object, time.Time{}, true } -func (c *cache[K, V]) get(k K) (*V, bool) { +func (c *cache[K, V]) get(k K) (V, bool) { + var zeroValue V + item, found := c.items[k] if !found { - return nil, false + return zeroValue, false } // "Inlining" of Expired if item.Expiration > 0 { if time.Now().UnixNano() > item.Expiration { - return nil, false + return zeroValue, false } } - return &item.Object, true + return item.Object, true } // Delete an item from the cache. Does nothing if the key is not in the cache. @@ -185,19 +191,20 @@ func (c *cache[K, V]) Delete(k K) { v, evicted := c.delete(k) c.mu.Unlock() if evicted { - c.onEvicted(k, *v) + c.onEvicted(k, v) } } -func (c *cache[K, V]) delete(k K) (*V, bool) { +func (c *cache[K, V]) delete(k K) (V, bool) { + var zeroValue V if c.onEvicted != nil { if v, found := c.items[k]; found { delete(c.items, k) - return &v.Object, true + return v.Object, true } } delete(c.items, k) - return nil, false + return zeroValue, false } type keyAndValue[K comparable, V any] struct { @@ -215,7 +222,7 @@ func (c *cache[K, V]) DeleteExpired() { if v.Expiration > 0 && now > v.Expiration { ov, evicted := c.delete(k) if evicted { - evictedItems = append(evictedItems, keyAndValue[K, V]{k, *ov}) + evictedItems = append(evictedItems, keyAndValue[K, V]{k, ov}) } } } diff --git a/cache_test.go b/cache_test.go index 84ebefa..3e0b6d3 100644 --- a/cache_test.go +++ b/cache_test.go @@ -19,17 +19,17 @@ func TestCache(t *testing.T) { tc := New[string, int](DefaultExpiration, 0) a, found := tc.Get("a") - if found || a != nil { + if found || a != 0 { t.Error("Getting A found value that shouldn't exist:", a) } b, found := tc.Get("b") - if found || b != nil { + if found || b != 0 { t.Error("Getting B found value that shouldn't exist:", b) } c, found := tc.Get("c") - if found || c != nil { + if found || c != 0 { t.Error("Getting C found value that shouldn't exist:", c) } @@ -39,9 +39,9 @@ func TestCache(t *testing.T) { if !found { t.Error("a was not found while getting a2") } - if x == nil { - t.Error("x for a is nil") - } else if a2 := *x; a2+2 != 3 { + if x == 0 { + t.Error("x for a is zero value") + } else if a2 := x; a2+2 != 3 { t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) } } @@ -100,14 +100,14 @@ func TestNewFrom(t *testing.T) { if !found { t.Fatal("Did not find a") } - if *a != 1 { + if a != 1 { t.Fatal("a is not 1") } b, found := tc.Get("b") if !found { t.Fatal("Did not find b") } - if *b != 2 { + if b != 2 { t.Fatal("b is not 2") } } @@ -163,8 +163,8 @@ func TestDelete(t *testing.T) { if found { t.Error("foo was found, but it should have been deleted") } - if x != nil { - t.Error("x is not nil:", x) + if x != "" { + t.Error("x is not zero value:", x) } } @@ -187,15 +187,15 @@ func TestFlush(t *testing.T) { if found { t.Error("foo was found, but it should have been deleted") } - if x != nil { - t.Error("x is not nil:", x) + if x != "" { + t.Error("x is not zero value:", x) } x, found = tc.Get("baz") if found { t.Error("baz was found, but it should have been deleted") } - if x != nil { - t.Error("x is not nil:", x) + if x != "" { + t.Error("x is not zero value:", x) } } @@ -217,7 +217,7 @@ func TestOnEvicted(t *testing.T) { if !works { t.Error("works bool not true") } - if *x != 4 { + if x != 4 { t.Error("bar was not 4") } } @@ -269,7 +269,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("a was not found") } - if (*a).(string) != "a" { + if a.(string) != "a" { t.Error("a is not a") } @@ -277,7 +277,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("b was not found") } - if (*b).(string) != "b" { + if b.(string) != "b" { t.Error("b is not b") } @@ -285,7 +285,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("c was not found") } - if (*c).(string) != "c" { + if c.(string) != "c" { t.Error("c is not c") } @@ -299,7 +299,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("*struct was not found") } - if (*s1).(*TestStruct).Num != 1 { + if s1.(*TestStruct).Num != 1 { t.Error("*struct.Num is not 1") } @@ -307,7 +307,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("[]struct was not found") } - s2r := (*s2).([]TestStruct) + s2r := s2.([]TestStruct) if len(s2r) != 2 { t.Error("Length of s2r is not 2") } @@ -322,7 +322,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("[]*struct was not found") } - s3r := (*s3).([]*TestStruct) + s3r := s3.([]*TestStruct) if len(s3r) != 2 { t.Error("Length of s3r is not 2") } @@ -337,7 +337,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) { if !found { t.Error("structception was not found") } - s4r := (*s4).(*TestStruct) + s4r := s4.(*TestStruct) if len(s4r.Children) != 2 { t.Error("Length of s4r.Children is not 2") } @@ -371,9 +371,8 @@ func TestFileSerialization(t *testing.T) { if !found { t.Error("a was not found") } - astr := *a - if astr != "aa" { - if astr == "a" { + if a != "aa" { + if a == "a" { t.Error("a was overwritten") } else { t.Error("a is not aa") @@ -383,7 +382,7 @@ func TestFileSerialization(t *testing.T) { if !found { t.Error("b was not found") } - if *b != "b" { + if b != "b" { t.Error("b is not b") } } @@ -672,7 +671,7 @@ func TestGetWithExpiration(t *testing.T) { } if x == nil { t.Error("x for a is nil") - } else if a2 := (*x).(int); a2+2 != 3 { + } else if a2 := x.(int); a2+2 != 3 { t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) } if !expiration.IsZero() { @@ -685,7 +684,7 @@ func TestGetWithExpiration(t *testing.T) { } if x == nil { t.Error("x for b is nil") - } else if b2 := (*x).(string); b2+"B" != "bB" { + } else if b2 := x.(string); b2+"B" != "bB" { t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) } if !expiration.IsZero() { @@ -698,7 +697,7 @@ func TestGetWithExpiration(t *testing.T) { } if x == nil { t.Error("x for c is nil") - } else if c2 := (*x).(float64); c2+1.2 != 4.7 { + } else if c2 := x.(float64); c2+1.2 != 4.7 { t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) } if !expiration.IsZero() { @@ -711,7 +710,7 @@ func TestGetWithExpiration(t *testing.T) { } if x == nil { t.Error("x for d is nil") - } else if d2 := (*x).(int); d2+2 != 3 { + } else if d2 := x.(int); d2+2 != 3 { t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2) } if !expiration.IsZero() { @@ -724,7 +723,7 @@ func TestGetWithExpiration(t *testing.T) { } if x == nil { t.Error("x for e is nil") - } else if e2 := (*x).(int); e2+2 != 3 { + } else if e2 := x.(int); e2+2 != 3 { t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2) } if expiration.UnixNano() != tc.items["e"].Expiration { diff --git a/ordered.go b/ordered.go index ff304b3..ba92683 100644 --- a/ordered.go +++ b/ordered.go @@ -17,18 +17,19 @@ type orderedCache[K comparable, V constraints.Ordered] struct { // Increment an item of type by n. // Returns incremented item or an error if it was not found. -func (c *orderedCache[K, V]) Increment(k K, n V) (*V, error) { +func (c *orderedCache[K, V]) Increment(k K, n V) (V, error) { + var zeroValue V c.mu.Lock() v, found := c.items[k] if !found || v.Expired() { c.mu.Unlock() - return nil, fmt.Errorf("Item %v not found", k) + return zeroValue, fmt.Errorf("Item %v not found", k) } res := v.Object + n v.Object = res c.items[k] = v c.mu.Unlock() - return &res, nil + return res, nil } // Return a new ordered cache with a given default expiration duration and cleanup diff --git a/ordered_test.go b/ordered_test.go index 1784b7c..c77e3fe 100644 --- a/ordered_test.go +++ b/ordered_test.go @@ -9,14 +9,14 @@ func TestIncrementWithInt(t *testing.T) { if err != nil { t.Error("Error incrementing:", err) } - if *n != 3 { + if n != 3 { t.Error("Returned number is not 3:", n) } x, found := tc.Get("tint") if !found { t.Error("tint was not found") } - if *x != 3 { + if x != 3 { t.Error("tint is not 3:", x) } } @@ -28,14 +28,14 @@ func TestIncrementInt8(t *testing.T) { if err != nil { t.Error("Error decrementing:", err) } - if *n != 3 { + if n != 3 { t.Error("Returned number is not 3:", n) } x, found := tc.Get("int8") if !found { t.Error("int8 was not found") } - if *x != 3 { + if x != 3 { t.Error("int8 is not 3:", x) } } @@ -47,13 +47,12 @@ func TestIncrementOverflowInt(t *testing.T) { if err != nil { t.Error("Error incrementing int8:", err) } - if *n != -128 { + if n != -128 { t.Error("Returned number is not -128:", n) } x, _ := tc.Get("int8") - int8 := *x - if int8 != -128 { - t.Error("int8 did not overflow as expected; value:", int8) + if x != -128 { + t.Error("int8 did not overflow as expected; value:", x) } } @@ -65,13 +64,12 @@ func TestIncrementOverflowUint(t *testing.T) { if err != nil { t.Error("Error incrementing int8:", err) } - if *n != 0 { + if n != 0 { t.Error("Returned number is not 0:", n) } x, _ := tc.Get("uint8") - uint8 := *x - if uint8 != 0 { - t.Error("uint8 did not overflow as expected; value:", uint8) + if x != 0 { + t.Error("uint8 did not overflow as expected; value:", x) } } diff --git a/sharded.go b/sharded.go index a8f7d8b..b2758cb 100644 --- a/sharded.go +++ b/sharded.go @@ -80,7 +80,7 @@ func (sc *shardedCache[V]) Replace(k string, x V, d time.Duration) error { return sc.bucket(k).Replace(k, x, d) } -func (sc *shardedCache[V]) Get(k string) (*V, bool) { +func (sc *shardedCache[V]) Get(k string) (V, bool) { return sc.bucket(k).Get(k) }