diff --git a/message/store_postgres_test.go b/message/store_postgres_test.go deleted file mode 100644 index 522620db..00000000 --- a/message/store_postgres_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package message_test - -import ( - "testing" - - dbtest "heckel.io/ntfy/v2/db/test" - "heckel.io/ntfy/v2/message" - - "github.com/stretchr/testify/require" -) - -func newTestPostgresStore(t *testing.T) message.Store { - testDB := dbtest.CreateTestPostgres(t) - store, err := message.NewPostgresStore(testDB, 0, 0) - require.Nil(t, err) - return store -} - -func TestPostgresStore_Messages(t *testing.T) { - testCacheMessages(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessagesLock(t *testing.T) { - testCacheMessagesLock(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_Topics(t *testing.T) { - testCacheTopics(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_Prune(t *testing.T) { - testCachePrune(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_Attachments(t *testing.T) { - testCacheAttachments(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_AttachmentsExpired(t *testing.T) { - testCacheAttachmentsExpired(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_Sender(t *testing.T) { - testSender(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) { - testDeleteScheduledBySequenceID(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessageByID(t *testing.T) { - testMessageByID(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MarkPublished(t *testing.T) { - testMarkPublished(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_ExpireMessages(t *testing.T) { - testExpireMessages(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) { - testMarkAttachmentsDeleted(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_Stats(t *testing.T) { - testStats(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_AddMessages(t *testing.T) { - testAddMessages(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessagesDue(t *testing.T) { - testMessagesDue(t, newTestPostgresStore(t)) -} - -func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) { - testMessageFieldRoundTrip(t, newTestPostgresStore(t)) -} diff --git a/message/store_sqlite_test.go b/message/store_sqlite_test.go index aa102044..e018fa28 100644 --- a/message/store_sqlite_test.go +++ b/message/store_sqlite_test.go @@ -13,158 +13,6 @@ import ( "heckel.io/ntfy/v2/model" ) -func TestSqliteStore_Messages(t *testing.T) { - testCacheMessages(t, newSqliteTestStore(t)) -} - -func TestMemStore_Messages(t *testing.T) { - testCacheMessages(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessagesLock(t *testing.T) { - testCacheMessagesLock(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessagesLock(t *testing.T) { - testCacheMessagesLock(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newMemTestStore(t)) -} - -func TestSqliteStore_Topics(t *testing.T) { - testCacheTopics(t, newSqliteTestStore(t)) -} - -func TestMemStore_Topics(t *testing.T) { - testCacheTopics(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newMemTestStore(t)) -} - -func TestSqliteStore_Prune(t *testing.T) { - testCachePrune(t, newSqliteTestStore(t)) -} - -func TestMemStore_Prune(t *testing.T) { - testCachePrune(t, newMemTestStore(t)) -} - -func TestSqliteStore_Attachments(t *testing.T) { - testCacheAttachments(t, newSqliteTestStore(t)) -} - -func TestMemStore_Attachments(t *testing.T) { - testCacheAttachments(t, newMemTestStore(t)) -} - -func TestSqliteStore_AttachmentsExpired(t *testing.T) { - testCacheAttachmentsExpired(t, newSqliteTestStore(t)) -} - -func TestMemStore_AttachmentsExpired(t *testing.T) { - testCacheAttachmentsExpired(t, newMemTestStore(t)) -} - -func TestSqliteStore_Sender(t *testing.T) { - testSender(t, newSqliteTestStore(t)) -} - -func TestMemStore_Sender(t *testing.T) { - testSender(t, newMemTestStore(t)) -} - -func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) { - testDeleteScheduledBySequenceID(t, newSqliteTestStore(t)) -} - -func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) { - testDeleteScheduledBySequenceID(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessageByID(t *testing.T) { - testMessageByID(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessageByID(t *testing.T) { - testMessageByID(t, newMemTestStore(t)) -} - -func TestSqliteStore_MarkPublished(t *testing.T) { - testMarkPublished(t, newSqliteTestStore(t)) -} - -func TestMemStore_MarkPublished(t *testing.T) { - testMarkPublished(t, newMemTestStore(t)) -} - -func TestSqliteStore_ExpireMessages(t *testing.T) { - testExpireMessages(t, newSqliteTestStore(t)) -} - -func TestMemStore_ExpireMessages(t *testing.T) { - testExpireMessages(t, newMemTestStore(t)) -} - -func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) { - testMarkAttachmentsDeleted(t, newSqliteTestStore(t)) -} - -func TestMemStore_MarkAttachmentsDeleted(t *testing.T) { - testMarkAttachmentsDeleted(t, newMemTestStore(t)) -} - -func TestSqliteStore_Stats(t *testing.T) { - testStats(t, newSqliteTestStore(t)) -} - -func TestMemStore_Stats(t *testing.T) { - testStats(t, newMemTestStore(t)) -} - -func TestSqliteStore_AddMessages(t *testing.T) { - testAddMessages(t, newSqliteTestStore(t)) -} - -func TestMemStore_AddMessages(t *testing.T) { - testAddMessages(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessagesDue(t *testing.T) { - testMessagesDue(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessagesDue(t *testing.T) { - testMessagesDue(t, newMemTestStore(t)) -} - -func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) { - testMessageFieldRoundTrip(t, newSqliteTestStore(t)) -} - -func TestMemStore_MessageFieldRoundTrip(t *testing.T) { - testMessageFieldRoundTrip(t, newMemTestStore(t)) -} - func TestSqliteStore_Migration_From0(t *testing.T) { filename := newSqliteTestStoreFile(t) db, err := sql.Open("sqlite3", filename) @@ -419,14 +267,6 @@ func TestNopStore(t *testing.T) { require.Empty(t, topics) } -func newSqliteTestStore(t *testing.T) message.Store { - filename := filepath.Join(t.TempDir(), "cache.db") - s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false) - require.Nil(t, err) - t.Cleanup(func() { s.Close() }) - return s -} - func newSqliteTestStoreFile(t *testing.T) string { return filepath.Join(t.TempDir(), "cache.db") } @@ -438,13 +278,6 @@ func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) m return s } -func newMemTestStore(t *testing.T) message.Store { - s, err := message.NewMemStore() - require.Nil(t, err) - t.Cleanup(func() { s.Close() }) - return s -} - func checkSqliteSchemaVersion(t *testing.T, filename string) { db, err := sql.Open("sqlite3", filename) require.Nil(t, err) diff --git a/message/store_test.go b/message/store_test.go index e3297c96..50bf84f3 100644 --- a/message/store_test.go +++ b/message/store_test.go @@ -2,6 +2,7 @@ package message_test import ( "net/netip" + "path/filepath" "sort" "sync" "testing" @@ -9,759 +10,832 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" ) -func testCacheMessages(t *testing.T, s message.Store) { - m1 := model.NewDefaultMessage("mytopic", "my message") - m1.Time = 1 - - m2 := model.NewDefaultMessage("mytopic", "my other message") - m2.Time = 2 - - require.Nil(t, s.AddMessage(m1)) - require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message"))) - require.Nil(t, s.AddMessage(m2)) - - // Adding invalid - require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added! - require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added! - - // mytopic: count - counts, err := s.MessageCounts() +func newSqliteTestStore(t *testing.T) message.Store { + filename := filepath.Join(t.TempDir(), "cache.db") + s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false) require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - - // mytopic: since all - messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) - require.Equal(t, 2, len(messages)) - require.Equal(t, "my message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - require.Equal(t, model.MessageEvent, messages[0].Event) - require.Equal(t, "", messages[0].Title) - require.Equal(t, 0, messages[0].Priority) - require.Nil(t, messages[0].Tags) - require.Equal(t, "my other message", messages[1].Message) - - // mytopic: since none - messages, _ = s.Messages("mytopic", model.SinceNoMessages, false) - require.Empty(t, messages) - - // mytopic: since m1 (by ID) - messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, m2.ID, messages[0].ID) - require.Equal(t, "my other message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - - // mytopic: since 2 - messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) - - // mytopic: latest - messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) - - // example: count - counts, err = s.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["example"]) - - // example: since all - messages, _ = s.Messages("example", model.SinceAllMessages, false) - require.Equal(t, "my example message", messages[0].Message) - - // non-existing: count - counts, err = s.MessageCounts() - require.Nil(t, err) - require.Equal(t, 0, counts["doesnotexist"]) - - // non-existing: since all - messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false) - require.Empty(t, messages) + t.Cleanup(func() { s.Close() }) + return s } -func testCacheMessagesLock(t *testing.T, s message.Store) { - var wg sync.WaitGroup - for i := 0; i < 5000; i++ { - wg.Add(1) - go func() { - assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message"))) - wg.Done() - }() - } - wg.Wait() +func newMemTestStore(t *testing.T) message.Store { + s, err := message.NewMemStore() + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + return s } -func testCacheMessagesScheduled(t *testing.T, s message.Store) { - m1 := model.NewDefaultMessage("mytopic", "message 1") - m2 := model.NewDefaultMessage("mytopic", "message 2") - m2.Time = time.Now().Add(time.Hour).Unix() - m3 := model.NewDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! - m4 := model.NewDefaultMessage("mytopic2", "message 4") - m4.Time = time.Now().Add(time.Minute).Unix() - require.Nil(t, s.AddMessage(m1)) - require.Nil(t, s.AddMessage(m2)) - require.Nil(t, s.AddMessage(m3)) - - messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled - require.Equal(t, 1, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - - messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) // Order! - require.Equal(t, "message 2", messages[2].Message) - - messages, _ = s.MessagesDue() - require.Empty(t, messages) +func newTestPostgresStore(t *testing.T) message.Store { + testDB := dbtest.CreateTestPostgres(t) + store, err := message.NewPostgresStore(testDB, 0, 0) + require.Nil(t, err) + return store } -func testCacheTopics(t *testing.T, s message.Store) { - require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message"))) - require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1"))) - require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2"))) - require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3"))) - - topics, err := s.Topics() - if err != nil { - t.Fatal(err) - } - require.Equal(t, 2, len(topics)) - require.Contains(t, topics, "topic1") - require.Contains(t, topics, "topic2") +func forEachBackend(t *testing.T, f func(t *testing.T, s message.Store)) { + t.Run("sqlite", func(t *testing.T) { + f(t, newSqliteTestStore(t)) + }) + t.Run("mem", func(t *testing.T) { + f(t, newMemTestStore(t)) + }) + t.Run("postgres", func(t *testing.T) { + f(t, newTestPostgresStore(t)) + }) } -func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) { - m := model.NewDefaultMessage("mytopic", "some message") - m.Tags = []string{"tag1", "tag2"} - m.Priority = 5 - m.Title = "some title" - require.Nil(t, s.AddMessage(m)) +func TestStore_Messages(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "my message") + m1.Time = 1 - messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) - require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) - require.Equal(t, 5, messages[0].Priority) - require.Equal(t, "some title", messages[0].Title) + m2 := model.NewDefaultMessage("mytopic", "my other message") + m2.Time = 2 + + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message"))) + require.Nil(t, s.AddMessage(m2)) + + // Adding invalid + require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added! + require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added! + + // mytopic: count + counts, err := s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + + // mytopic: since all + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) + require.Equal(t, 2, len(messages)) + require.Equal(t, "my message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) + require.Equal(t, model.MessageEvent, messages[0].Event) + require.Equal(t, "", messages[0].Title) + require.Equal(t, 0, messages[0].Priority) + require.Nil(t, messages[0].Tags) + require.Equal(t, "my other message", messages[1].Message) + + // mytopic: since none + messages, _ = s.Messages("mytopic", model.SinceNoMessages, false) + require.Empty(t, messages) + + // mytopic: since m1 (by ID) + messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, m2.ID, messages[0].ID) + require.Equal(t, "my other message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) + + // mytopic: since 2 + messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + + // mytopic: latest + messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + + // example: count + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["example"]) + + // example: since all + messages, _ = s.Messages("example", model.SinceAllMessages, false) + require.Equal(t, "my example message", messages[0].Message) + + // non-existing: count + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 0, counts["doesnotexist"]) + + // non-existing: since all + messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false) + require.Empty(t, messages) + }) } -func testCacheMessagesSinceID(t *testing.T, s message.Store) { - m1 := model.NewDefaultMessage("mytopic", "message 1") - m1.Time = 100 - m2 := model.NewDefaultMessage("mytopic", "message 2") - m2.Time = 200 - m3 := model.NewDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 - m4 := model.NewDefaultMessage("mytopic", "message 4") - m4.Time = 400 - m5 := model.NewDefaultMessage("mytopic", "message 5") - m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 - m6 := model.NewDefaultMessage("mytopic", "message 6") - m6.Time = 600 - m7 := model.NewDefaultMessage("mytopic", "message 7") - m7.Time = 700 - - require.Nil(t, s.AddMessage(m1)) - require.Nil(t, s.AddMessage(m2)) - require.Nil(t, s.AddMessage(m3)) - require.Nil(t, s.AddMessage(m4)) - require.Nil(t, s.AddMessage(m5)) - require.Nil(t, s.AddMessage(m6)) - require.Nil(t, s.AddMessage(m7)) - - // Case 1: Since ID exists, exclude scheduled - messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false) - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! - require.Equal(t, "message 7", messages[2].Message) - - // Case 2: Since ID exists, include scheduled - messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true) - require.Equal(t, 5, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) - require.Equal(t, "message 7", messages[2].Message) - require.Equal(t, "message 5", messages[3].Message) // Order! - require.Equal(t, "message 3", messages[4].Message) // Order! - - // Case 3: Since ID does not exist (-> Return all messages), include scheduled - messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true) - require.Equal(t, 7, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 2", messages[1].Message) - require.Equal(t, "message 4", messages[2].Message) - require.Equal(t, "message 6", messages[3].Message) - require.Equal(t, "message 7", messages[4].Message) - require.Equal(t, "message 5", messages[5].Message) // Order! - require.Equal(t, "message 3", messages[6].Message) // Order! - - // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled - messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false) - require.Equal(t, 0, len(messages)) - - // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled - messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true) - require.Equal(t, 2, len(messages)) - require.Equal(t, "message 5", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) +func TestStore_MessagesLock(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + var wg sync.WaitGroup + for i := 0; i < 5000; i++ { + wg.Add(1) + go func() { + assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message"))) + wg.Done() + }() + } + wg.Wait() + }) } -func testCachePrune(t *testing.T, s message.Store) { - now := time.Now().Unix() +func TestStore_MessagesScheduled(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "message 1") + m2 := model.NewDefaultMessage("mytopic", "message 2") + m2.Time = time.Now().Add(time.Hour).Unix() + m3 := model.NewDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! + m4 := model.NewDefaultMessage("mytopic2", "message 4") + m4.Time = time.Now().Add(time.Minute).Unix() + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) - m1 := model.NewDefaultMessage("mytopic", "my message") - m1.Time = now - 10 - m1.Expires = now - 5 + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled + require.Equal(t, 1, len(messages)) + require.Equal(t, "message 1", messages[0].Message) - m2 := model.NewDefaultMessage("mytopic", "my other message") - m2.Time = now - 5 - m2.Expires = now + 5 // In the future + messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) // Order! + require.Equal(t, "message 2", messages[2].Message) - m3 := model.NewDefaultMessage("another_topic", "and another one") - m3.Time = now - 12 - m3.Expires = now - 2 - - require.Nil(t, s.AddMessage(m1)) - require.Nil(t, s.AddMessage(m2)) - require.Nil(t, s.AddMessage(m3)) - - counts, err := s.MessageCounts() - require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - require.Equal(t, 1, counts["another_topic"]) - - expiredMessageIDs, err := s.MessagesExpired() - require.Nil(t, err) - require.Nil(t, s.DeleteMessages(expiredMessageIDs...)) - - counts, err = s.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["mytopic"]) - require.Equal(t, 0, counts["another_topic"]) - - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) + messages, _ = s.MessagesDue() + require.Empty(t, messages) + }) } -func testCacheAttachments(t *testing.T, s message.Store) { - expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired - m := model.NewDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.SequenceID = "m1" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &model.Attachment{ - Name: "flower.jpg", - Type: "image/jpeg", - Size: 5000, - Expires: expires1, - URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", - } - require.Nil(t, s.AddMessage(m)) +func TestStore_Topics(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3"))) - expires2 := time.Now().Add(2 * time.Hour).Unix() // Future - m = model.NewDefaultMessage("mytopic", "sending you a car") - m.ID = "m2" - m.SequenceID = "m2" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &model.Attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: expires2, - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, s.AddMessage(m)) - - expires3 := time.Now().Add(1 * time.Hour).Unix() // Future - m = model.NewDefaultMessage("another-topic", "sending you another car") - m.ID = "m3" - m.SequenceID = "m3" - m.User = "u_BAsbaAa" - m.Sender = netip.MustParseAddr("5.6.7.8") - m.Attachment = &model.Attachment{ - Name: "another-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: expires3, - URL: "https://ntfy.sh/file/zakaDHFW.jpg", - } - require.Nil(t, s.AddMessage(m)) - - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - - require.Equal(t, "flower for you", messages[0].Message) - require.Equal(t, "flower.jpg", messages[0].Attachment.Name) - require.Equal(t, "image/jpeg", messages[0].Attachment.Type) - require.Equal(t, int64(5000), messages[0].Attachment.Size) - require.Equal(t, expires1, messages[0].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[0].Sender.String()) - - require.Equal(t, "sending you a car", messages[1].Message) - require.Equal(t, "car.jpg", messages[1].Attachment.Name) - require.Equal(t, "image/jpeg", messages[1].Attachment.Type) - require.Equal(t, int64(10000), messages[1].Attachment.Size) - require.Equal(t, expires2, messages[1].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[1].Sender.String()) - - size, err := s.AttachmentBytesUsedBySender("1.2.3.4") - require.Nil(t, err) - require.Equal(t, int64(10000), size) - - size, err = s.AttachmentBytesUsedBySender("5.6.7.8") - require.Nil(t, err) - require.Equal(t, int64(0), size) // Accounted to the user, not the IP! - - size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa") - require.Nil(t, err) - require.Equal(t, int64(20000), size) + topics, err := s.Topics() + if err != nil { + t.Fatal(err) + } + require.Equal(t, 2, len(topics)) + require.Contains(t, topics, "topic1") + require.Contains(t, topics, "topic2") + }) } -func testCacheAttachmentsExpired(t *testing.T, s message.Store) { - m := model.NewDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.SequenceID = "m1" - m.Expires = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(m)) +func TestStore_MessagesTagsPrioAndTitle(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m := model.NewDefaultMessage("mytopic", "some message") + m.Tags = []string{"tag1", "tag2"} + m.Priority = 5 + m.Title = "some title" + require.Nil(t, s.AddMessage(m)) - m = model.NewDefaultMessage("mytopic", "message with attachment") - m.ID = "m2" - m.SequenceID = "m2" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &model.Attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: time.Now().Add(2 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, s.AddMessage(m)) - - m = model.NewDefaultMessage("mytopic", "message with external attachment") - m.ID = "m3" - m.SequenceID = "m3" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &model.Attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Expires: 0, // Unknown! - URL: "https://somedomain.com/car.jpg", - } - require.Nil(t, s.AddMessage(m)) - - m = model.NewDefaultMessage("mytopic2", "message with expired attachment") - m.ID = "m4" - m.SequenceID = "m4" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &model.Attachment{ - Name: "expired-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: time.Now().Add(-1 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, s.AddMessage(m)) - - ids, err := s.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 1, len(ids)) - require.Equal(t, "m4", ids[0]) + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) + require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) + require.Equal(t, 5, messages[0].Priority) + require.Equal(t, "some title", messages[0].Title) + }) } -func testSender(t *testing.T, s message.Store) { - m1 := model.NewDefaultMessage("mytopic", "mymessage") - m1.Sender = netip.MustParseAddr("1.2.3.4") - require.Nil(t, s.AddMessage(m1)) +func TestStore_MessagesSinceID(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "message 1") + m1.Time = 100 + m2 := model.NewDefaultMessage("mytopic", "message 2") + m2.Time = 200 + m3 := model.NewDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 + m4 := model.NewDefaultMessage("mytopic", "message 4") + m4.Time = 400 + m5 := model.NewDefaultMessage("mytopic", "message 5") + m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 + m6 := model.NewDefaultMessage("mytopic", "message 6") + m6.Time = 600 + m7 := model.NewDefaultMessage("mytopic", "message 7") + m7.Time = 700 - m2 := model.NewDefaultMessage("mytopic", "mymessage without sender") - require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) + require.Nil(t, s.AddMessage(m4)) + require.Nil(t, s.AddMessage(m5)) + require.Nil(t, s.AddMessage(m6)) + require.Nil(t, s.AddMessage(m7)) - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) - require.Equal(t, messages[1].Sender, netip.Addr{}) + // Case 1: Since ID exists, exclude scheduled + messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false) + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! + require.Equal(t, "message 7", messages[2].Message) + + // Case 2: Since ID exists, include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true) + require.Equal(t, 5, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) + require.Equal(t, "message 7", messages[2].Message) + require.Equal(t, "message 5", messages[3].Message) // Order! + require.Equal(t, "message 3", messages[4].Message) // Order! + + // Case 3: Since ID does not exist (-> Return all messages), include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true) + require.Equal(t, 7, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 2", messages[1].Message) + require.Equal(t, "message 4", messages[2].Message) + require.Equal(t, "message 6", messages[3].Message) + require.Equal(t, "message 7", messages[4].Message) + require.Equal(t, "message 5", messages[5].Message) // Order! + require.Equal(t, "message 3", messages[6].Message) // Order! + + // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false) + require.Equal(t, 0, len(messages)) + + // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true) + require.Equal(t, 2, len(messages)) + require.Equal(t, "message 5", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) + }) } -func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) { - // Create a scheduled (unpublished) message - scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message") - scheduledMsg.ID = "scheduled1" - scheduledMsg.SequenceID = "seq123" - scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled - require.Nil(t, s.AddMessage(scheduledMsg)) +func TestStore_Prune(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + now := time.Now().Unix() - // Create a published message with different sequence ID - publishedMsg := model.NewDefaultMessage("mytopic", "published message") - publishedMsg.ID = "published1" - publishedMsg.SequenceID = "seq456" - publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published - require.Nil(t, s.AddMessage(publishedMsg)) + m1 := model.NewDefaultMessage("mytopic", "my message") + m1.Time = now - 10 + m1.Expires = now - 5 - // Create a scheduled message in a different topic - otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled") - otherTopicMsg.ID = "other1" - otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg - otherTopicMsg.Time = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(otherTopicMsg)) + m2 := model.NewDefaultMessage("mytopic", "my other message") + m2.Time = now - 5 + m2.Expires = now + 5 // In the future - // Verify all messages exist (including scheduled) - messages, err := s.Messages("mytopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) + m3 := model.NewDefaultMessage("another_topic", "and another one") + m3.Time = now - 12 + m3.Expires = now - 2 - messages, err = s.Messages("othertopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) - // Delete scheduled message by sequence ID and verify returned IDs - deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123") - require.Nil(t, err) - require.Equal(t, 1, len(deletedIDs)) - require.Equal(t, "scheduled1", deletedIDs[0]) + counts, err := s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + require.Equal(t, 1, counts["another_topic"]) - // Verify scheduled message is deleted - messages, err = s.Messages("mytopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "published message", messages[0].Message) + expiredMessageIDs, err := s.MessagesExpired() + require.Nil(t, err) + require.Nil(t, s.DeleteMessages(expiredMessageIDs...)) - // Verify other topic's message still exists (topic-scoped deletion) - messages, err = s.Messages("othertopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "other scheduled", messages[0].Message) + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["mytopic"]) + require.Equal(t, 0, counts["another_topic"]) - // Deleting non-existent sequence ID should return empty list - deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent") - require.Nil(t, err) - require.Empty(t, deletedIDs) - - // Deleting published message should not affect it (only deletes unpublished) - deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456") - require.Nil(t, err) - require.Empty(t, deletedIDs) - - messages, err = s.Messages("mytopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "published message", messages[0].Message) + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + }) } -func testMessageByID(t *testing.T, s message.Store) { - // Add a message - m := model.NewDefaultMessage("mytopic", "some message") - m.Title = "some title" - m.Priority = 4 - m.Tags = []string{"tag1", "tag2"} - require.Nil(t, s.AddMessage(m)) +func TestStore_Attachments(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired + m := model.NewDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.SequenceID = "m1" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &model.Attachment{ + Name: "flower.jpg", + Type: "image/jpeg", + Size: 5000, + Expires: expires1, + URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", + } + require.Nil(t, s.AddMessage(m)) - // Retrieve by ID - retrieved, err := s.Message(m.ID) - require.Nil(t, err) - require.Equal(t, m.ID, retrieved.ID) - require.Equal(t, "mytopic", retrieved.Topic) - require.Equal(t, "some message", retrieved.Message) - require.Equal(t, "some title", retrieved.Title) - require.Equal(t, 4, retrieved.Priority) - require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags) + expires2 := time.Now().Add(2 * time.Hour).Unix() // Future + m = model.NewDefaultMessage("mytopic", "sending you a car") + m.ID = "m2" + m.SequenceID = "m2" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: expires2, + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) - // Non-existent ID returns ErrMessageNotFound - _, err = s.Message("doesnotexist") - require.Equal(t, model.ErrMessageNotFound, err) + expires3 := time.Now().Add(1 * time.Hour).Unix() // Future + m = model.NewDefaultMessage("another-topic", "sending you another car") + m.ID = "m3" + m.SequenceID = "m3" + m.User = "u_BAsbaAa" + m.Sender = netip.MustParseAddr("5.6.7.8") + m.Attachment = &model.Attachment{ + Name: "another-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: expires3, + URL: "https://ntfy.sh/file/zakaDHFW.jpg", + } + require.Nil(t, s.AddMessage(m)) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + require.Equal(t, "flower for you", messages[0].Message) + require.Equal(t, "flower.jpg", messages[0].Attachment.Name) + require.Equal(t, "image/jpeg", messages[0].Attachment.Type) + require.Equal(t, int64(5000), messages[0].Attachment.Size) + require.Equal(t, expires1, messages[0].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[0].Sender.String()) + + require.Equal(t, "sending you a car", messages[1].Message) + require.Equal(t, "car.jpg", messages[1].Attachment.Name) + require.Equal(t, "image/jpeg", messages[1].Attachment.Type) + require.Equal(t, int64(10000), messages[1].Attachment.Size) + require.Equal(t, expires2, messages[1].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[1].Sender.String()) + + size, err := s.AttachmentBytesUsedBySender("1.2.3.4") + require.Nil(t, err) + require.Equal(t, int64(10000), size) + + size, err = s.AttachmentBytesUsedBySender("5.6.7.8") + require.Nil(t, err) + require.Equal(t, int64(0), size) // Accounted to the user, not the IP! + + size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa") + require.Nil(t, err) + require.Equal(t, int64(20000), size) + }) } -func testMarkPublished(t *testing.T, s message.Store) { - // Add a scheduled message (future time → unpublished) - m := model.NewDefaultMessage("mytopic", "scheduled message") - m.Time = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(m)) +func TestStore_AttachmentsExpired(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m := model.NewDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.SequenceID = "m1" + m.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m)) - // Verify it does not appear in non-scheduled queries - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 0, len(messages)) + m = model.NewDefaultMessage("mytopic", "message with attachment") + m.ID = "m2" + m.SequenceID = "m2" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: time.Now().Add(2 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) - // Verify it does appear in scheduled queries - messages, err = s.Messages("mytopic", model.SinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) + m = model.NewDefaultMessage("mytopic", "message with external attachment") + m.ID = "m3" + m.SequenceID = "m3" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Expires: 0, // Unknown! + URL: "https://somedomain.com/car.jpg", + } + require.Nil(t, s.AddMessage(m)) - // Mark as published - require.Nil(t, s.MarkPublished(m)) + m = model.NewDefaultMessage("mytopic2", "message with expired attachment") + m.ID = "m4" + m.SequenceID = "m4" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "expired-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: time.Now().Add(-1 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) - // Now it should appear in non-scheduled queries too - messages, err = s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "scheduled message", messages[0].Message) + ids, err := s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "m4", ids[0]) + }) } -func testExpireMessages(t *testing.T, s message.Store) { - // Add messages to two topics - m1 := model.NewDefaultMessage("topic1", "message 1") - m1.Expires = time.Now().Add(time.Hour).Unix() - m2 := model.NewDefaultMessage("topic1", "message 2") - m2.Expires = time.Now().Add(time.Hour).Unix() - m3 := model.NewDefaultMessage("topic2", "message 3") - m3.Expires = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(m1)) - require.Nil(t, s.AddMessage(m2)) - require.Nil(t, s.AddMessage(m3)) +func TestStore_Sender(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "mymessage") + m1.Sender = netip.MustParseAddr("1.2.3.4") + require.Nil(t, s.AddMessage(m1)) - // Verify all messages exist - messages, err := s.Messages("topic1", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - messages, err = s.Messages("topic2", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) + m2 := model.NewDefaultMessage("mytopic", "mymessage without sender") + require.Nil(t, s.AddMessage(m2)) - // Expire topic1 messages - require.Nil(t, s.ExpireMessages("topic1")) - - // topic1 messages should now be expired (expires set to past) - expiredIDs, err := s.MessagesExpired() - require.Nil(t, err) - require.Equal(t, 2, len(expiredIDs)) - sort.Strings(expiredIDs) - expectedIDs := []string{m1.ID, m2.ID} - sort.Strings(expectedIDs) - require.Equal(t, expectedIDs, expiredIDs) - - // topic2 should be unaffected - messages, err = s.Messages("topic2", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "message 3", messages[0].Message) + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) + require.Equal(t, messages[1].Sender, netip.Addr{}) + }) } -func testMarkAttachmentsDeleted(t *testing.T, s message.Store) { - // Add a message with an expired attachment (file needs cleanup) - m1 := model.NewDefaultMessage("mytopic", "old file") - m1.ID = "msg1" - m1.SequenceID = "msg1" - m1.Expires = time.Now().Add(time.Hour).Unix() - m1.Attachment = &model.Attachment{ - Name: "old.pdf", - Type: "application/pdf", - Size: 50000, - Expires: time.Now().Add(-time.Hour).Unix(), // Expired - URL: "https://ntfy.sh/file/old.pdf", - } - require.Nil(t, s.AddMessage(m1)) +func TestStore_DeleteScheduledBySequenceID(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Create a scheduled (unpublished) message + scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message") + scheduledMsg.ID = "scheduled1" + scheduledMsg.SequenceID = "seq123" + scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled + require.Nil(t, s.AddMessage(scheduledMsg)) - // Add a message with another expired attachment - m2 := model.NewDefaultMessage("mytopic", "another old file") - m2.ID = "msg2" - m2.SequenceID = "msg2" - m2.Expires = time.Now().Add(time.Hour).Unix() - m2.Attachment = &model.Attachment{ - Name: "another.pdf", - Type: "application/pdf", - Size: 30000, - Expires: time.Now().Add(-time.Hour).Unix(), // Expired - URL: "https://ntfy.sh/file/another.pdf", - } - require.Nil(t, s.AddMessage(m2)) + // Create a published message with different sequence ID + publishedMsg := model.NewDefaultMessage("mytopic", "published message") + publishedMsg.ID = "published1" + publishedMsg.SequenceID = "seq456" + publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published + require.Nil(t, s.AddMessage(publishedMsg)) - // Both should show as expired attachments needing cleanup - ids, err := s.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 2, len(ids)) + // Create a scheduled message in a different topic + otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled") + otherTopicMsg.ID = "other1" + otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg + otherTopicMsg.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(otherTopicMsg)) - // Mark msg1's attachment as deleted (file cleaned up) - require.Nil(t, s.MarkAttachmentsDeleted("msg1")) + // Verify all messages exist (including scheduled) + messages, err := s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) - // Now only msg2 should show as needing cleanup - ids, err = s.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 1, len(ids)) - require.Equal(t, "msg2", ids[0]) + messages, err = s.Messages("othertopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) - // Mark msg2 too - require.Nil(t, s.MarkAttachmentsDeleted("msg2")) + // Delete scheduled message by sequence ID and verify returned IDs + deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123") + require.Nil(t, err) + require.Equal(t, 1, len(deletedIDs)) + require.Equal(t, "scheduled1", deletedIDs[0]) - // No more expired attachments to clean up - ids, err = s.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 0, len(ids)) + // Verify scheduled message is deleted + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "published message", messages[0].Message) - // Messages themselves still exist - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) + // Verify other topic's message still exists (topic-scoped deletion) + messages, err = s.Messages("othertopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "other scheduled", messages[0].Message) + + // Deleting non-existent sequence ID should return empty list + deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent") + require.Nil(t, err) + require.Empty(t, deletedIDs) + + // Deleting published message should not affect it (only deletes unpublished) + deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456") + require.Nil(t, err) + require.Empty(t, deletedIDs) + + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "published message", messages[0].Message) + }) } -func testStats(t *testing.T, s message.Store) { - // Initial stats should be zero - messages, err := s.Stats() - require.Nil(t, err) - require.Equal(t, int64(0), messages) +func TestStore_MessageByID(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Add a message + m := model.NewDefaultMessage("mytopic", "some message") + m.Title = "some title" + m.Priority = 4 + m.Tags = []string{"tag1", "tag2"} + require.Nil(t, s.AddMessage(m)) - // Update stats - require.Nil(t, s.UpdateStats(42)) - messages, err = s.Stats() - require.Nil(t, err) - require.Equal(t, int64(42), messages) + // Retrieve by ID + retrieved, err := s.Message(m.ID) + require.Nil(t, err) + require.Equal(t, m.ID, retrieved.ID) + require.Equal(t, "mytopic", retrieved.Topic) + require.Equal(t, "some message", retrieved.Message) + require.Equal(t, "some title", retrieved.Title) + require.Equal(t, 4, retrieved.Priority) + require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags) - // Update again (overwrites) - require.Nil(t, s.UpdateStats(100)) - messages, err = s.Stats() - require.Nil(t, err) - require.Equal(t, int64(100), messages) + // Non-existent ID returns ErrMessageNotFound + _, err = s.Message("doesnotexist") + require.Equal(t, model.ErrMessageNotFound, err) + }) } -func testAddMessages(t *testing.T, s message.Store) { - // Batch add multiple messages - msgs := []*model.Message{ - model.NewDefaultMessage("mytopic", "batch 1"), - model.NewDefaultMessage("mytopic", "batch 2"), - model.NewDefaultMessage("othertopic", "batch 3"), - } - require.Nil(t, s.AddMessages(msgs)) +func TestStore_MarkPublished(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Add a scheduled message (future time -> unpublished) + m := model.NewDefaultMessage("mytopic", "scheduled message") + m.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m)) - // Verify all were inserted - messages, err := s.Messages("mytopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) + // Verify it does not appear in non-scheduled queries + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 0, len(messages)) - messages, err = s.Messages("othertopic", model.SinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "batch 3", messages[0].Message) + // Verify it does appear in scheduled queries + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) - // Empty batch should succeed - require.Nil(t, s.AddMessages([]*model.Message{})) + // Mark as published + require.Nil(t, s.MarkPublished(m)) - // Batch with invalid event type should fail - badMsgs := []*model.Message{ - model.NewKeepaliveMessage("mytopic"), - } - require.NotNil(t, s.AddMessages(badMsgs)) + // Now it should appear in non-scheduled queries too + messages, err = s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "scheduled message", messages[0].Message) + }) } -func testMessagesDue(t *testing.T, s message.Store) { - // Add a message scheduled in the past (i.e. it's due now) - m1 := model.NewDefaultMessage("mytopic", "due message") - m1.Time = time.Now().Add(-time.Second).Unix() - // Set expires in the future so it doesn't get pruned - m1.Expires = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(m1)) +func TestStore_ExpireMessages(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Add messages to two topics + m1 := model.NewDefaultMessage("topic1", "message 1") + m1.Expires = time.Now().Add(time.Hour).Unix() + m2 := model.NewDefaultMessage("topic1", "message 2") + m2.Expires = time.Now().Add(time.Hour).Unix() + m3 := model.NewDefaultMessage("topic2", "message 3") + m3.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) - // Add a message scheduled in the future (not due) - m2 := model.NewDefaultMessage("mytopic", "future message") - m2.Time = time.Now().Add(time.Hour).Unix() - require.Nil(t, s.AddMessage(m2)) + // Verify all messages exist + messages, err := s.Messages("topic1", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + messages, err = s.Messages("topic2", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) - // Mark m1 as published so it won't be "due" - // (MessagesDue returns unpublished messages whose time <= now) - // m1 is auto-published (time <= now), so it should not be due - // m2 is unpublished (time in future), not due yet - due, err := s.MessagesDue() - require.Nil(t, err) - require.Equal(t, 0, len(due)) + // Expire topic1 messages + require.Nil(t, s.ExpireMessages("topic1")) - // Add a message that was explicitly scheduled in the past but time has "arrived" - // We need to manipulate the database to create a truly "due" message: - // a message with published=false and time <= now - m3 := model.NewDefaultMessage("mytopic", "truly due message") - m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now - require.Nil(t, s.AddMessage(m3)) + // topic1 messages should now be expired (expires set to past) + expiredIDs, err := s.MessagesExpired() + require.Nil(t, err) + require.Equal(t, 2, len(expiredIDs)) + sort.Strings(expiredIDs) + expectedIDs := []string{m1.ID, m2.ID} + sort.Strings(expectedIDs) + require.Equal(t, expectedIDs, expiredIDs) - // Not due yet - due, err = s.MessagesDue() - require.Nil(t, err) - require.Equal(t, 0, len(due)) - - // Wait for it to become due - time.Sleep(3 * time.Second) - - due, err = s.MessagesDue() - require.Nil(t, err) - require.Equal(t, 1, len(due)) - require.Equal(t, "truly due message", due[0].Message) + // topic2 should be unaffected + messages, err = s.Messages("topic2", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "message 3", messages[0].Message) + }) } -func testMessageFieldRoundTrip(t *testing.T, s message.Store) { - // Create a message with all fields populated - m := model.NewDefaultMessage("mytopic", "hello world") - m.SequenceID = "custom_seq_id" - m.Title = "A Title" - m.Priority = 4 - m.Tags = []string{"warning", "srv01"} - m.Click = "https://example.com/click" - m.Icon = "https://example.com/icon.png" - m.Actions = []*model.Action{ - { - ID: "action1", - Action: "view", - Label: "Open Site", - URL: "https://example.com", - Clear: true, - }, - { - ID: "action2", - Action: "http", - Label: "Call Webhook", - URL: "https://example.com/hook", - Method: "PUT", - Headers: map[string]string{"X-Token": "secret"}, - Body: `{"key":"value"}`, - }, - } - m.ContentType = "text/markdown" - m.Encoding = "base64" - m.Sender = netip.MustParseAddr("9.8.7.6") - m.User = "u_TestUser123" - require.Nil(t, s.AddMessage(m)) +func TestStore_MarkAttachmentsDeleted(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Add a message with an expired attachment (file needs cleanup) + m1 := model.NewDefaultMessage("mytopic", "old file") + m1.ID = "msg1" + m1.SequenceID = "msg1" + m1.Expires = time.Now().Add(time.Hour).Unix() + m1.Attachment = &model.Attachment{ + Name: "old.pdf", + Type: "application/pdf", + Size: 50000, + Expires: time.Now().Add(-time.Hour).Unix(), // Expired + URL: "https://ntfy.sh/file/old.pdf", + } + require.Nil(t, s.AddMessage(m1)) - // Retrieve and verify every field - retrieved, err := s.Message(m.ID) - require.Nil(t, err) + // Add a message with another expired attachment + m2 := model.NewDefaultMessage("mytopic", "another old file") + m2.ID = "msg2" + m2.SequenceID = "msg2" + m2.Expires = time.Now().Add(time.Hour).Unix() + m2.Attachment = &model.Attachment{ + Name: "another.pdf", + Type: "application/pdf", + Size: 30000, + Expires: time.Now().Add(-time.Hour).Unix(), // Expired + URL: "https://ntfy.sh/file/another.pdf", + } + require.Nil(t, s.AddMessage(m2)) - require.Equal(t, m.ID, retrieved.ID) - require.Equal(t, "custom_seq_id", retrieved.SequenceID) - require.Equal(t, m.Time, retrieved.Time) - require.Equal(t, m.Expires, retrieved.Expires) - require.Equal(t, model.MessageEvent, retrieved.Event) - require.Equal(t, "mytopic", retrieved.Topic) - require.Equal(t, "hello world", retrieved.Message) - require.Equal(t, "A Title", retrieved.Title) - require.Equal(t, 4, retrieved.Priority) - require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags) - require.Equal(t, "https://example.com/click", retrieved.Click) - require.Equal(t, "https://example.com/icon.png", retrieved.Icon) - require.Equal(t, "text/markdown", retrieved.ContentType) - require.Equal(t, "base64", retrieved.Encoding) - require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender) - require.Equal(t, "u_TestUser123", retrieved.User) + // Both should show as expired attachments needing cleanup + ids, err := s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 2, len(ids)) - // Verify actions round-trip - require.Equal(t, 2, len(retrieved.Actions)) + // Mark msg1's attachment as deleted (file cleaned up) + require.Nil(t, s.MarkAttachmentsDeleted("msg1")) - require.Equal(t, "action1", retrieved.Actions[0].ID) - require.Equal(t, "view", retrieved.Actions[0].Action) - require.Equal(t, "Open Site", retrieved.Actions[0].Label) - require.Equal(t, "https://example.com", retrieved.Actions[0].URL) - require.Equal(t, true, retrieved.Actions[0].Clear) + // Now only msg2 should show as needing cleanup + ids, err = s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "msg2", ids[0]) - require.Equal(t, "action2", retrieved.Actions[1].ID) - require.Equal(t, "http", retrieved.Actions[1].Action) - require.Equal(t, "Call Webhook", retrieved.Actions[1].Label) - require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL) - require.Equal(t, "PUT", retrieved.Actions[1].Method) - require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"]) - require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body) + // Mark msg2 too + require.Nil(t, s.MarkAttachmentsDeleted("msg2")) + + // No more expired attachments to clean up + ids, err = s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 0, len(ids)) + + // Messages themselves still exist + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + }) +} + +func TestStore_Stats(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Initial stats should be zero + messages, err := s.Stats() + require.Nil(t, err) + require.Equal(t, int64(0), messages) + + // Update stats + require.Nil(t, s.UpdateStats(42)) + messages, err = s.Stats() + require.Nil(t, err) + require.Equal(t, int64(42), messages) + + // Update again (overwrites) + require.Nil(t, s.UpdateStats(100)) + messages, err = s.Stats() + require.Nil(t, err) + require.Equal(t, int64(100), messages) + }) +} + +func TestStore_AddMessages(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Batch add multiple messages + msgs := []*model.Message{ + model.NewDefaultMessage("mytopic", "batch 1"), + model.NewDefaultMessage("mytopic", "batch 2"), + model.NewDefaultMessage("othertopic", "batch 3"), + } + require.Nil(t, s.AddMessages(msgs)) + + // Verify all were inserted + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + messages, err = s.Messages("othertopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "batch 3", messages[0].Message) + + // Empty batch should succeed + require.Nil(t, s.AddMessages([]*model.Message{})) + + // Batch with invalid event type should fail + badMsgs := []*model.Message{ + model.NewKeepaliveMessage("mytopic"), + } + require.NotNil(t, s.AddMessages(badMsgs)) + }) +} + +func TestStore_MessagesDue(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Add a message scheduled in the past (i.e. it's due now) + m1 := model.NewDefaultMessage("mytopic", "due message") + m1.Time = time.Now().Add(-time.Second).Unix() + // Set expires in the future so it doesn't get pruned + m1.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m1)) + + // Add a message scheduled in the future (not due) + m2 := model.NewDefaultMessage("mytopic", "future message") + m2.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m2)) + + // Mark m1 as published so it won't be "due" + // (MessagesDue returns unpublished messages whose time <= now) + // m1 is auto-published (time <= now), so it should not be due + // m2 is unpublished (time in future), not due yet + due, err := s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 0, len(due)) + + // Add a message that was explicitly scheduled in the past but time has "arrived" + // We need to manipulate the database to create a truly "due" message: + // a message with published=false and time <= now + m3 := model.NewDefaultMessage("mytopic", "truly due message") + m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now + require.Nil(t, s.AddMessage(m3)) + + // Not due yet + due, err = s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 0, len(due)) + + // Wait for it to become due + time.Sleep(3 * time.Second) + + due, err = s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 1, len(due)) + require.Equal(t, "truly due message", due[0].Message) + }) +} + +func TestStore_MessageFieldRoundTrip(t *testing.T) { + forEachBackend(t, func(t *testing.T, s message.Store) { + // Create a message with all fields populated + m := model.NewDefaultMessage("mytopic", "hello world") + m.SequenceID = "custom_seq_id" + m.Title = "A Title" + m.Priority = 4 + m.Tags = []string{"warning", "srv01"} + m.Click = "https://example.com/click" + m.Icon = "https://example.com/icon.png" + m.Actions = []*model.Action{ + { + ID: "action1", + Action: "view", + Label: "Open Site", + URL: "https://example.com", + Clear: true, + }, + { + ID: "action2", + Action: "http", + Label: "Call Webhook", + URL: "https://example.com/hook", + Method: "PUT", + Headers: map[string]string{"X-Token": "secret"}, + Body: `{"key":"value"}`, + }, + } + m.ContentType = "text/markdown" + m.Encoding = "base64" + m.Sender = netip.MustParseAddr("9.8.7.6") + m.User = "u_TestUser123" + require.Nil(t, s.AddMessage(m)) + + // Retrieve and verify every field + retrieved, err := s.Message(m.ID) + require.Nil(t, err) + + require.Equal(t, m.ID, retrieved.ID) + require.Equal(t, "custom_seq_id", retrieved.SequenceID) + require.Equal(t, m.Time, retrieved.Time) + require.Equal(t, m.Expires, retrieved.Expires) + require.Equal(t, model.MessageEvent, retrieved.Event) + require.Equal(t, "mytopic", retrieved.Topic) + require.Equal(t, "hello world", retrieved.Message) + require.Equal(t, "A Title", retrieved.Title) + require.Equal(t, 4, retrieved.Priority) + require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags) + require.Equal(t, "https://example.com/click", retrieved.Click) + require.Equal(t, "https://example.com/icon.png", retrieved.Icon) + require.Equal(t, "text/markdown", retrieved.ContentType) + require.Equal(t, "base64", retrieved.Encoding) + require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender) + require.Equal(t, "u_TestUser123", retrieved.User) + + // Verify actions round-trip + require.Equal(t, 2, len(retrieved.Actions)) + + require.Equal(t, "action1", retrieved.Actions[0].ID) + require.Equal(t, "view", retrieved.Actions[0].Action) + require.Equal(t, "Open Site", retrieved.Actions[0].Label) + require.Equal(t, "https://example.com", retrieved.Actions[0].URL) + require.Equal(t, true, retrieved.Actions[0].Clear) + + require.Equal(t, "action2", retrieved.Actions[1].ID) + require.Equal(t, "http", retrieved.Actions[1].Action) + require.Equal(t, "Call Webhook", retrieved.Actions[1].Label) + require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL) + require.Equal(t, "PUT", retrieved.Actions[1].Method) + require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"]) + require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body) + }) } diff --git a/user/store_postgres_test.go b/user/store_postgres_test.go deleted file mode 100644 index e3902ea2..00000000 --- a/user/store_postgres_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package user_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - dbtest "heckel.io/ntfy/v2/db/test" - "heckel.io/ntfy/v2/user" -) - -func newTestPostgresStore(t *testing.T) user.Store { - testDB := dbtest.CreateTestPostgres(t) - store, err := user.NewPostgresStore(testDB) - require.Nil(t, err) - return store -} - -func TestPostgresStoreAddUser(t *testing.T) { - testStoreAddUser(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAddUserAlreadyExists(t *testing.T) { - testStoreAddUserAlreadyExists(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveUser(t *testing.T) { - testStoreRemoveUser(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUserByID(t *testing.T) { - testStoreUserByID(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUserByToken(t *testing.T) { - testStoreUserByToken(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUserByStripeCustomer(t *testing.T) { - testStoreUserByStripeCustomer(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUsers(t *testing.T) { - testStoreUsers(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUsersCount(t *testing.T) { - testStoreUsersCount(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreChangePassword(t *testing.T) { - testStoreChangePassword(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreChangeRole(t *testing.T) { - testStoreChangeRole(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokens(t *testing.T) { - testStoreTokens(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokenChangeLabel(t *testing.T) { - testStoreTokenChangeLabel(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokenRemove(t *testing.T) { - testStoreTokenRemove(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokenRemoveExpired(t *testing.T) { - testStoreTokenRemoveExpired(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokenRemoveExcess(t *testing.T) { - testStoreTokenRemoveExcess(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) { - testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAllowAccess(t *testing.T) { - testStoreAllowAccess(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAllowAccessReadOnly(t *testing.T) { - testStoreAllowAccessReadOnly(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreResetAccess(t *testing.T) { - testStoreResetAccess(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreResetAccessAll(t *testing.T) { - testStoreResetAccessAll(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) { - testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) { - testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) { - testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreReservations(t *testing.T) { - testStoreReservations(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreReservationsCount(t *testing.T) { - testStoreReservationsCount(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreHasReservation(t *testing.T) { - testStoreHasReservation(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreReservationOwner(t *testing.T) { - testStoreReservationOwner(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTiers(t *testing.T) { - testStoreTiers(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTierUpdate(t *testing.T) { - testStoreTierUpdate(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTierRemove(t *testing.T) { - testStoreTierRemove(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreTierByStripePrice(t *testing.T) { - testStoreTierByStripePrice(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreChangeTier(t *testing.T) { - testStoreChangeTier(t, newTestPostgresStore(t)) -} - -func TestPostgresStorePhoneNumbers(t *testing.T) { - testStorePhoneNumbers(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreChangeSettings(t *testing.T) { - testStoreChangeSettings(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreChangeBilling(t *testing.T) { - testStoreChangeBilling(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUpdateStats(t *testing.T) { - testStoreUpdateStats(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreResetStats(t *testing.T) { - testStoreResetStats(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreMarkUserRemoved(t *testing.T) { - testStoreMarkUserRemoved(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveDeletedUsers(t *testing.T) { - testStoreRemoveDeletedUsers(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreAllGrants(t *testing.T) { - testStoreAllGrants(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreOtherAccessCount(t *testing.T) { - testStoreOtherAccessCount(t, newTestPostgresStore(t)) -} diff --git a/user/store_sqlite_test.go b/user/store_sqlite_test.go deleted file mode 100644 index 9f9e5a71..00000000 --- a/user/store_sqlite_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package user_test - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/user" -) - -func newTestSQLiteStore(t *testing.T) user.Store { - store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "") - require.Nil(t, err) - t.Cleanup(func() { store.Close() }) - return store -} - -func TestSQLiteStoreAddUser(t *testing.T) { - testStoreAddUser(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) { - testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveUser(t *testing.T) { - testStoreRemoveUser(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUserByID(t *testing.T) { - testStoreUserByID(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUserByToken(t *testing.T) { - testStoreUserByToken(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUserByStripeCustomer(t *testing.T) { - testStoreUserByStripeCustomer(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUsers(t *testing.T) { - testStoreUsers(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUsersCount(t *testing.T) { - testStoreUsersCount(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreChangePassword(t *testing.T) { - testStoreChangePassword(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreChangeRole(t *testing.T) { - testStoreChangeRole(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokens(t *testing.T) { - testStoreTokens(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokenChangeLabel(t *testing.T) { - testStoreTokenChangeLabel(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokenRemove(t *testing.T) { - testStoreTokenRemove(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokenRemoveExpired(t *testing.T) { - testStoreTokenRemoveExpired(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokenRemoveExcess(t *testing.T) { - testStoreTokenRemoveExcess(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) { - testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAllowAccess(t *testing.T) { - testStoreAllowAccess(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) { - testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreResetAccess(t *testing.T) { - testStoreResetAccess(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreResetAccessAll(t *testing.T) { - testStoreResetAccessAll(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) { - testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) { - testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) { - testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreReservations(t *testing.T) { - testStoreReservations(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreReservationsCount(t *testing.T) { - testStoreReservationsCount(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreHasReservation(t *testing.T) { - testStoreHasReservation(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreReservationOwner(t *testing.T) { - testStoreReservationOwner(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTiers(t *testing.T) { - testStoreTiers(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTierUpdate(t *testing.T) { - testStoreTierUpdate(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTierRemove(t *testing.T) { - testStoreTierRemove(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreTierByStripePrice(t *testing.T) { - testStoreTierByStripePrice(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreChangeTier(t *testing.T) { - testStoreChangeTier(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStorePhoneNumbers(t *testing.T) { - testStorePhoneNumbers(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreChangeSettings(t *testing.T) { - testStoreChangeSettings(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreChangeBilling(t *testing.T) { - testStoreChangeBilling(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUpdateStats(t *testing.T) { - testStoreUpdateStats(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreResetStats(t *testing.T) { - testStoreResetStats(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreMarkUserRemoved(t *testing.T) { - testStoreMarkUserRemoved(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) { - testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreAllGrants(t *testing.T) { - testStoreAllGrants(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreOtherAccessCount(t *testing.T) { - testStoreOtherAccessCount(t, newTestSQLiteStore(t)) -} diff --git a/user/store_test.go b/user/store_test.go index da012bf2..7dc3ef38 100644 --- a/user/store_test.go +++ b/user/store_test.go @@ -2,618 +2,717 @@ package user_test import ( "net/netip" + "path/filepath" "testing" "time" "github.com/stretchr/testify/require" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/user" ) -func testStoreAddUser(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - require.Equal(t, user.RoleUser, u.Role) - require.False(t, u.Provisioned) - require.NotEmpty(t, u.ID) - require.NotEmpty(t, u.SyncTopic) -} - -func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false)) -} - -func testStoreRemoveUser(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - - require.Nil(t, store.RemoveUser("phil")) - _, err = store.User("phil") - require.Equal(t, user.ErrUserNotFound, err) -} - -func testStoreUserByID(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false)) - u, err := store.User("phil") - require.Nil(t, err) - - u2, err := store.UserByID(u.ID) - require.Nil(t, err) - require.Equal(t, u.Name, u2.Name) - require.Equal(t, u.ID, u2.ID) -} - -func testStoreUserByToken(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false) - require.Nil(t, err) - require.Equal(t, "tk_test123", tk.Value) - - u2, err := store.UserByToken(tk.Value) - require.Nil(t, err) - require.Equal(t, "phil", u2.Name) -} - -func testStoreUserByStripeCustomer(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.ChangeBilling("phil", &user.Billing{ - StripeCustomerID: "cus_test123", - StripeSubscriptionID: "sub_test123", - })) - - u, err := store.UserByStripeCustomer("cus_test123") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - require.Equal(t, "cus_test123", u.Billing.StripeCustomerID) -} - -func testStoreUsers(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false)) - - users, err := store.Users() - require.Nil(t, err) - require.True(t, len(users) >= 3) // phil, ben, and the everyone user -} - -func testStoreUsersCount(t *testing.T, store user.Store) { - count, err := store.UsersCount() - require.Nil(t, err) - require.True(t, count >= 1) // At least the everyone user - - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - count2, err := store.UsersCount() - require.Nil(t, err) - require.Equal(t, count+1, count2) -} - -func testStoreChangePassword(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "philhash", u.Hash) - - require.Nil(t, store.ChangePassword("phil", "newhash")) - u, err = store.User("phil") - require.Nil(t, err) - require.Equal(t, "newhash", u.Hash) -} - -func testStoreChangeRole(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, user.RoleUser, u.Role) - - require.Nil(t, store.ChangeRole("phil", user.RoleAdmin)) - u, err = store.User("phil") - require.Nil(t, err) - require.Equal(t, user.RoleAdmin, u.Role) -} - -func testStoreTokens(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - now := time.Now() - expires := now.Add(24 * time.Hour) - origin := netip.MustParseAddr("9.9.9.9") - - tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false) - require.Nil(t, err) - require.Equal(t, "tk_abc", tk.Value) - require.Equal(t, "my token", tk.Label) - - // Get single token - tk2, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, "tk_abc", tk2.Value) - require.Equal(t, "my token", tk2.Label) - - // Get all tokens - tokens, err := store.Tokens(u.ID) - require.Nil(t, err) - require.Len(t, tokens, 1) - require.Equal(t, "tk_abc", tokens[0].Value) - - // Token count - count, err := store.TokenCount(u.ID) - require.Nil(t, err) - require.Equal(t, 1, count) -} - -func testStoreTokenChangeLabel(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - _, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) - require.Nil(t, err) - - require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label")) - tk, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, "new label", tk.Label) -} - -func testStoreTokenRemove(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) - require.Nil(t, err) - - require.Nil(t, store.RemoveToken(u.ID, "tk_abc")) - _, err = store.Token(u.ID, "tk_abc") - require.Equal(t, user.ErrTokenNotFound, err) -} - -func testStoreTokenRemoveExpired(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - // Create expired token and active token - _, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false) - require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) - require.Nil(t, err) - - require.Nil(t, store.RemoveExpiredTokens()) - - // Expired token should be gone - _, err = store.Token(u.ID, "tk_expired") - require.Equal(t, user.ErrTokenNotFound, err) - - // Active token should still exist - tk, err := store.Token(u.ID, "tk_active") - require.Nil(t, err) - require.Equal(t, "tk_active", tk.Value) -} - -func testStoreTokenRemoveExcess(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - // Create 3 tokens with increasing expiry - for i, name := range []string{"tk_a", "tk_b", "tk_c"} { - _, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false) +func forEachStoreBackend(t *testing.T, f func(t *testing.T, store user.Store)) { + t.Run("sqlite", func(t *testing.T) { + store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "") require.Nil(t, err) - } - - count, err := store.TokenCount(u.ID) - require.Nil(t, err) - require.Equal(t, 3, count) - - // Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c) - require.Nil(t, store.RemoveExcessTokens(u.ID, 2)) - - count, err = store.TokenCount(u.ID) - require.Nil(t, err) - require.Equal(t, 2, count) - - // tk_a should be removed (earliest expiry) - _, err = store.Token(u.ID, "tk_a") - require.Equal(t, user.ErrTokenNotFound, err) - - // tk_b and tk_c should remain - _, err = store.Token(u.ID, "tk_b") - require.Nil(t, err) - _, err = store.Token(u.ID, "tk_c") - require.Nil(t, err) + t.Cleanup(func() { store.Close() }) + f(t, store) + }) + t.Run("postgres", func(t *testing.T) { + testDB := dbtest.CreateTestPostgres(t) + store, err := user.NewPostgresStore(testDB) + require.Nil(t, err) + f(t, store) + }) } -func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) - require.Nil(t, err) - - newTime := time.Now().Add(5 * time.Minute) - newOrigin := netip.MustParseAddr("5.5.5.5") - require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin)) - - tk, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, newTime.Unix(), tk.LastAccess.Unix()) - require.Equal(t, newOrigin, tk.LastOrigin) +func TestStoreAddUser(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) + require.Equal(t, user.RoleUser, u.Role) + require.False(t, u.Provisioned) + require.NotEmpty(t, u.ID) + require.NotEmpty(t, u.SyncTopic) + }) } -func testStoreAllowAccess(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.Equal(t, "mytopic", grants[0].TopicPattern) - require.True(t, grants[0].Permission.IsReadWrite()) +func TestStoreAddUserAlreadyExists(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false)) + }) } -func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) +func TestStoreRemoveUser(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) - require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false)) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.True(t, grants[0].Permission.IsRead()) - require.False(t, grants[0].Permission.IsWrite()) + require.Nil(t, store.RemoveUser("phil")) + _, err = store.User("phil") + require.Equal(t, user.ErrUserNotFound, err) + }) } -func testStoreResetAccess(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) +func TestStoreUserByID(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false)) + u, err := store.User("phil") + require.Nil(t, err) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 2) - - require.Nil(t, store.ResetAccess("phil", "topic1")) - grants, err = store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.Equal(t, "topic2", grants[0].TopicPattern) + u2, err := store.UserByID(u.ID) + require.Nil(t, err) + require.Equal(t, u.Name, u2.Name) + require.Equal(t, u.ID, u2.ID) + }) } -func testStoreResetAccessAll(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) +func TestStoreUserByToken(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - require.Nil(t, store.ResetAccess("phil", "")) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 0) + tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false) + require.Nil(t, err) + require.Equal(t, "tk_test123", tk.Value) + + u2, err := store.UserByToken(tk.Value) + require.Nil(t, err) + require.Equal(t, "phil", u2.Name) + }) } -func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) +func TestStoreUserByStripeCustomer(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.ChangeBilling("phil", &user.Billing{ + StripeCustomerID: "cus_test123", + StripeSubscriptionID: "sub_test123", + })) - read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic") - require.Nil(t, err) - require.True(t, found) - require.True(t, read) - require.True(t, write) + u, err := store.UserByStripeCustomer("cus_test123") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) + require.Equal(t, "cus_test123", u.Billing.StripeCustomerID) + }) } -func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) +func TestStoreUsers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false)) - _, _, found, err := store.AuthorizeTopicAccess("phil", "other") - require.Nil(t, err) - require.False(t, found) + users, err := store.Users() + require.Nil(t, err) + require.True(t, len(users) >= 3) // phil, ben, and the everyone user + }) } -func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false)) +func TestStoreUsersCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + count, err := store.UsersCount() + require.Nil(t, err) + require.True(t, count >= 1) // At least the everyone user - read, write, found, err := store.AuthorizeTopicAccess("phil", "secret") - require.Nil(t, err) - require.True(t, found) - require.False(t, read) - require.False(t, write) + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + count2, err := store.UsersCount() + require.Nil(t, err) + require.Equal(t, count+1, count2) + }) } -func testStoreReservations(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) - require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false)) +func TestStoreChangePassword(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, "philhash", u.Hash) - reservations, err := store.Reservations("phil") - require.Nil(t, err) - require.Len(t, reservations, 1) - require.Equal(t, "mytopic", reservations[0].Topic) - require.True(t, reservations[0].Owner.IsReadWrite()) - require.True(t, reservations[0].Everyone.IsRead()) - require.False(t, reservations[0].Everyone.IsWrite()) + require.Nil(t, store.ChangePassword("phil", "newhash")) + u, err = store.User("phil") + require.Nil(t, err) + require.Equal(t, "newhash", u.Hash) + }) } -func testStoreReservationsCount(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false)) +func TestStoreChangeRole(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, user.RoleUser, u.Role) - count, err := store.ReservationsCount("phil") - require.Nil(t, err) - require.Equal(t, int64(2), count) + require.Nil(t, store.ChangeRole("phil", user.RoleAdmin)) + u, err = store.User("phil") + require.Nil(t, err) + require.Equal(t, user.RoleAdmin, u.Role) + }) } -func testStoreHasReservation(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) +func TestStoreTokens(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - has, err := store.HasReservation("phil", "mytopic") - require.Nil(t, err) - require.True(t, has) + now := time.Now() + expires := now.Add(24 * time.Hour) + origin := netip.MustParseAddr("9.9.9.9") - has, err = store.HasReservation("phil", "other") - require.Nil(t, err) - require.False(t, has) + tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false) + require.Nil(t, err) + require.Equal(t, "tk_abc", tk.Value) + require.Equal(t, "my token", tk.Label) + + // Get single token + tk2, err := store.Token(u.ID, "tk_abc") + require.Nil(t, err) + require.Equal(t, "tk_abc", tk2.Value) + require.Equal(t, "my token", tk2.Label) + + // Get all tokens + tokens, err := store.Tokens(u.ID) + require.Nil(t, err) + require.Len(t, tokens, 1) + require.Equal(t, "tk_abc", tokens[0].Value) + + // Token count + count, err := store.TokenCount(u.ID) + require.Nil(t, err) + require.Equal(t, 1, count) + }) } -func testStoreReservationOwner(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) +func TestStoreTokenChangeLabel(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - owner, err := store.ReservationOwner("mytopic") - require.Nil(t, err) - require.NotEmpty(t, owner) // Returns the user ID + _, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + require.Nil(t, err) - owner, err = store.ReservationOwner("unowned") - require.Nil(t, err) - require.Empty(t, owner) + require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label")) + tk, err := store.Token(u.ID, "tk_abc") + require.Nil(t, err) + require.Equal(t, "new label", tk.Label) + }) } -func testStoreTiers(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - MessageLimit: 5000, - MessageExpiryDuration: 24 * time.Hour, - EmailLimit: 100, - CallLimit: 10, - ReservationLimit: 20, - AttachmentFileSizeLimit: 10 * 1024 * 1024, - AttachmentTotalSizeLimit: 100 * 1024 * 1024, - AttachmentExpiryDuration: 48 * time.Hour, - AttachmentBandwidthLimit: 500 * 1024 * 1024, - } - require.Nil(t, store.AddTier(tier)) +func TestStoreTokenRemove(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - // Get by code - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "ti_test", t2.ID) - require.Equal(t, "pro", t2.Code) - require.Equal(t, "Pro", t2.Name) - require.Equal(t, int64(5000), t2.MessageLimit) - require.Equal(t, int64(100), t2.EmailLimit) - require.Equal(t, int64(10), t2.CallLimit) - require.Equal(t, int64(20), t2.ReservationLimit) + _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + require.Nil(t, err) - // List all tiers - tiers, err := store.Tiers() - require.Nil(t, err) - require.Len(t, tiers, 1) - require.Equal(t, "pro", tiers[0].Code) + require.Nil(t, store.RemoveToken(u.ID, "tk_abc")) + _, err = store.Token(u.ID, "tk_abc") + require.Equal(t, user.ErrTokenNotFound, err) + }) } -func testStoreTierUpdate(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) +func TestStoreTokenRemoveExpired(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - tier.Name = "Professional" - tier.MessageLimit = 9999 - require.Nil(t, store.UpdateTier(tier)) + // Create expired token and active token + _, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false) + require.Nil(t, err) + _, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + require.Nil(t, err) - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "Professional", t2.Name) - require.Equal(t, int64(9999), t2.MessageLimit) + require.Nil(t, store.RemoveExpiredTokens()) + + // Expired token should be gone + _, err = store.Token(u.ID, "tk_expired") + require.Equal(t, user.ErrTokenNotFound, err) + + // Active token should still exist + tk, err := store.Token(u.ID, "tk_active") + require.Nil(t, err) + require.Equal(t, "tk_active", tk.Value) + }) } -func testStoreTierRemove(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) +func TestStoreTokenRemoveExcess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "pro", t2.Code) + // Create 3 tokens with increasing expiry + for i, name := range []string{"tk_a", "tk_b", "tk_c"} { + _, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false) + require.Nil(t, err) + } - require.Nil(t, store.RemoveTier("pro")) - _, err = store.Tier("pro") - require.Equal(t, user.ErrTierNotFound, err) + count, err := store.TokenCount(u.ID) + require.Nil(t, err) + require.Equal(t, 3, count) + + // Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c) + require.Nil(t, store.RemoveExcessTokens(u.ID, 2)) + + count, err = store.TokenCount(u.ID) + require.Nil(t, err) + require.Equal(t, 2, count) + + // tk_a should be removed (earliest expiry) + _, err = store.Token(u.ID, "tk_a") + require.Equal(t, user.ErrTokenNotFound, err) + + // tk_b and tk_c should remain + _, err = store.Token(u.ID, "tk_b") + require.Nil(t, err) + _, err = store.Token(u.ID, "tk_c") + require.Nil(t, err) + }) } -func testStoreTierByStripePrice(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - StripeMonthlyPriceID: "price_monthly", - StripeYearlyPriceID: "price_yearly", - } - require.Nil(t, store.AddTier(tier)) +func TestStoreTokenUpdateLastAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) - t2, err := store.TierByStripePrice("price_monthly") - require.Nil(t, err) - require.Equal(t, "pro", t2.Code) + _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + require.Nil(t, err) - t3, err := store.TierByStripePrice("price_yearly") - require.Nil(t, err) - require.Equal(t, "pro", t3.Code) + newTime := time.Now().Add(5 * time.Minute) + newOrigin := netip.MustParseAddr("5.5.5.5") + require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin)) + + tk, err := store.Token(u.ID, "tk_abc") + require.Nil(t, err) + require.Equal(t, newTime.Unix(), tk.LastAccess.Unix()) + require.Equal(t, newOrigin, tk.LastOrigin) + }) } -func testStoreChangeTier(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.ChangeTier("phil", "pro")) +func TestStoreAllowAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.NotNil(t, u.Tier) - require.Equal(t, "pro", u.Tier.Code) + require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) + grants, err := store.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.Equal(t, "mytopic", grants[0].TopicPattern) + require.True(t, grants[0].Permission.IsReadWrite()) + }) } -func testStorePhoneNumbers(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreAllowAccessReadOnly(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890")) - require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321")) - - numbers, err := store.PhoneNumbers(u.ID) - require.Nil(t, err) - require.Len(t, numbers, 2) - - require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890")) - numbers, err = store.PhoneNumbers(u.ID) - require.Nil(t, err) - require.Len(t, numbers, 1) - require.Equal(t, "+0987654321", numbers[0]) + require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false)) + grants, err := store.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.True(t, grants[0].Permission.IsRead()) + require.False(t, grants[0].Permission.IsWrite()) + }) } -func testStoreChangeSettings(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreResetAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) + require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) - lang := "de" - prefs := &user.Prefs{Language: &lang} - require.Nil(t, store.ChangeSettings(u.ID, prefs)) + grants, err := store.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 2) - u2, err := store.User("phil") - require.Nil(t, err) - require.NotNil(t, u2.Prefs) - require.Equal(t, "de", *u2.Prefs.Language) + require.Nil(t, store.ResetAccess("phil", "topic1")) + grants, err = store.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.Equal(t, "topic2", grants[0].TopicPattern) + }) } -func testStoreChangeBilling(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) +func TestStoreResetAccessAll(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) + require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) - billing := &user.Billing{ - StripeCustomerID: "cus_123", - StripeSubscriptionID: "sub_456", - } - require.Nil(t, store.ChangeBilling("phil", billing)) - - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "cus_123", u.Billing.StripeCustomerID) - require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID) + require.Nil(t, store.ResetAccess("phil", "")) + grants, err := store.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 0) + }) } -func testStoreUpdateStats(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreAuthorizeTopicAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) - stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1} - require.Nil(t, store.UpdateStats(u.ID, stats)) - - u2, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, int64(42), u2.Stats.Messages) - require.Equal(t, int64(3), u2.Stats.Emails) - require.Equal(t, int64(1), u2.Stats.Calls) + read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic") + require.Nil(t, err) + require.True(t, found) + require.True(t, read) + require.True(t, write) + }) } -func testStoreResetStats(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreAuthorizeTopicAccessNotFound(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1})) - require.Nil(t, store.ResetStats()) - - u2, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, int64(0), u2.Stats.Messages) - require.Equal(t, int64(0), u2.Stats.Emails) - require.Equal(t, int64(0), u2.Stats.Calls) + _, _, found, err := store.AuthorizeTopicAccess("phil", "other") + require.Nil(t, err) + require.False(t, found) + }) } -func testStoreMarkUserRemoved(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false)) - require.Nil(t, store.MarkUserRemoved(u.ID)) - - u2, err := store.User("phil") - require.Nil(t, err) - require.True(t, u2.Deleted) + read, write, found, err := store.AuthorizeTopicAccess("phil", "secret") + require.Nil(t, err) + require.True(t, found) + require.False(t, read) + require.False(t, write) + }) } -func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) +func TestStoreReservations(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) + require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false)) - require.Nil(t, store.MarkUserRemoved(u.ID)) - - // RemoveDeletedUsers only removes users past the hard-delete duration (7 days). - // Immediately after marking, the user should still exist. - require.Nil(t, store.RemoveDeletedUsers()) - u2, err := store.User("phil") - require.Nil(t, err) - require.True(t, u2.Deleted) + reservations, err := store.Reservations("phil") + require.Nil(t, err) + require.Len(t, reservations, 1) + require.Equal(t, "mytopic", reservations[0].Topic) + require.True(t, reservations[0].Owner.IsReadWrite()) + require.True(t, reservations[0].Everyone.IsRead()) + require.False(t, reservations[0].Everyone.IsWrite()) + }) } -func testStoreAllGrants(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) - phil, err := store.User("phil") - require.Nil(t, err) - ben, err := store.User("ben") - require.Nil(t, err) +func TestStoreReservationsCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false)) + require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false)) - - grants, err := store.AllGrants() - require.Nil(t, err) - require.Contains(t, grants, phil.ID) - require.Contains(t, grants, ben.ID) + count, err := store.ReservationsCount("phil") + require.Nil(t, err) + require.Equal(t, int64(2), count) + }) } -func testStoreOtherAccessCount(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false)) +func TestStoreHasReservation(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) - count, err := store.OtherAccessCount("phil", "mytopic") - require.Nil(t, err) - require.Equal(t, 1, count) + has, err := store.HasReservation("phil", "mytopic") + require.Nil(t, err) + require.True(t, has) + + has, err = store.HasReservation("phil", "other") + require.Nil(t, err) + require.False(t, has) + }) +} + +func TestStoreReservationOwner(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) + + owner, err := store.ReservationOwner("mytopic") + require.Nil(t, err) + require.NotEmpty(t, owner) // Returns the user ID + + owner, err = store.ReservationOwner("unowned") + require.Nil(t, err) + require.Empty(t, owner) + }) +} + +func TestStoreTiers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + tier := &user.Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + MessageLimit: 5000, + MessageExpiryDuration: 24 * time.Hour, + EmailLimit: 100, + CallLimit: 10, + ReservationLimit: 20, + AttachmentFileSizeLimit: 10 * 1024 * 1024, + AttachmentTotalSizeLimit: 100 * 1024 * 1024, + AttachmentExpiryDuration: 48 * time.Hour, + AttachmentBandwidthLimit: 500 * 1024 * 1024, + } + require.Nil(t, store.AddTier(tier)) + + // Get by code + t2, err := store.Tier("pro") + require.Nil(t, err) + require.Equal(t, "ti_test", t2.ID) + require.Equal(t, "pro", t2.Code) + require.Equal(t, "Pro", t2.Name) + require.Equal(t, int64(5000), t2.MessageLimit) + require.Equal(t, int64(100), t2.EmailLimit) + require.Equal(t, int64(10), t2.CallLimit) + require.Equal(t, int64(20), t2.ReservationLimit) + + // List all tiers + tiers, err := store.Tiers() + require.Nil(t, err) + require.Len(t, tiers, 1) + require.Equal(t, "pro", tiers[0].Code) + }) +} + +func TestStoreTierUpdate(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + tier := &user.Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, store.AddTier(tier)) + + tier.Name = "Professional" + tier.MessageLimit = 9999 + require.Nil(t, store.UpdateTier(tier)) + + t2, err := store.Tier("pro") + require.Nil(t, err) + require.Equal(t, "Professional", t2.Name) + require.Equal(t, int64(9999), t2.MessageLimit) + }) +} + +func TestStoreTierRemove(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + tier := &user.Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, store.AddTier(tier)) + + t2, err := store.Tier("pro") + require.Nil(t, err) + require.Equal(t, "pro", t2.Code) + + require.Nil(t, store.RemoveTier("pro")) + _, err = store.Tier("pro") + require.Equal(t, user.ErrTierNotFound, err) + }) +} + +func TestStoreTierByStripePrice(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + tier := &user.Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + StripeMonthlyPriceID: "price_monthly", + StripeYearlyPriceID: "price_yearly", + } + require.Nil(t, store.AddTier(tier)) + + t2, err := store.TierByStripePrice("price_monthly") + require.Nil(t, err) + require.Equal(t, "pro", t2.Code) + + t3, err := store.TierByStripePrice("price_yearly") + require.Nil(t, err) + require.Equal(t, "pro", t3.Code) + }) +} + +func TestStoreChangeTier(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + tier := &user.Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, store.AddTier(tier)) + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.ChangeTier("phil", "pro")) + + u, err := store.User("phil") + require.Nil(t, err) + require.NotNil(t, u.Tier) + require.Equal(t, "pro", u.Tier.Code) + }) +} + +func TestStorePhoneNumbers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890")) + require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321")) + + numbers, err := store.PhoneNumbers(u.ID) + require.Nil(t, err) + require.Len(t, numbers, 2) + + require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890")) + numbers, err = store.PhoneNumbers(u.ID) + require.Nil(t, err) + require.Len(t, numbers, 1) + require.Equal(t, "+0987654321", numbers[0]) + }) +} + +func TestStoreChangeSettings(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + lang := "de" + prefs := &user.Prefs{Language: &lang} + require.Nil(t, store.ChangeSettings(u.ID, prefs)) + + u2, err := store.User("phil") + require.Nil(t, err) + require.NotNil(t, u2.Prefs) + require.Equal(t, "de", *u2.Prefs.Language) + }) +} + +func TestStoreChangeBilling(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + + billing := &user.Billing{ + StripeCustomerID: "cus_123", + StripeSubscriptionID: "sub_456", + } + require.Nil(t, store.ChangeBilling("phil", billing)) + + u, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, "cus_123", u.Billing.StripeCustomerID) + require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID) + }) +} + +func TestStoreUpdateStats(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1} + require.Nil(t, store.UpdateStats(u.ID, stats)) + + u2, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, int64(42), u2.Stats.Messages) + require.Equal(t, int64(3), u2.Stats.Emails) + require.Equal(t, int64(1), u2.Stats.Calls) + }) +} + +func TestStoreResetStats(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1})) + require.Nil(t, store.ResetStats()) + + u2, err := store.User("phil") + require.Nil(t, err) + require.Equal(t, int64(0), u2.Stats.Messages) + require.Equal(t, int64(0), u2.Stats.Emails) + require.Equal(t, int64(0), u2.Stats.Calls) + }) +} + +func TestStoreMarkUserRemoved(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + require.Nil(t, store.MarkUserRemoved(u.ID)) + + u2, err := store.User("phil") + require.Nil(t, err) + require.True(t, u2.Deleted) + }) +} + +func TestStoreRemoveDeletedUsers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + u, err := store.User("phil") + require.Nil(t, err) + + require.Nil(t, store.MarkUserRemoved(u.ID)) + + // RemoveDeletedUsers only removes users past the hard-delete duration (7 days). + // Immediately after marking, the user should still exist. + require.Nil(t, store.RemoveDeletedUsers()) + u2, err := store.User("phil") + require.Nil(t, err) + require.True(t, u2.Deleted) + }) +} + +func TestStoreAllGrants(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) + phil, err := store.User("phil") + require.Nil(t, err) + ben, err := store.User("ben") + require.Nil(t, err) + + require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) + require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false)) + + grants, err := store.AllGrants() + require.Nil(t, err) + require.Contains(t, grants, phil.ID) + require.Contains(t, grants, ben.ID) + }) +} + +func TestStoreOtherAccessCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, store user.Store) { + require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) + require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) + require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false)) + + count, err := store.OtherAccessCount("phil", "mytopic") + require.Nil(t, err) + require.Equal(t, 1, count) + }) } diff --git a/webpush/store_postgres_test.go b/webpush/store_postgres_test.go deleted file mode 100644 index 1b2d43d2..00000000 --- a/webpush/store_postgres_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package webpush_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - dbtest "heckel.io/ntfy/v2/db/test" - "heckel.io/ntfy/v2/webpush" -) - -func newTestPostgresStore(t *testing.T) webpush.Store { - testDB := dbtest.CreateTestPostgres(t) - store, err := webpush.NewPostgresStore(testDB) - require.Nil(t, err) - return store -} - -func TestPostgresStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) { - testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) { - testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUpsertSubscriptionUpdateTopics(t *testing.T) { - testStoreUpsertSubscriptionUpdateTopics(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreUpsertSubscriptionUpdateFields(t *testing.T) { - testStoreUpsertSubscriptionUpdateFields(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveByUserIDMultiple(t *testing.T) { - testStoreRemoveByUserIDMultiple(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveByEndpoint(t *testing.T) { - testStoreRemoveByEndpoint(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveByUserID(t *testing.T) { - testStoreRemoveByUserID(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreRemoveByUserIDEmpty(t *testing.T) { - testStoreRemoveByUserIDEmpty(t, newTestPostgresStore(t)) -} - -func TestPostgresStoreExpiryWarningSent(t *testing.T) { - store := newTestPostgresStore(t) - testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt) -} - -func TestPostgresStoreExpiring(t *testing.T) { - store := newTestPostgresStore(t) - testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt) -} - -func TestPostgresStoreRemoveExpired(t *testing.T) { - store := newTestPostgresStore(t) - testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt) -} diff --git a/webpush/store_sqlite_test.go b/webpush/store_sqlite_test.go deleted file mode 100644 index 1d4087d1..00000000 --- a/webpush/store_sqlite_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package webpush_test - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/webpush" -) - -func newTestSQLiteStore(t *testing.T) webpush.Store { - store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "") - require.Nil(t, err) - t.Cleanup(func() { store.Close() }) - return store -} - -func TestSQLiteStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) { - testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) { - testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUpsertSubscriptionUpdateTopics(t *testing.T) { - testStoreUpsertSubscriptionUpdateTopics(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreUpsertSubscriptionUpdateFields(t *testing.T) { - testStoreUpsertSubscriptionUpdateFields(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveByUserIDMultiple(t *testing.T) { - testStoreRemoveByUserIDMultiple(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveByEndpoint(t *testing.T) { - testStoreRemoveByEndpoint(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveByUserID(t *testing.T) { - testStoreRemoveByUserID(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreRemoveByUserIDEmpty(t *testing.T) { - testStoreRemoveByUserIDEmpty(t, newTestSQLiteStore(t)) -} - -func TestSQLiteStoreExpiryWarningSent(t *testing.T) { - store := newTestSQLiteStore(t) - testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt) -} - -func TestSQLiteStoreExpiring(t *testing.T) { - store := newTestSQLiteStore(t) - testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt) -} - -func TestSQLiteStoreRemoveExpired(t *testing.T) { - store := newTestSQLiteStore(t) - testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt) -} diff --git a/webpush/store_test.go b/webpush/store_test.go index fe2e28f3..bb27bf5f 100644 --- a/webpush/store_test.go +++ b/webpush/store_test.go @@ -3,211 +3,250 @@ package webpush_test import ( "fmt" "net/netip" + "path/filepath" "testing" "time" "github.com/stretchr/testify/require" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/webpush" ) const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF" -func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpush.Store) { - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - - subs, err := store.SubscriptionsForTopic("test-topic") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) - require.Equal(t, subs[0].P256dh, "p256dh-key") - require.Equal(t, subs[0].Auth, "auth-key") - require.Equal(t, subs[0].UserID, "u_1234") - - subs2, err := store.SubscriptionsForTopic("mytopic") - require.Nil(t, err) - require.Len(t, subs2, 1) - require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) +func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) { + t.Run("sqlite", func(t *testing.T) { + store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "") + require.Nil(t, err) + t.Cleanup(func() { store.Close() }) + f(t, store) + }) + t.Run("postgres", func(t *testing.T) { + testDB := dbtest.CreateTestPostgres(t) + store, err := webpush.NewPostgresStore(testDB) + require.Nil(t, err) + f(t, store) + }) } -func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store webpush.Store) { - // Insert 10 subscriptions with the same IP address - for i := 0; i < 10; i++ { - endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) - require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - } +func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - // Another one for the same endpoint should be fine - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + subs, err := store.SubscriptionsForTopic("test-topic") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) + require.Equal(t, subs[0].P256dh, "p256dh-key") + require.Equal(t, subs[0].Auth, "auth-key") + require.Equal(t, subs[0].UserID, "u_1234") - // But with a different endpoint it should fail - require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - - // But with a different IP address it should be fine again - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) + subs2, err := store.SubscriptionsForTopic("mytopic") + require.Nil(t, err) + require.Len(t, subs2, 1) + require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) + }) } -func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store) { - // Insert subscription with two topics, and another with one topic - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) +func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert 10 subscriptions with the same IP address + for i := 0; i < 10; i++ { + endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) + require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + } - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 2) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint) + // Another one for the same endpoint should be fine + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - subs, err = store.SubscriptionsForTopic("topic2") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + // But with a different endpoint it should fail + require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - // Update the first subscription to have only one topic - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 2) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - - subs, err = store.SubscriptionsForTopic("topic2") - require.Nil(t, err) - require.Len(t, subs, 0) + // But with a different IP address it should be fine again + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) + }) } -func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store) { - // Insert a subscription - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) +func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics, and another with one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, "auth-key", subs[0].Auth) - require.Equal(t, "p256dh-key", subs[0].P256dh) - require.Equal(t, "u_1234", subs[0].UserID) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint) - // Re-upsert the same endpoint with different auth, p256dh, and userID - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "new-auth", "new-p256dh", "u_5678", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) - require.Equal(t, "new-auth", subs[0].Auth) - require.Equal(t, "new-p256dh", subs[0].P256dh) - require.Equal(t, "u_5678", subs[0].UserID) + // Update the first subscription to have only one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 0) + }) } -func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) { - // Insert two subscriptions for u_1234 and one for u_5678 - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"2", "auth-key", "p256dh-key", "u_5678", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) +func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert a subscription + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 3) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, "auth-key", subs[0].Auth) + require.Equal(t, "p256dh-key", subs[0].P256dh) + require.Equal(t, "u_1234", subs[0].UserID) - // Remove all subscriptions for u_1234 - require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) + // Re-upsert the same endpoint with different auth, p256dh, and userID + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "new-auth", "new-p256dh", "u_5678", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - // Only u_5678's subscription should remain - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint) - require.Equal(t, "u_5678", subs[0].UserID) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) + require.Equal(t, "new-auth", subs[0].Auth) + require.Equal(t, "new-p256dh", subs[0].P256dh) + require.Equal(t, "u_5678", subs[0].UserID) + }) } -func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) { - // Insert subscription with two topics - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) +func TestStoreRemoveByUserIDMultiple(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert two subscriptions for u_1234 and one for u_5678 + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"2", "auth-key", "p256dh-key", "u_5678", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) - // And remove it again - require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint)) - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 3) + + // Remove all subscriptions for u_1234 + require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) + + // Only u_5678's subscription should remain + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint) + require.Equal(t, "u_5678", subs[0].UserID) + }) } -func testStoreRemoveByUserID(t *testing.T, store webpush.Store) { - // Insert subscription with two topics - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) +func TestStoreRemoveByEndpoint(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) - // And remove it again - require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint)) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) + }) } -func testStoreRemoveByUserIDEmpty(t *testing.T, store webpush.Store) { - require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID("")) +func TestStoreRemoveByUserID(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) + }) } -func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) { - // Insert subscription with two topics - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - - // Set updated_at to the past so it shows up as expiring - require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix())) - - // Verify subscription appears in expiring list (warned_at == 0) - subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour) - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) - - // Mark them as warning sent - require.Nil(t, store.MarkExpiryWarningSent(subs)) - - // Verify subscription no longer appears in expiring list (warned_at > 0) - subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) - require.Nil(t, err) - require.Len(t, subs, 0) +func TestStoreRemoveByUserIDEmpty(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID("")) + }) } -func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) { - // Insert subscription with two topics - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) +func TestStoreExpiryWarningSent(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - // Fake-mark them as soon-to-expire - require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix())) + // Set updated_at to the past so it shows up as expiring + require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix())) - // Should not be cleaned up yet - require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + // Verify subscription appears in expiring list (warned_at == 0) + subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour) + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) - // Run expiration - subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) + // Mark them as warning sent + require.Nil(t, store.MarkExpiryWarningSent(subs)) + + // Verify subscription no longer appears in expiring list (warned_at > 0) + subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) + require.Nil(t, err) + require.Len(t, subs, 0) + }) } -func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) { - // Insert subscription with two topics - require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) +func TestStoreExpiring(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) - // Fake-mark them as expired - require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix())) + // Fake-mark them as soon-to-expire + require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix())) - // Run expiration - require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + // Should not be cleaned up yet + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) - // List again, should be 0 - subs, err = store.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) + // Run expiration + subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) + }) +} + +func TestStoreRemoveExpired(t *testing.T) { + forEachBackend(t, func(t *testing.T, store webpush.Store) { + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Fake-mark them as expired + require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix())) + + // Run expiration + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + + // List again, should be 0 + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) + }) }