Return zero value instead of nil

This commit is contained in:
Michael Dresner 2022-03-24 23:41:39 +03:00
parent 0d5b08999a
commit 133b774867
No known key found for this signature in database
GPG Key ID: 039C3C305BAC5C23
5 changed files with 72 additions and 67 deletions

View File

@ -115,68 +115,74 @@ func (c *cache[K, V]) Replace(k K, x V, d time.Duration) error {
return nil return nil
} }
// Get an item from the cache. Returns the item or nil, and a bool indicating // Get an item from the cache. Returns the item or zero value, and a bool indicating
// whether the key was found. // whether the key was found.
func (c *cache[K, V]) Get(k K) (*V, bool) { func (c *cache[K, V]) Get(k K) (V, bool) {
var zeroValue V
c.mu.RLock() c.mu.RLock()
// "Inlining" of get and Expired // "Inlining" of get and Expired
item, found := c.items[k] item, found := c.items[k]
if !found { if !found {
c.mu.RUnlock() c.mu.RUnlock()
return nil, false return zeroValue, false
} }
if item.Expiration > 0 { if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration { if time.Now().UnixNano() > item.Expiration {
c.mu.RUnlock() c.mu.RUnlock()
return nil, false return zeroValue, false
} }
} }
c.mu.RUnlock() c.mu.RUnlock()
return &item.Object, true return item.Object, true
} }
// GetWithExpiration returns an item and its expiration time from the cache. // GetWithExpiration returns an item and its expiration time from the cache.
// It returns the item or nil, the expiration time if one is set (if the item // It returns the item or zero value, the expiration time if one is set (if the item
// never expires a zero value for time.Time is returned), and a bool indicating // never expires a zero value for time.Time is returned), and a bool indicating
// whether the key was found. // whether the key was found.
func (c *cache[K, V]) GetWithExpiration(k K) (*V, time.Time, bool) { func (c *cache[K, V]) GetWithExpiration(k K) (V, time.Time, bool) {
var zeroValue V
c.mu.RLock() c.mu.RLock()
// "Inlining" of get and Expired // "Inlining" of get and Expired
item, found := c.items[k] item, found := c.items[k]
if !found { if !found {
c.mu.RUnlock() c.mu.RUnlock()
return nil, time.Time{}, false return zeroValue, time.Time{}, false
} }
if item.Expiration > 0 { if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration { if time.Now().UnixNano() > item.Expiration {
c.mu.RUnlock() c.mu.RUnlock()
return nil, time.Time{}, false return zeroValue, time.Time{}, false
} }
// Return the item and the expiration time // Return the item and the expiration time
c.mu.RUnlock() c.mu.RUnlock()
return &item.Object, time.Unix(0, item.Expiration), true return item.Object, time.Unix(0, item.Expiration), true
} }
// If expiration <= 0 (i.e. no expiration time set) then return the item // If expiration <= 0 (i.e. no expiration time set) then return the item
// and a zeroed time.Time // and a zeroed time.Time
c.mu.RUnlock() c.mu.RUnlock()
return &item.Object, time.Time{}, true return item.Object, time.Time{}, true
} }
func (c *cache[K, V]) get(k K) (*V, bool) { func (c *cache[K, V]) get(k K) (V, bool) {
var zeroValue V
item, found := c.items[k] item, found := c.items[k]
if !found { if !found {
return nil, false return zeroValue, false
} }
// "Inlining" of Expired // "Inlining" of Expired
if item.Expiration > 0 { if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration { if time.Now().UnixNano() > item.Expiration {
return nil, false return zeroValue, false
} }
} }
return &item.Object, true return item.Object, true
} }
// Delete an item from the cache. Does nothing if the key is not in the cache. // Delete an item from the cache. Does nothing if the key is not in the cache.
@ -185,19 +191,20 @@ func (c *cache[K, V]) Delete(k K) {
v, evicted := c.delete(k) v, evicted := c.delete(k)
c.mu.Unlock() c.mu.Unlock()
if evicted { if evicted {
c.onEvicted(k, *v) c.onEvicted(k, v)
} }
} }
func (c *cache[K, V]) delete(k K) (*V, bool) { func (c *cache[K, V]) delete(k K) (V, bool) {
var zeroValue V
if c.onEvicted != nil { if c.onEvicted != nil {
if v, found := c.items[k]; found { if v, found := c.items[k]; found {
delete(c.items, k) delete(c.items, k)
return &v.Object, true return v.Object, true
} }
} }
delete(c.items, k) delete(c.items, k)
return nil, false return zeroValue, false
} }
type keyAndValue[K comparable, V any] struct { type keyAndValue[K comparable, V any] struct {
@ -215,7 +222,7 @@ func (c *cache[K, V]) DeleteExpired() {
if v.Expiration > 0 && now > v.Expiration { if v.Expiration > 0 && now > v.Expiration {
ov, evicted := c.delete(k) ov, evicted := c.delete(k)
if evicted { if evicted {
evictedItems = append(evictedItems, keyAndValue[K, V]{k, *ov}) evictedItems = append(evictedItems, keyAndValue[K, V]{k, ov})
} }
} }
} }

View File

@ -19,17 +19,17 @@ func TestCache(t *testing.T) {
tc := New[string, int](DefaultExpiration, 0) tc := New[string, int](DefaultExpiration, 0)
a, found := tc.Get("a") a, found := tc.Get("a")
if found || a != nil { if found || a != 0 {
t.Error("Getting A found value that shouldn't exist:", a) t.Error("Getting A found value that shouldn't exist:", a)
} }
b, found := tc.Get("b") b, found := tc.Get("b")
if found || b != nil { if found || b != 0 {
t.Error("Getting B found value that shouldn't exist:", b) t.Error("Getting B found value that shouldn't exist:", b)
} }
c, found := tc.Get("c") c, found := tc.Get("c")
if found || c != nil { if found || c != 0 {
t.Error("Getting C found value that shouldn't exist:", c) t.Error("Getting C found value that shouldn't exist:", c)
} }
@ -39,9 +39,9 @@ func TestCache(t *testing.T) {
if !found { if !found {
t.Error("a was not found while getting a2") t.Error("a was not found while getting a2")
} }
if x == nil { if x == 0 {
t.Error("x for a is nil") t.Error("x for a is zero value")
} else if a2 := *x; a2+2 != 3 { } else if a2 := x; a2+2 != 3 {
t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
} }
} }
@ -100,14 +100,14 @@ func TestNewFrom(t *testing.T) {
if !found { if !found {
t.Fatal("Did not find a") t.Fatal("Did not find a")
} }
if *a != 1 { if a != 1 {
t.Fatal("a is not 1") t.Fatal("a is not 1")
} }
b, found := tc.Get("b") b, found := tc.Get("b")
if !found { if !found {
t.Fatal("Did not find b") t.Fatal("Did not find b")
} }
if *b != 2 { if b != 2 {
t.Fatal("b is not 2") t.Fatal("b is not 2")
} }
} }
@ -163,8 +163,8 @@ func TestDelete(t *testing.T) {
if found { if found {
t.Error("foo was found, but it should have been deleted") t.Error("foo was found, but it should have been deleted")
} }
if x != nil { if x != "" {
t.Error("x is not nil:", x) t.Error("x is not zero value:", x)
} }
} }
@ -187,15 +187,15 @@ func TestFlush(t *testing.T) {
if found { if found {
t.Error("foo was found, but it should have been deleted") t.Error("foo was found, but it should have been deleted")
} }
if x != nil { if x != "" {
t.Error("x is not nil:", x) t.Error("x is not zero value:", x)
} }
x, found = tc.Get("baz") x, found = tc.Get("baz")
if found { if found {
t.Error("baz was found, but it should have been deleted") t.Error("baz was found, but it should have been deleted")
} }
if x != nil { if x != "" {
t.Error("x is not nil:", x) t.Error("x is not zero value:", x)
} }
} }
@ -217,7 +217,7 @@ func TestOnEvicted(t *testing.T) {
if !works { if !works {
t.Error("works bool not true") t.Error("works bool not true")
} }
if *x != 4 { if x != 4 {
t.Error("bar was not 4") t.Error("bar was not 4")
} }
} }
@ -269,7 +269,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("a was not found") t.Error("a was not found")
} }
if (*a).(string) != "a" { if a.(string) != "a" {
t.Error("a is not a") t.Error("a is not a")
} }
@ -277,7 +277,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("b was not found") t.Error("b was not found")
} }
if (*b).(string) != "b" { if b.(string) != "b" {
t.Error("b is not b") t.Error("b is not b")
} }
@ -285,7 +285,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("c was not found") t.Error("c was not found")
} }
if (*c).(string) != "c" { if c.(string) != "c" {
t.Error("c is not c") t.Error("c is not c")
} }
@ -299,7 +299,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("*struct was not found") t.Error("*struct was not found")
} }
if (*s1).(*TestStruct).Num != 1 { if s1.(*TestStruct).Num != 1 {
t.Error("*struct.Num is not 1") t.Error("*struct.Num is not 1")
} }
@ -307,7 +307,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("[]struct was not found") t.Error("[]struct was not found")
} }
s2r := (*s2).([]TestStruct) s2r := s2.([]TestStruct)
if len(s2r) != 2 { if len(s2r) != 2 {
t.Error("Length of s2r is not 2") t.Error("Length of s2r is not 2")
} }
@ -322,7 +322,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("[]*struct was not found") t.Error("[]*struct was not found")
} }
s3r := (*s3).([]*TestStruct) s3r := s3.([]*TestStruct)
if len(s3r) != 2 { if len(s3r) != 2 {
t.Error("Length of s3r is not 2") t.Error("Length of s3r is not 2")
} }
@ -337,7 +337,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found { if !found {
t.Error("structception was not found") t.Error("structception was not found")
} }
s4r := (*s4).(*TestStruct) s4r := s4.(*TestStruct)
if len(s4r.Children) != 2 { if len(s4r.Children) != 2 {
t.Error("Length of s4r.Children is not 2") t.Error("Length of s4r.Children is not 2")
} }
@ -371,9 +371,8 @@ func TestFileSerialization(t *testing.T) {
if !found { if !found {
t.Error("a was not found") t.Error("a was not found")
} }
astr := *a if a != "aa" {
if astr != "aa" { if a == "a" {
if astr == "a" {
t.Error("a was overwritten") t.Error("a was overwritten")
} else { } else {
t.Error("a is not aa") t.Error("a is not aa")
@ -383,7 +382,7 @@ func TestFileSerialization(t *testing.T) {
if !found { if !found {
t.Error("b was not found") t.Error("b was not found")
} }
if *b != "b" { if b != "b" {
t.Error("b is not b") t.Error("b is not b")
} }
} }
@ -672,7 +671,7 @@ func TestGetWithExpiration(t *testing.T) {
} }
if x == nil { if x == nil {
t.Error("x for a is nil") t.Error("x for a is nil")
} else if a2 := (*x).(int); a2+2 != 3 { } else if a2 := x.(int); a2+2 != 3 {
t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2) t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
} }
if !expiration.IsZero() { if !expiration.IsZero() {
@ -685,7 +684,7 @@ func TestGetWithExpiration(t *testing.T) {
} }
if x == nil { if x == nil {
t.Error("x for b is nil") t.Error("x for b is nil")
} else if b2 := (*x).(string); b2+"B" != "bB" { } else if b2 := x.(string); b2+"B" != "bB" {
t.Error("b2 (which should be b) plus B does not equal bB; value:", b2) t.Error("b2 (which should be b) plus B does not equal bB; value:", b2)
} }
if !expiration.IsZero() { if !expiration.IsZero() {
@ -698,7 +697,7 @@ func TestGetWithExpiration(t *testing.T) {
} }
if x == nil { if x == nil {
t.Error("x for c is nil") t.Error("x for c is nil")
} else if c2 := (*x).(float64); c2+1.2 != 4.7 { } else if c2 := x.(float64); c2+1.2 != 4.7 {
t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2) t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2)
} }
if !expiration.IsZero() { if !expiration.IsZero() {
@ -711,7 +710,7 @@ func TestGetWithExpiration(t *testing.T) {
} }
if x == nil { if x == nil {
t.Error("x for d is nil") t.Error("x for d is nil")
} else if d2 := (*x).(int); d2+2 != 3 { } else if d2 := x.(int); d2+2 != 3 {
t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2) t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2)
} }
if !expiration.IsZero() { if !expiration.IsZero() {
@ -724,7 +723,7 @@ func TestGetWithExpiration(t *testing.T) {
} }
if x == nil { if x == nil {
t.Error("x for e is nil") t.Error("x for e is nil")
} else if e2 := (*x).(int); e2+2 != 3 { } else if e2 := x.(int); e2+2 != 3 {
t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2) t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2)
} }
if expiration.UnixNano() != tc.items["e"].Expiration { if expiration.UnixNano() != tc.items["e"].Expiration {

View File

@ -17,18 +17,19 @@ type orderedCache[K comparable, V constraints.Ordered] struct {
// Increment an item of type by n. // Increment an item of type by n.
// Returns incremented item or an error if it was not found. // Returns incremented item or an error if it was not found.
func (c *orderedCache[K, V]) Increment(k K, n V) (*V, error) { func (c *orderedCache[K, V]) Increment(k K, n V) (V, error) {
var zeroValue V
c.mu.Lock() c.mu.Lock()
v, found := c.items[k] v, found := c.items[k]
if !found || v.Expired() { if !found || v.Expired() {
c.mu.Unlock() c.mu.Unlock()
return nil, fmt.Errorf("Item %v not found", k) return zeroValue, fmt.Errorf("Item %v not found", k)
} }
res := v.Object + n res := v.Object + n
v.Object = res v.Object = res
c.items[k] = v c.items[k] = v
c.mu.Unlock() c.mu.Unlock()
return &res, nil return res, nil
} }
// Return a new ordered cache with a given default expiration duration and cleanup // Return a new ordered cache with a given default expiration duration and cleanup

View File

@ -9,14 +9,14 @@ func TestIncrementWithInt(t *testing.T) {
if err != nil { if err != nil {
t.Error("Error incrementing:", err) t.Error("Error incrementing:", err)
} }
if *n != 3 { if n != 3 {
t.Error("Returned number is not 3:", n) t.Error("Returned number is not 3:", n)
} }
x, found := tc.Get("tint") x, found := tc.Get("tint")
if !found { if !found {
t.Error("tint was not found") t.Error("tint was not found")
} }
if *x != 3 { if x != 3 {
t.Error("tint is not 3:", x) t.Error("tint is not 3:", x)
} }
} }
@ -28,14 +28,14 @@ func TestIncrementInt8(t *testing.T) {
if err != nil { if err != nil {
t.Error("Error decrementing:", err) t.Error("Error decrementing:", err)
} }
if *n != 3 { if n != 3 {
t.Error("Returned number is not 3:", n) t.Error("Returned number is not 3:", n)
} }
x, found := tc.Get("int8") x, found := tc.Get("int8")
if !found { if !found {
t.Error("int8 was not found") t.Error("int8 was not found")
} }
if *x != 3 { if x != 3 {
t.Error("int8 is not 3:", x) t.Error("int8 is not 3:", x)
} }
} }
@ -47,13 +47,12 @@ func TestIncrementOverflowInt(t *testing.T) {
if err != nil { if err != nil {
t.Error("Error incrementing int8:", err) t.Error("Error incrementing int8:", err)
} }
if *n != -128 { if n != -128 {
t.Error("Returned number is not -128:", n) t.Error("Returned number is not -128:", n)
} }
x, _ := tc.Get("int8") x, _ := tc.Get("int8")
int8 := *x if x != -128 {
if int8 != -128 { t.Error("int8 did not overflow as expected; value:", x)
t.Error("int8 did not overflow as expected; value:", int8)
} }
} }
@ -65,13 +64,12 @@ func TestIncrementOverflowUint(t *testing.T) {
if err != nil { if err != nil {
t.Error("Error incrementing int8:", err) t.Error("Error incrementing int8:", err)
} }
if *n != 0 { if n != 0 {
t.Error("Returned number is not 0:", n) t.Error("Returned number is not 0:", n)
} }
x, _ := tc.Get("uint8") x, _ := tc.Get("uint8")
uint8 := *x if x != 0 {
if uint8 != 0 { t.Error("uint8 did not overflow as expected; value:", x)
t.Error("uint8 did not overflow as expected; value:", uint8)
} }
} }

View File

@ -80,7 +80,7 @@ func (sc *shardedCache[V]) Replace(k string, x V, d time.Duration) error {
return sc.bucket(k).Replace(k, x, d) return sc.bucket(k).Replace(k, x, d)
} }
func (sc *shardedCache[V]) Get(k string) (*V, bool) { func (sc *shardedCache[V]) Get(k string) (V, bool) {
return sc.bucket(k).Get(k) return sc.bucket(k).Get(k)
} }