Return zero value instead of nil

This commit is contained in:
Michael Dresner 2022-03-24 23:41:39 +03:00
parent 0d5b08999a
commit 133b774867
No known key found for this signature in database
GPG Key ID: 039C3C305BAC5C23
5 changed files with 72 additions and 67 deletions

View File

@ -115,68 +115,74 @@ func (c *cache[K, V]) Replace(k K, x V, d time.Duration) error {
return nil
}
// Get an item from the cache. Returns the item or nil, and a bool indicating
// Get an item from the cache. Returns the item or zero value, and a bool indicating
// whether the key was found.
func (c *cache[K, V]) Get(k K) (*V, bool) {
func (c *cache[K, V]) Get(k K) (V, bool) {
var zeroValue V
c.mu.RLock()
// "Inlining" of get and Expired
item, found := c.items[k]
if !found {
c.mu.RUnlock()
return nil, false
return zeroValue, false
}
if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
c.mu.RUnlock()
return nil, false
return zeroValue, false
}
}
c.mu.RUnlock()
return &item.Object, true
return item.Object, true
}
// GetWithExpiration returns an item and its expiration time from the cache.
// It returns the item or nil, the expiration time if one is set (if the item
// It returns the item or zero value, the expiration time if one is set (if the item
// never expires a zero value for time.Time is returned), and a bool indicating
// whether the key was found.
func (c *cache[K, V]) GetWithExpiration(k K) (*V, time.Time, bool) {
func (c *cache[K, V]) GetWithExpiration(k K) (V, time.Time, bool) {
var zeroValue V
c.mu.RLock()
// "Inlining" of get and Expired
item, found := c.items[k]
if !found {
c.mu.RUnlock()
return nil, time.Time{}, false
return zeroValue, time.Time{}, false
}
if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
c.mu.RUnlock()
return nil, time.Time{}, false
return zeroValue, time.Time{}, false
}
// Return the item and the expiration time
c.mu.RUnlock()
return &item.Object, time.Unix(0, item.Expiration), true
return item.Object, time.Unix(0, item.Expiration), true
}
// If expiration <= 0 (i.e. no expiration time set) then return the item
// and a zeroed time.Time
c.mu.RUnlock()
return &item.Object, time.Time{}, true
return item.Object, time.Time{}, true
}
func (c *cache[K, V]) get(k K) (*V, bool) {
func (c *cache[K, V]) get(k K) (V, bool) {
var zeroValue V
item, found := c.items[k]
if !found {
return nil, false
return zeroValue, false
}
// "Inlining" of Expired
if item.Expiration > 0 {
if time.Now().UnixNano() > item.Expiration {
return nil, false
return zeroValue, false
}
}
return &item.Object, true
return item.Object, true
}
// Delete an item from the cache. Does nothing if the key is not in the cache.
@ -185,19 +191,20 @@ func (c *cache[K, V]) Delete(k K) {
v, evicted := c.delete(k)
c.mu.Unlock()
if evicted {
c.onEvicted(k, *v)
c.onEvicted(k, v)
}
}
func (c *cache[K, V]) delete(k K) (*V, bool) {
func (c *cache[K, V]) delete(k K) (V, bool) {
var zeroValue V
if c.onEvicted != nil {
if v, found := c.items[k]; found {
delete(c.items, k)
return &v.Object, true
return v.Object, true
}
}
delete(c.items, k)
return nil, false
return zeroValue, false
}
type keyAndValue[K comparable, V any] struct {
@ -215,7 +222,7 @@ func (c *cache[K, V]) DeleteExpired() {
if v.Expiration > 0 && now > v.Expiration {
ov, evicted := c.delete(k)
if evicted {
evictedItems = append(evictedItems, keyAndValue[K, V]{k, *ov})
evictedItems = append(evictedItems, keyAndValue[K, V]{k, ov})
}
}
}

View File

@ -19,17 +19,17 @@ func TestCache(t *testing.T) {
tc := New[string, int](DefaultExpiration, 0)
a, found := tc.Get("a")
if found || a != nil {
if found || a != 0 {
t.Error("Getting A found value that shouldn't exist:", a)
}
b, found := tc.Get("b")
if found || b != nil {
if found || b != 0 {
t.Error("Getting B found value that shouldn't exist:", b)
}
c, found := tc.Get("c")
if found || c != nil {
if found || c != 0 {
t.Error("Getting C found value that shouldn't exist:", c)
}
@ -39,9 +39,9 @@ func TestCache(t *testing.T) {
if !found {
t.Error("a was not found while getting a2")
}
if x == nil {
t.Error("x for a is nil")
} else if a2 := *x; a2+2 != 3 {
if x == 0 {
t.Error("x for a is zero value")
} else if a2 := x; a2+2 != 3 {
t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
}
}
@ -100,14 +100,14 @@ func TestNewFrom(t *testing.T) {
if !found {
t.Fatal("Did not find a")
}
if *a != 1 {
if a != 1 {
t.Fatal("a is not 1")
}
b, found := tc.Get("b")
if !found {
t.Fatal("Did not find b")
}
if *b != 2 {
if b != 2 {
t.Fatal("b is not 2")
}
}
@ -163,8 +163,8 @@ func TestDelete(t *testing.T) {
if found {
t.Error("foo was found, but it should have been deleted")
}
if x != nil {
t.Error("x is not nil:", x)
if x != "" {
t.Error("x is not zero value:", x)
}
}
@ -187,15 +187,15 @@ func TestFlush(t *testing.T) {
if found {
t.Error("foo was found, but it should have been deleted")
}
if x != nil {
t.Error("x is not nil:", x)
if x != "" {
t.Error("x is not zero value:", x)
}
x, found = tc.Get("baz")
if found {
t.Error("baz was found, but it should have been deleted")
}
if x != nil {
t.Error("x is not nil:", x)
if x != "" {
t.Error("x is not zero value:", x)
}
}
@ -217,7 +217,7 @@ func TestOnEvicted(t *testing.T) {
if !works {
t.Error("works bool not true")
}
if *x != 4 {
if x != 4 {
t.Error("bar was not 4")
}
}
@ -269,7 +269,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("a was not found")
}
if (*a).(string) != "a" {
if a.(string) != "a" {
t.Error("a is not a")
}
@ -277,7 +277,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("b was not found")
}
if (*b).(string) != "b" {
if b.(string) != "b" {
t.Error("b is not b")
}
@ -285,7 +285,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("c was not found")
}
if (*c).(string) != "c" {
if c.(string) != "c" {
t.Error("c is not c")
}
@ -299,7 +299,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("*struct was not found")
}
if (*s1).(*TestStruct).Num != 1 {
if s1.(*TestStruct).Num != 1 {
t.Error("*struct.Num is not 1")
}
@ -307,7 +307,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("[]struct was not found")
}
s2r := (*s2).([]TestStruct)
s2r := s2.([]TestStruct)
if len(s2r) != 2 {
t.Error("Length of s2r is not 2")
}
@ -322,7 +322,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("[]*struct was not found")
}
s3r := (*s3).([]*TestStruct)
s3r := s3.([]*TestStruct)
if len(s3r) != 2 {
t.Error("Length of s3r is not 2")
}
@ -337,7 +337,7 @@ func testFillAndSerialize(t *testing.T, tc *Cache[string, any]) {
if !found {
t.Error("structception was not found")
}
s4r := (*s4).(*TestStruct)
s4r := s4.(*TestStruct)
if len(s4r.Children) != 2 {
t.Error("Length of s4r.Children is not 2")
}
@ -371,9 +371,8 @@ func TestFileSerialization(t *testing.T) {
if !found {
t.Error("a was not found")
}
astr := *a
if astr != "aa" {
if astr == "a" {
if a != "aa" {
if a == "a" {
t.Error("a was overwritten")
} else {
t.Error("a is not aa")
@ -383,7 +382,7 @@ func TestFileSerialization(t *testing.T) {
if !found {
t.Error("b was not found")
}
if *b != "b" {
if b != "b" {
t.Error("b is not b")
}
}
@ -672,7 +671,7 @@ func TestGetWithExpiration(t *testing.T) {
}
if x == nil {
t.Error("x for a is nil")
} else if a2 := (*x).(int); a2+2 != 3 {
} else if a2 := x.(int); a2+2 != 3 {
t.Error("a2 (which should be 1) plus 2 does not equal 3; value:", a2)
}
if !expiration.IsZero() {
@ -685,7 +684,7 @@ func TestGetWithExpiration(t *testing.T) {
}
if x == nil {
t.Error("x for b is nil")
} else if b2 := (*x).(string); b2+"B" != "bB" {
} else if b2 := x.(string); b2+"B" != "bB" {
t.Error("b2 (which should be b) plus B does not equal bB; value:", b2)
}
if !expiration.IsZero() {
@ -698,7 +697,7 @@ func TestGetWithExpiration(t *testing.T) {
}
if x == nil {
t.Error("x for c is nil")
} else if c2 := (*x).(float64); c2+1.2 != 4.7 {
} else if c2 := x.(float64); c2+1.2 != 4.7 {
t.Error("c2 (which should be 3.5) plus 1.2 does not equal 4.7; value:", c2)
}
if !expiration.IsZero() {
@ -711,7 +710,7 @@ func TestGetWithExpiration(t *testing.T) {
}
if x == nil {
t.Error("x for d is nil")
} else if d2 := (*x).(int); d2+2 != 3 {
} else if d2 := x.(int); d2+2 != 3 {
t.Error("d (which should be 1) plus 2 does not equal 3; value:", d2)
}
if !expiration.IsZero() {
@ -724,7 +723,7 @@ func TestGetWithExpiration(t *testing.T) {
}
if x == nil {
t.Error("x for e is nil")
} else if e2 := (*x).(int); e2+2 != 3 {
} else if e2 := x.(int); e2+2 != 3 {
t.Error("e (which should be 1) plus 2 does not equal 3; value:", e2)
}
if expiration.UnixNano() != tc.items["e"].Expiration {

View File

@ -17,18 +17,19 @@ type orderedCache[K comparable, V constraints.Ordered] struct {
// Increment an item of type by n.
// Returns incremented item or an error if it was not found.
func (c *orderedCache[K, V]) Increment(k K, n V) (*V, error) {
func (c *orderedCache[K, V]) Increment(k K, n V) (V, error) {
var zeroValue V
c.mu.Lock()
v, found := c.items[k]
if !found || v.Expired() {
c.mu.Unlock()
return nil, fmt.Errorf("Item %v not found", k)
return zeroValue, fmt.Errorf("Item %v not found", k)
}
res := v.Object + n
v.Object = res
c.items[k] = v
c.mu.Unlock()
return &res, nil
return res, nil
}
// Return a new ordered cache with a given default expiration duration and cleanup

View File

@ -9,14 +9,14 @@ func TestIncrementWithInt(t *testing.T) {
if err != nil {
t.Error("Error incrementing:", err)
}
if *n != 3 {
if n != 3 {
t.Error("Returned number is not 3:", n)
}
x, found := tc.Get("tint")
if !found {
t.Error("tint was not found")
}
if *x != 3 {
if x != 3 {
t.Error("tint is not 3:", x)
}
}
@ -28,14 +28,14 @@ func TestIncrementInt8(t *testing.T) {
if err != nil {
t.Error("Error decrementing:", err)
}
if *n != 3 {
if n != 3 {
t.Error("Returned number is not 3:", n)
}
x, found := tc.Get("int8")
if !found {
t.Error("int8 was not found")
}
if *x != 3 {
if x != 3 {
t.Error("int8 is not 3:", x)
}
}
@ -47,13 +47,12 @@ func TestIncrementOverflowInt(t *testing.T) {
if err != nil {
t.Error("Error incrementing int8:", err)
}
if *n != -128 {
if n != -128 {
t.Error("Returned number is not -128:", n)
}
x, _ := tc.Get("int8")
int8 := *x
if int8 != -128 {
t.Error("int8 did not overflow as expected; value:", int8)
if x != -128 {
t.Error("int8 did not overflow as expected; value:", x)
}
}
@ -65,13 +64,12 @@ func TestIncrementOverflowUint(t *testing.T) {
if err != nil {
t.Error("Error incrementing int8:", err)
}
if *n != 0 {
if n != 0 {
t.Error("Returned number is not 0:", n)
}
x, _ := tc.Get("uint8")
uint8 := *x
if uint8 != 0 {
t.Error("uint8 did not overflow as expected; value:", uint8)
if x != 0 {
t.Error("uint8 did not overflow as expected; value:", x)
}
}

View File

@ -80,7 +80,7 @@ func (sc *shardedCache[V]) Replace(k string, x V, d time.Duration) error {
return sc.bucket(k).Replace(k, x, d)
}
func (sc *shardedCache[V]) Get(k string) (*V, bool) {
func (sc *shardedCache[V]) Get(k string) (V, bool) {
return sc.bucket(k).Get(k)
}