diff --git a/cache/cache.go b/cache/cache.go index a57f46a..b9fd8ad 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,6 +1,9 @@ package libpack_cache import ( + "bytes" + "compress/gzip" + "io" "sync" "time" ) @@ -46,15 +49,20 @@ func (c *Cache) Set(key string, value []byte, ttl time.Duration) { defer c.Unlock() expiresAt := time.Now().Add(ttl) - // Get a byte slice from the pool and ensure it's properly sized. - b := c.bytePool.Get().([]byte) - if cap(b) < len(value) { - b = make([]byte, len(value)) - } else { - b = b[:len(value)] + compressedValue, err := c.compress(value) + if err != nil { + return } - copy(b, value) + // Get a byte slice from the pool and ensure it's properly sized. + b := c.bytePool.Get().([]byte) + if cap(b) < len(compressedValue) { + b = make([]byte, len(compressedValue)) + } else { + b = b[:len(compressedValue)] + } + + copy(b, compressedValue) entry := CacheEntry{ Value: b, @@ -71,10 +79,12 @@ func (c *Cache) Get(key string) ([]byte, bool) { if !ok || entry.(CacheEntry).ExpiresAt.Before(time.Now()) { return nil, false } + compressedValue := entry.(CacheEntry).Value + value, err := c.decompress(compressedValue) + if err != nil { + return nil, false + } - // Copy the value from the byte slice. - value := make([]byte, len(entry.(CacheEntry).Value)) - copy(value, entry.(CacheEntry).Value) return value, true } @@ -110,3 +120,26 @@ func (c *Cache) CleanExpiredEntries() { return true }) } + +func (c *Cache) compress(data []byte) ([]byte, error) { + var buf bytes.Buffer + w := gzip.NewWriter(&buf) + _, err := w.Write(data) + if err != nil { + return nil, err + } + err = w.Close() + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (c *Cache) decompress(data []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + defer r.Close() + return io.ReadAll(r) +} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..b0e2cc7 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,112 @@ +package libpack_cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/suite" +) + +type CacheTestSuite struct { + suite.Suite +} + +func (suite *CacheTestSuite) SetupTest() { +} + +func TestCachingTestSuite(t *testing.T) { + suite.Run(t, new(CacheTestSuite)) +} + +func (suite *CacheTestSuite) Test_New() { + suite.T().Run("should return a new cache", func(t *testing.T) { + cache := New(2 * time.Second) + suite.NotNil(cache) + }) +} + +func (suite *CacheTestSuite) Test_CacheUse() { + cache := New(30 * time.Second) + tests := []struct { + name string + cache_value string + }{ + { + name: "test1", + cache_value: "test1-123", + }, + { + name: "test2", + cache_value: "test2-123", + }, + } + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + cache.Set(tt.name, []byte(tt.name), 5*time.Second) + c, ok := cache.Get(tt.name) + suite.Equal(true, ok) + suite.Equal(tt.name, string(c)) + }) + } +} + +func (suite *CacheTestSuite) Test_CacheDelete() { + cache := New(30 * time.Second) + tests := []struct { + name string + cache_value string + }{ + { + name: "test1", + cache_value: "test1-123", + }, + { + name: "test2", + cache_value: "test2-123", + }, + } + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + cache.Set(tt.name, []byte(tt.name), 5*time.Second) + c, ok := cache.Get(tt.name) + suite.Equal(true, ok) + suite.Equal(tt.name, string(c)) + cache.Delete(tt.name) + c, ok = cache.Get(tt.name) + suite.Equal(false, ok) + suite.Equal("", string(c)) + }) + } +} + +func (suite *CacheTestSuite) Test_CacheExpire() { + cache := New(30 * time.Second) + tests := []struct { + name string + cache_value string + ttl time.Duration + }{ + { + name: "test1", + cache_value: "test1-123", + ttl: 2 * time.Second, + }, + { + name: "test2", + cache_value: "test2-123", + ttl: 5 * time.Second, + }, + } + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + cache.Set(tt.name, []byte(tt.name), tt.ttl) + c, ok := cache.Get(tt.name) + suite.Equal(true, ok) + suite.Equal(tt.name, string(c)) + time.Sleep(tt.ttl) + c, ok = cache.Get(tt.name) + suite.Equal(false, ok) + suite.Equal("", string(c)) + }) + } +}