diff --git a/cache.go b/cache.go index db88d2f..be921ac 100644 --- a/cache.go +++ b/cache.go @@ -165,6 +165,38 @@ func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { return item.Object, time.Time{}, true } +// GetWithExpirationUpdate returns item and updates its cache expiration time +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpirationUpdate(k string, d time.Duration) (interface{}, bool) { + c.mu.RLock() + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, false + } + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, false + } + } + c.mu.RUnlock() + + c.mu.Lock() + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + item.Expiration = time.Now().Add(d).UnixNano() + } + c.items[k] = item + c.mu.Unlock() + + return item.Object, true +} + func (c *cache) get(k string) (interface{}, bool) { item, found := c.items[k] if !found { diff --git a/cache_test.go b/cache_test.go index cb80b38..ed2315d 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1769,3 +1769,28 @@ func TestGetWithExpiration(t *testing.T) { t.Error("expiration for e is in the past") } } + +func TestGetWithExpirationUpdate(t *testing.T) { + var found bool + + tc := New(50*time.Millisecond, 1*time.Millisecond) + tc.Set("a", 1, DefaultExpiration) + + <-time.After(25 * time.Millisecond) + _, found = tc.GetWithExpirationUpdate("a", DefaultExpiration) + if !found { + t.Error("item `a` not expired yet") + } + + <-time.After(25 * time.Millisecond) + _, found = tc.Get("a") + if !found { + t.Error("item `a` not expired yet") + } + + <-time.After(30 * time.Millisecond) + _, found = tc.Get("a") + if found { + t.Error("Found `a` when it should have been automatically deleted") + } +} \ No newline at end of file