From 9736973286699748c72b702120f5dfe36753883f Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 1 Mar 2026 13:19:53 -0500 Subject: [PATCH] REmove store interface --- message/{store.go => cache.go} | 109 ++++++++---------- .../{store_postgres.go => cache_postgres.go} | 6 +- ...res_schema.go => cache_postgres_schema.go} | 0 message/{store_sqlite.go => cache_sqlite.go} | 10 +- ...qlite_schema.go => cache_sqlite_schema.go} | 0 ...re_sqlite_test.go => cache_sqlite_test.go} | 2 +- message/{store_test.go => cache_test.go} | 46 ++++---- server/server.go | 8 +- server/server_test.go | 2 +- server/visitor.go | 4 +- webpush/store.go | 43 +++---- webpush/store_postgres.go | 6 +- webpush/store_sqlite.go | 6 +- webpush/store_test.go | 24 ++-- 14 files changed, 122 insertions(+), 144 deletions(-) rename message/{store.go => cache.go} (79%) rename message/{store_postgres.go => cache_postgres.go} (97%) rename message/{store_postgres_schema.go => cache_postgres_schema.go} (100%) rename message/{store_sqlite.go => cache_sqlite.go} (97%) rename message/{store_sqlite_schema.go => cache_sqlite_schema.go} (100%) rename message/{store_sqlite_test.go => cache_sqlite_test.go} (99%) rename message/{store_test.go => cache_test.go} (95%) diff --git a/message/store.go b/message/cache.go similarity index 79% rename from message/store.go rename to message/cache.go index 11d261f0..953a6f7f 100644 --- a/message/store.go +++ b/message/cache.go @@ -20,32 +20,8 @@ const ( var errNoRows = errors.New("no rows found") -// Store is the interface for a message cache store -type Store interface { - AddMessage(m *model.Message) error - AddMessages(ms []*model.Message) error - Message(id string) (*model.Message, error) - MessagesCount() (int, error) - Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) - MessagesDue() ([]*model.Message, error) - MessagesExpired() ([]string, error) - MarkPublished(m *model.Message) error - UpdateMessageTime(messageID string, timestamp int64) error - Topics() ([]string, error) - DeleteMessages(ids ...string) error - DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) - ExpireMessages(topics ...string) error - AttachmentsExpired() ([]string, error) - MarkAttachmentsDeleted(ids ...string) error - AttachmentBytesUsedBySender(sender string) (int64, error) - AttachmentBytesUsedByUser(userID string) (int64, error) - UpdateStats(messages int64) error - Stats() (int64, error) - Close() error -} - -// storeQueries holds the database-specific SQL queries -type storeQueries struct { +// queries holds the database-specific SQL queries +type queries struct { insertMessage string deleteMessage string selectScheduledMessageIDsBySeqID string @@ -71,21 +47,21 @@ type storeQueries struct { updateMessageTime string } -// commonStore implements store operations that are identical across database backends -type commonStore struct { +// Cache stores published messages +type Cache struct { db *sql.DB queue *util.BatchingQueue[*model.Message] nop bool mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer) - queries storeQueries + queries queries } -func newCommonStore(db *sql.DB, queries storeQueries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *commonStore { +func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache { var queue *util.BatchingQueue[*model.Message] if batchSize > 0 || batchTimeout > 0 { queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout) } - c := &commonStore{ + c := &Cache{ db: db, queue: queue, nop: nop, @@ -96,13 +72,13 @@ func newCommonStore(db *sql.DB, queries storeQueries, mu *sync.Mutex, batchSize return c } -func (c *commonStore) maybeLock() { +func (c *Cache) maybeLock() { if c.mu != nil { c.mu.Lock() } } -func (c *commonStore) maybeUnlock() { +func (c *Cache) maybeUnlock() { if c.mu != nil { c.mu.Unlock() } @@ -110,7 +86,7 @@ func (c *commonStore) maybeUnlock() { // AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously. // The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. -func (c *commonStore) AddMessage(m *model.Message) error { +func (c *Cache) AddMessage(m *model.Message) error { if c.queue != nil { c.queue.Enqueue(m) return nil @@ -119,11 +95,11 @@ func (c *commonStore) AddMessage(m *model.Message) error { } // AddMessages synchronously stores a batch of messages to the message cache -func (c *commonStore) AddMessages(ms []*model.Message) error { +func (c *Cache) AddMessages(ms []*model.Message) error { return c.addMessages(ms) } -func (c *commonStore) addMessages(ms []*model.Message) error { +func (c *Cache) addMessages(ms []*model.Message) error { c.maybeLock() defer c.maybeUnlock() if c.nop { @@ -209,7 +185,8 @@ func (c *commonStore) addMessages(ms []*model.Message) error { return nil } -func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { +// Messages returns messages for a topic since the given marker, optionally including scheduled messages +func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { if since.IsNone() { return make([]*model.Message, 0), nil } else if since.IsLatest() { @@ -220,7 +197,7 @@ func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled return c.messagesSinceTime(topic, since, scheduled) } -func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { +func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { var rows *sql.Rows var err error if scheduled { @@ -234,7 +211,7 @@ func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, s return readMessages(rows) } -func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { +func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { var rows *sql.Rows var err error if scheduled { @@ -248,7 +225,7 @@ func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, sch return readMessages(rows) } -func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) { +func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesLatest, topic) if err != nil { return nil, err @@ -256,7 +233,8 @@ func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) { return readMessages(rows) } -func (c *commonStore) MessagesDue() ([]*model.Message, error) { +// MessagesDue returns all messages that are due for publishing +func (c *Cache) MessagesDue() ([]*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix()) if err != nil { return nil, err @@ -265,7 +243,7 @@ func (c *commonStore) MessagesDue() ([]*model.Message, error) { } // MessagesExpired returns a list of IDs for messages that have expired (should be deleted) -func (c *commonStore) MessagesExpired() ([]string, error) { +func (c *Cache) MessagesExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix()) if err != nil { return nil, err @@ -285,7 +263,8 @@ func (c *commonStore) MessagesExpired() ([]string, error) { return ids, nil } -func (c *commonStore) Message(id string) (*model.Message, error) { +// Message returns the message with the given ID, or ErrMessageNotFound if not found +func (c *Cache) Message(id string) (*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesByID, id) if err != nil { return nil, err @@ -298,21 +277,23 @@ func (c *commonStore) Message(id string) (*model.Message, error) { } // UpdateMessageTime updates the time column for a message by ID. This is only used for testing. -func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error { +func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error { c.maybeLock() defer c.maybeUnlock() _, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID) return err } -func (c *commonStore) MarkPublished(m *model.Message) error { +// MarkPublished marks a message as published +func (c *Cache) MarkPublished(m *model.Message) error { c.maybeLock() defer c.maybeUnlock() _, err := c.db.Exec(c.queries.updateMessagePublished, m.ID) return err } -func (c *commonStore) MessagesCount() (int, error) { +// MessagesCount returns the total number of messages in the cache +func (c *Cache) MessagesCount() (int, error) { rows, err := c.db.Query(c.queries.selectMessagesCount) if err != nil { return 0, err @@ -328,7 +309,8 @@ func (c *commonStore) MessagesCount() (int, error) { return count, nil } -func (c *commonStore) Topics() ([]string, error) { +// Topics returns a list of all topics with messages in the cache +func (c *Cache) Topics() ([]string, error) { rows, err := c.db.Query(c.queries.selectTopics) if err != nil { return nil, err @@ -348,7 +330,8 @@ func (c *commonStore) Topics() ([]string, error) { return topics, nil } -func (c *commonStore) DeleteMessages(ids ...string) error { +// DeleteMessages deletes the messages with the given IDs +func (c *Cache) DeleteMessages(ids ...string) error { c.maybeLock() defer c.maybeUnlock() tx, err := c.db.Begin() @@ -366,7 +349,7 @@ func (c *commonStore) DeleteMessages(ids ...string) error { // DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. // It returns the message IDs of the deleted messages, which can be used to clean up attachment files. -func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { +func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { c.maybeLock() defer c.maybeUnlock() tx, err := c.db.Begin() @@ -402,7 +385,8 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s return ids, nil } -func (c *commonStore) ExpireMessages(topics ...string) error { +// ExpireMessages marks messages in the given topics as expired +func (c *Cache) ExpireMessages(topics ...string) error { c.maybeLock() defer c.maybeUnlock() tx, err := c.db.Begin() @@ -418,7 +402,8 @@ func (c *commonStore) ExpireMessages(topics ...string) error { return tx.Commit() } -func (c *commonStore) AttachmentsExpired() ([]string, error) { +// AttachmentsExpired returns message IDs with expired attachments that have not been deleted +func (c *Cache) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) if err != nil { return nil, err @@ -438,7 +423,8 @@ func (c *commonStore) AttachmentsExpired() ([]string, error) { return ids, nil } -func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error { +// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted +func (c *Cache) MarkAttachmentsDeleted(ids ...string) error { c.maybeLock() defer c.maybeUnlock() tx, err := c.db.Begin() @@ -454,7 +440,8 @@ func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error { return tx.Commit() } -func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) { +// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender +func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) { rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) if err != nil { return 0, err @@ -462,7 +449,8 @@ func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) return c.readAttachmentBytesUsed(rows) } -func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) { +// AttachmentBytesUsedByUser returns the total size of active attachments for the given user +func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) { rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) if err != nil { return 0, err @@ -470,7 +458,7 @@ func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) { return c.readAttachmentBytesUsed(rows) } -func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { +func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { defer rows.Close() var size int64 if !rows.Next() { @@ -484,14 +472,16 @@ func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { return size, nil } -func (c *commonStore) UpdateStats(messages int64) error { +// UpdateStats updates the total message count statistic +func (c *Cache) UpdateStats(messages int64) error { c.maybeLock() defer c.maybeUnlock() _, err := c.db.Exec(c.queries.updateStats, messages) return err } -func (c *commonStore) Stats() (messages int64, err error) { +// Stats returns the total message count statistic +func (c *Cache) Stats() (messages int64, err error) { rows, err := c.db.Query(c.queries.selectStats) if err != nil { return 0, err @@ -506,11 +496,12 @@ func (c *commonStore) Stats() (messages int64, err error) { return messages, nil } -func (c *commonStore) Close() error { +// Close closes the underlying database connection +func (c *Cache) Close() error { return c.db.Close() } -func (c *commonStore) processMessageBatches() { +func (c *Cache) processMessageBatches() { if c.queue == nil { return } diff --git a/message/store_postgres.go b/message/cache_postgres.go similarity index 97% rename from message/store_postgres.go rename to message/cache_postgres.go index a21b8f62..f5ced3bb 100644 --- a/message/store_postgres.go +++ b/message/cache_postgres.go @@ -75,7 +75,7 @@ const ( postgresUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2` ) -var pgQueries = storeQueries{ +var pgQueries = queries{ insertMessage: postgresInsertMessageQuery, deleteMessage: postgresDeleteMessageQuery, selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery, @@ -102,9 +102,9 @@ var pgQueries = storeQueries{ } // NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. -func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (Store, error) { +func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) { if err := setupPostgresDB(db); err != nil { return nil, err } - return newCommonStore(db, pgQueries, nil, batchSize, batchTimeout, false), nil + return newCache(db, pgQueries, nil, batchSize, batchTimeout, false), nil } diff --git a/message/store_postgres_schema.go b/message/cache_postgres_schema.go similarity index 100% rename from message/store_postgres_schema.go rename to message/cache_postgres_schema.go diff --git a/message/store_sqlite.go b/message/cache_sqlite.go similarity index 97% rename from message/store_sqlite.go rename to message/cache_sqlite.go index 923f3480..f9d8605e 100644 --- a/message/store_sqlite.go +++ b/message/cache_sqlite.go @@ -78,7 +78,7 @@ const ( sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?` ) -var sqliteQueries = storeQueries{ +var sqliteQueries = queries{ insertMessage: sqliteInsertMessageQuery, deleteMessage: sqliteDeleteMessageQuery, selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery, @@ -105,7 +105,7 @@ var sqliteQueries = storeQueries{ } // NewSQLiteStore creates a SQLite file-backed cache -func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) { +func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) { parentDir := filepath.Dir(filename) if !util.FileExists(parentDir) { return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) @@ -117,17 +117,17 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration if err := setupSQLite(db, startupQueries, cacheDuration); err != nil { return nil, err } - return newCommonStore(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil + return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil } // NewMemStore creates an in-memory cache -func NewMemStore() (Store, error) { +func NewMemStore() (*Cache, error) { return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false) } // NewNopStore creates an in-memory cache that discards all messages; // it is always empty and can be used if caching is entirely disabled -func NewNopStore() (Store, error) { +func NewNopStore() (*Cache, error) { return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true) } diff --git a/message/store_sqlite_schema.go b/message/cache_sqlite_schema.go similarity index 100% rename from message/store_sqlite_schema.go rename to message/cache_sqlite_schema.go diff --git a/message/store_sqlite_test.go b/message/cache_sqlite_test.go similarity index 99% rename from message/store_sqlite_test.go rename to message/cache_sqlite_test.go index e018fa28..e69488e6 100644 --- a/message/store_sqlite_test.go +++ b/message/cache_sqlite_test.go @@ -271,7 +271,7 @@ func newSqliteTestStoreFile(t *testing.T) string { return filepath.Join(t.TempDir(), "cache.db") } -func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store { +func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache { s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false) require.Nil(t, err) t.Cleanup(func() { s.Close() }) diff --git a/message/store_test.go b/message/cache_test.go similarity index 95% rename from message/store_test.go rename to message/cache_test.go index e53c11ec..eb992381 100644 --- a/message/store_test.go +++ b/message/cache_test.go @@ -15,7 +15,7 @@ import ( "heckel.io/ntfy/v2/model" ) -func newSqliteTestStore(t *testing.T) message.Store { +func newSqliteTestStore(t *testing.T) *message.Cache { filename := filepath.Join(t.TempDir(), "cache.db") s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false) require.Nil(t, err) @@ -23,21 +23,21 @@ func newSqliteTestStore(t *testing.T) message.Store { return s } -func newMemTestStore(t *testing.T) message.Store { +func newMemTestStore(t *testing.T) *message.Cache { s, err := message.NewMemStore() require.Nil(t, err) t.Cleanup(func() { s.Close() }) return s } -func newTestPostgresStore(t *testing.T) message.Store { +func newTestPostgresStore(t *testing.T) *message.Cache { testDB := dbtest.CreateTestPostgres(t) store, err := message.NewPostgresStore(testDB, 0, 0) require.Nil(t, err) return store } -func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) { +func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) { t.Run("sqlite", func(t *testing.T) { f(t, newSqliteTestStore(t)) }) @@ -50,7 +50,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) { } func TestStore_Messages(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m1 := model.NewDefaultMessage("mytopic", "my message") m1.Time = 1 @@ -113,7 +113,7 @@ func TestStore_Messages(t *testing.T) { } func TestStore_MessagesLock(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { var wg sync.WaitGroup for i := 0; i < 5000; i++ { wg.Add(1) @@ -127,7 +127,7 @@ func TestStore_MessagesLock(t *testing.T) { } func TestStore_MessagesScheduled(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m1 := model.NewDefaultMessage("mytopic", "message 1") m2 := model.NewDefaultMessage("mytopic", "message 2") m2.Time = time.Now().Add(time.Hour).Unix() @@ -155,7 +155,7 @@ func TestStore_MessagesScheduled(t *testing.T) { } func TestStore_Topics(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message"))) require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1"))) require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2"))) @@ -172,7 +172,7 @@ func TestStore_Topics(t *testing.T) { } func TestStore_MessagesTagsPrioAndTitle(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m := model.NewDefaultMessage("mytopic", "some message") m.Tags = []string{"tag1", "tag2"} m.Priority = 5 @@ -187,7 +187,7 @@ func TestStore_MessagesTagsPrioAndTitle(t *testing.T) { } func TestStore_MessagesSinceID(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m1 := model.NewDefaultMessage("mytopic", "message 1") m1.Time = 100 m2 := model.NewDefaultMessage("mytopic", "message 2") @@ -251,7 +251,7 @@ func TestStore_MessagesSinceID(t *testing.T) { } func TestStore_Prune(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { now := time.Now().Unix() m1 := model.NewDefaultMessage("mytopic", "my message") @@ -290,7 +290,7 @@ func TestStore_Prune(t *testing.T) { } func TestStore_Attachments(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired m := model.NewDefaultMessage("mytopic", "flower for you") m.ID = "m1" @@ -369,7 +369,7 @@ func TestStore_Attachments(t *testing.T) { } func TestStore_AttachmentsExpired(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m := model.NewDefaultMessage("mytopic", "flower for you") m.ID = "m1" m.SequenceID = "m1" @@ -422,7 +422,7 @@ func TestStore_AttachmentsExpired(t *testing.T) { } func TestStore_Sender(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { m1 := model.NewDefaultMessage("mytopic", "mymessage") m1.Sender = netip.MustParseAddr("1.2.3.4") require.Nil(t, s.AddMessage(m1)) @@ -439,7 +439,7 @@ func TestStore_Sender(t *testing.T) { } func TestStore_DeleteScheduledBySequenceID(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Create a scheduled (unpublished) message scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message") scheduledMsg.ID = "scheduled1" @@ -506,7 +506,7 @@ func TestStore_DeleteScheduledBySequenceID(t *testing.T) { } func TestStore_MessageByID(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Add a message m := model.NewDefaultMessage("mytopic", "some message") m.Title = "some title" @@ -531,7 +531,7 @@ func TestStore_MessageByID(t *testing.T) { } func TestStore_MarkPublished(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Add a scheduled message (future time -> unpublished) m := model.NewDefaultMessage("mytopic", "scheduled message") m.Time = time.Now().Add(time.Hour).Unix() @@ -559,7 +559,7 @@ func TestStore_MarkPublished(t *testing.T) { } func TestStore_ExpireMessages(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Add messages to two topics m1 := model.NewDefaultMessage("topic1", "message 1") m1.Expires = time.Now().Add(time.Hour).Unix() @@ -600,7 +600,7 @@ func TestStore_ExpireMessages(t *testing.T) { } func TestStore_MarkAttachmentsDeleted(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Add a message with an expired attachment (file needs cleanup) m1 := model.NewDefaultMessage("mytopic", "old file") m1.ID = "msg1" @@ -659,7 +659,7 @@ func TestStore_MarkAttachmentsDeleted(t *testing.T) { } func TestStore_Stats(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Initial stats should be zero messages, err := s.Stats() require.Nil(t, err) @@ -680,7 +680,7 @@ func TestStore_Stats(t *testing.T) { } func TestStore_AddMessages(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Batch add multiple messages msgs := []*model.Message{ model.NewDefaultMessage("mytopic", "batch 1"), @@ -711,7 +711,7 @@ func TestStore_AddMessages(t *testing.T) { } func TestStore_MessagesDue(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Add a message scheduled in the past (i.e. it's due now) m1 := model.NewDefaultMessage("mytopic", "due message") m1.Time = time.Now().Add(-time.Second).Unix() @@ -755,7 +755,7 @@ func TestStore_MessagesDue(t *testing.T) { } func TestStore_MessageFieldRoundTrip(t *testing.T) { - forEachBackend(t, func(t *testing.T, s message.Store) { + forEachBackend(t, func(t *testing.T, s *message.Cache) { // Create a message with all fields populated m := model.NewDefaultMessage("mytopic", "hello world") m.SequenceID = "custom_seq_id" diff --git a/server/server.go b/server/server.go index 55caaa53..a64772af 100644 --- a/server/server.go +++ b/server/server.go @@ -62,8 +62,8 @@ type Server struct { messages int64 // Total number of messages (persisted if messageCache enabled) messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! - messageCache message.Store // Database that stores the messages - webPush webpush.Store // Database that stores web push subscriptions + messageCache *message.Cache // Database that stores the messages + webPush *webpush.Store // Database that stores web push subscriptions fileCache *fileCache // File system based cache that stores attachments stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) @@ -191,7 +191,7 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - var wp webpush.Store + var wp *webpush.Store if conf.WebPushPublicKey != "" { if pool != nil { wp, err = webpush.NewPostgresStore(pool) @@ -277,7 +277,7 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config, pool *sql.DB) (message.Store, error) { +func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) { if conf.CacheDuration == 0 { return message.NewNopStore() } else if pool != nil { diff --git a/server/server_test.go b/server/server_test.go index 1f206562..fc85da05 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4121,7 +4121,7 @@ func TestServer_DeleteScheduledMessage_WithAttachment(t *testing.T) { }) } -func newMemTestCache(t *testing.T) message.Store { +func newMemTestCache(t *testing.T) *message.Cache { c, err := message.NewMemStore() require.Nil(t, err) return c diff --git a/server/visitor.go b/server/visitor.go index 12217be6..6d8fe6d1 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -54,7 +54,7 @@ const ( // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config - messageCache message.Store + messageCache *message.Cache userManager *user.Manager // May be nil ip netip.Addr // Visitor IP address user *user.User // Only set if authenticated user, otherwise nil @@ -115,7 +115,7 @@ const ( visitorLimitBasisTier = visitorLimitBasis("tier") ) -func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { +func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { var messages, emails, calls int64 if user != nil { messages = user.Stats.Messages diff --git a/webpush/store.go b/webpush/store.go index 4fc5d77f..20ff19cc 100644 --- a/webpush/store.go +++ b/webpush/store.go @@ -21,21 +21,14 @@ var ( ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty") ) -// Store is the interface for a web push subscription store. -type Store interface { - UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error - SubscriptionsForTopic(topic string) ([]*Subscription, error) - SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) - MarkExpiryWarningSent(subscriptions []*Subscription) error - RemoveSubscriptionsByEndpoint(endpoint string) error - RemoveSubscriptionsByUserID(userID string) error - RemoveExpiredSubscriptions(expireAfter time.Duration) error - SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error - Close() error +// Store holds the database connection and queries for web push subscriptions. +type Store struct { + db *sql.DB + queries queries } -// storeQueries holds the database-specific SQL queries. -type storeQueries struct { +// queries holds the database-specific SQL queries. +type queries struct { selectSubscriptionIDByEndpoint string selectSubscriptionCountBySubscriberIP string selectSubscriptionsForTopic string @@ -51,14 +44,8 @@ type storeQueries struct { deleteSubscriptionTopicWithoutSubscription string } -// commonStore implements store operations that are identical across database backends. -type commonStore struct { - db *sql.DB - queries storeQueries -} - // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. -func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { +func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { tx, err := s.db.Begin() if err != nil { return err @@ -97,7 +84,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s } // SubscriptionsForTopic returns all subscriptions for the given topic. -func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { +func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) { rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic) if err != nil { return nil, err @@ -107,7 +94,7 @@ func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, erro } // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. -func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { +func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix()) if err != nil { return nil, err @@ -117,7 +104,7 @@ func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscri } // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. -func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { +func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error { tx, err := s.db.Begin() if err != nil { return err @@ -132,13 +119,13 @@ func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error } // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. -func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error { +func (s *Store) RemoveSubscriptionsByEndpoint(endpoint string) error { _, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint) return err } // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. -func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error { +func (s *Store) RemoveSubscriptionsByUserID(userID string) error { if userID == "" { return ErrWebPushUserIDCannotBeEmpty } @@ -147,7 +134,7 @@ func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error { } // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. -func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { +func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error { _, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix()) if err != nil { return err @@ -158,13 +145,13 @@ func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) erro // SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is // exported for testing purposes. -func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error { +func (s *Store) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error { _, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint) return err } // Close closes the underlying database connection. -func (s *commonStore) Close() error { +func (s *Store) Close() error { return s.db.Close() } diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index 5dd72e70..090a93c4 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -71,13 +71,13 @@ const ( ) // NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool. -func NewPostgresStore(db *sql.DB) (Store, error) { +func NewPostgresStore(db *sql.DB) (*Store, error) { if err := setupPostgresDB(db); err != nil { return nil, err } - return &commonStore{ + return &Store{ db: db, - queries: storeQueries{ + queries: queries{ selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery, selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery, diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index a00c2e61..867891ec 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -76,7 +76,7 @@ const ( ) // NewSQLiteStore creates a new SQLite-backed web push store. -func NewSQLiteStore(filename, startupQueries string) (Store, error) { +func NewSQLiteStore(filename, startupQueries string) (*Store, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -87,9 +87,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { if err := runSQLiteStartupQueries(db, startupQueries); err != nil { return nil, err } - return &commonStore{ + return &Store{ db: db, - queries: storeQueries{ + queries: queries{ selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpointQuery, selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIPQuery, selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery, diff --git a/webpush/store_test.go b/webpush/store_test.go index bb27bf5f..348f9998 100644 --- a/webpush/store_test.go +++ b/webpush/store_test.go @@ -14,7 +14,7 @@ import ( const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF" -func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) { +func forEachBackend(t *testing.T, f func(t *testing.T, store *webpush.Store)) { t.Run("sqlite", func(t *testing.T) { store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "") require.Nil(t, err) @@ -30,7 +30,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) { } func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) subs, err := store.SubscriptionsForTopic("test-topic") @@ -49,7 +49,7 @@ func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) { } func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert 10 subscriptions with the same IP address for i := 0; i < 10; i++ { endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) @@ -68,7 +68,7 @@ func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) { } func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics, and another with one topic require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) @@ -99,7 +99,7 @@ func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) { } func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert a subscription require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) @@ -124,7 +124,7 @@ func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) { } func TestStoreRemoveByUserIDMultiple(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert two subscriptions for u_1234 and one for u_5678 require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) @@ -147,7 +147,7 @@ func TestStoreRemoveByUserIDMultiple(t *testing.T) { } func TestStoreRemoveByEndpoint(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) subs, err := store.SubscriptionsForTopic("topic1") @@ -163,7 +163,7 @@ func TestStoreRemoveByEndpoint(t *testing.T) { } func TestStoreRemoveByUserID(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) subs, err := store.SubscriptionsForTopic("topic1") @@ -179,13 +179,13 @@ func TestStoreRemoveByUserID(t *testing.T) { } func TestStoreRemoveByUserIDEmpty(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID("")) }) } func TestStoreExpiryWarningSent(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) @@ -209,7 +209,7 @@ func TestStoreExpiryWarningSent(t *testing.T) { } func TestStoreExpiring(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) subs, err := store.SubscriptionsForTopic("topic1") @@ -231,7 +231,7 @@ func TestStoreExpiring(t *testing.T) { } func TestStoreRemoveExpired(t *testing.T) { - forEachBackend(t, func(t *testing.T, store webpush.Store) { + forEachBackend(t, func(t *testing.T, store *webpush.Store) { // Insert subscription with two topics require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) subs, err := store.SubscriptionsForTopic("topic1")