diff --git a/README.md b/README.md index c5789cc..c4172db 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,12 @@ func main() { foo := x.(*MyStruct) // ... } + + // Set new value if the item is missing or expired + item, found := c.SetNx("foo", cache.DefaultExpiration, func(s string) (interface{}, error) { + // return value to set if cache miss + return "bar", nil + }) } ``` diff --git a/cache.go b/cache.go index db88d2f..7d6ff53 100644 --- a/cache.go +++ b/cache.go @@ -135,6 +135,25 @@ func (c *cache) Get(k string) (interface{}, bool) { return item.Object, true } +func (c *cache) SetNx(k string, expiry time.Duration, funcToCall func(string) (interface{}, error)) (interface{}, +bool) { + c.mu.RLock() + item, found := c.items[k] + if !found || (item.Expiration > 0 && time.Now().UnixNano() > item.Expiration){ + val, err := funcToCall(k) + if err != nil { + c.mu.RUnlock() + return nil, false + } + + c.set(k, val, expiry) + c.mu.RUnlock() + return val, true + } + c.mu.RUnlock() + return item.Object, true +} + // GetWithExpiration returns an item and its expiration time from the cache. // It returns the item or nil, the expiration time if one is set (if the item // never expires a zero value for time.Time is returned), and a bool indicating diff --git a/cache_test.go b/cache_test.go index 47a3d53..1ca4c09 100644 --- a/cache_test.go +++ b/cache_test.go @@ -8,6 +8,8 @@ import ( "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/pkg/errors" ) type TestStruct struct { @@ -1769,3 +1771,51 @@ func TestGetWithExpiration(t *testing.T) { t.Error("expiration for e is in the past") } } + +func TestSetNx(t *testing.T) { + tc := New(DefaultExpiration, 0) + val1, isPresent := tc.SetNx("key1", 10 * time.Second, func(s string) (interface{}, error) { + return "a for apple", nil + }) + + assert.True(t, isPresent) + assert.Equal(t, "a for apple", val1) + + val2, isPresent := tc.SetNx("key2", 10 * time.Second, func(s string) (interface{}, error) { + return 10, nil + }) + + assert.True(t, isPresent) + assert.Equal(t, 10, val2) + + tc.set("existingKey", []int{1,2,3}, 10 * time.Second) + val3, isPresent := tc.SetNx("existingKey", 10 * time.Second, func(s string) (interface{}, error) { + return []int{4}, nil + }) + + assert.True(t, isPresent) + assert.Equal(t, []int{1,2,3}, val3) + + tc.set("expiredKey", "expiredValue", 10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) + val4, isPresent := tc.SetNx("expiredKey", 10 * time.Second, func(s string) (interface{}, error) { + return "newValue", nil + }) + + assert.True(t, isPresent) + assert.Equal(t, "newValue", val4) + + val5, isPresent := tc.SetNx("errorKey", 10 * time.Second, func(s string) (interface{}, error) { + return "doesn't matter return value", errors.New("some error") + }) + + assert.False(t, isPresent) + assert.Nil(t, val5) + + val6, isPresent := tc.SetNx("errorKey", 10 * time.Second, func(s string) (interface{}, error) { + return nil, nil + }) + + assert.True(t, isPresent) + assert.Nil(t, val6) +}