From 4e5f95ba0c073bff35e2c0a748000718ff56d224 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Feb 2026 19:53:34 -0500 Subject: [PATCH] Refactor webpush store to eliminate code duplication Consolidate SQLite and Postgres store implementations into a single commonStore with database-specific SQL queries passed via configuration. This eliminates ~100 lines of duplicate code while maintaining full functionality for both backends. Key changes: - Move all store methods to commonStore in store.go - Remove sqliteStore and postgresStore wrapper structs - Refactor SQLite to use QueryRow() pattern like Postgres - Pass database-specific queries via storeQueries struct - Make store types unexported, only expose Store interface All tests pass for both SQLite and PostgreSQL backends. --- server/server.go | 2 +- webpush/store.go | 137 +++++++++++++++++++++++++++- webpush/store_postgres.go | 141 ++++------------------------- webpush/store_postgres_test.go | 2 +- webpush/store_sqlite.go | 161 ++++----------------------------- webpush/store_sqlite_test.go | 2 +- 6 files changed, 173 insertions(+), 272 deletions(-) diff --git a/server/server.go b/server/server.go index cba0179a..de8af35f 100644 --- a/server/server.go +++ b/server/server.go @@ -58,7 +58,7 @@ type Server struct { messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! messageCache *messageCache // Database that stores the messages - webPush webpush.Store // Database that stores web push subscriptions + webPush webpush.Store // Database that stores web push subscriptions fileCache *fileCache // File system based cache that stores attachments stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) diff --git a/webpush/store.go b/webpush/store.go index 70dfc8fa..048fe083 100644 --- a/webpush/store.go +++ b/webpush/store.go @@ -7,6 +7,7 @@ import ( "time" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" ) const ( @@ -16,7 +17,6 @@ const ( ) var ( - ErrWebPushNoRows = errors.New("no rows found") ErrWebPushTooManySubscriptions = errors.New("too many subscriptions") ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty") ) @@ -34,6 +34,29 @@ type Store interface { Close() error } +// storeQueries holds the database-specific SQL queries. +type storeQueries struct { + selectSubscriptionIDByEndpoint string + selectSubscriptionCountBySubscriberIP string + selectSubscriptionsForTopic string + selectSubscriptionsExpiringSoon string + insertSubscription string + updateSubscriptionWarningSent string + updateSubscriptionUpdatedAt string + deleteSubscriptionByEndpoint string + deleteSubscriptionByUserID string + deleteSubscriptionByAge string + insertSubscriptionTopic string + deleteSubscriptionTopicAll string + deleteSubscriptionTopicWithoutSubscription string +} + +// commonStore implements store operations that are identical across database backends. +type commonStore struct { + db *sql.DB + queries storeQueries +} + // Subscription represents a web push subscription. type Subscription struct { ID string @@ -69,3 +92,115 @@ func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) { } return subscriptions, nil } + +// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. +func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Read number of subscriptions for subscriber IP address + var subscriptionCount int + if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil { + return err + } + // Read existing subscription ID for endpoint (or create new ID) + var subscriptionID string + err = tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID) + if errors.Is(err, sql.ErrNoRows) { + if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { + return ErrWebPushTooManySubscriptions + } + subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) + } else if err != nil { + return err + } + // Insert or update subscription + updatedAt, warnedAt := time.Now().Unix(), 0 + if _, err = tx.Exec(s.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { + return err + } + // Replace all subscription topics + if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil { + return err + } + for _, topic := range topics { + if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil { + return err + } + } + return tx.Commit() +} + +// SubscriptionsForTopic returns all subscriptions for the given topic. +func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { + rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. +func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { + rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. +func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, subscription := range subscriptions { + if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil { + return err + } + } + return tx.Commit() +} + +// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. +func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error { + _, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint) + return err +} + +// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. +func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error { + if userID == "" { + return ErrWebPushUserIDCannotBeEmpty + } + _, err := s.db.Exec(s.queries.deleteSubscriptionByUserID, userID) + return err +} + +// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. +func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { + _, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix()) + if err != nil { + return err + } + _, err = s.db.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription) + return err +} + +// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is +// exported for testing purposes and is not part of the Store interface. +func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error { + _, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint) + return err +} + +// Close closes the underlying database connection. +func (s *commonStore) Close() error { + return s.db.Close() +} diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index 025e11c8..da69c6e1 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -2,13 +2,8 @@ package webpush import ( "database/sql" - "errors" - "net/netip" - "time" _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver - - "heckel.io/ntfy/v2/util" ) const ( @@ -74,13 +69,8 @@ const ( pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1` ) -// PostgresStore is a web push subscription store backed by PostgreSQL. -type PostgresStore struct { - db *sql.DB -} - // NewPostgresStore creates a new PostgreSQL-backed web push store. -func NewPostgresStore(dsn string) (*PostgresStore, error) { +func NewPostgresStore(dsn string) (Store, error) { db, err := sql.Open("pgx", dsn) if err != nil { return nil, err @@ -91,8 +81,23 @@ func NewPostgresStore(dsn string) (*PostgresStore, error) { if err := setupPostgresDB(db); err != nil { return nil, err } - return &PostgresStore{ + return &commonStore{ db: db, + queries: storeQueries{ + selectSubscriptionIDByEndpoint: pgSelectSubscriptionIDByEndpoint, + selectSubscriptionCountBySubscriberIP: pgSelectSubscriptionCountBySubscriberIP, + selectSubscriptionsForTopic: pgSelectSubscriptionsForTopicQuery, + selectSubscriptionsExpiringSoon: pgSelectSubscriptionsExpiringSoonQuery, + insertSubscription: pgInsertSubscriptionQuery, + updateSubscriptionWarningSent: pgUpdateSubscriptionWarningSentQuery, + updateSubscriptionUpdatedAt: pgUpdateSubscriptionUpdatedAtQuery, + deleteSubscriptionByEndpoint: pgDeleteSubscriptionByEndpointQuery, + deleteSubscriptionByUserID: pgDeleteSubscriptionByUserIDQuery, + deleteSubscriptionByAge: pgDeleteSubscriptionByAgeQuery, + insertSubscriptionTopic: pgInsertSubscriptionTopicQuery, + deleteSubscriptionTopicAll: pgDeleteSubscriptionTopicAllQuery, + deleteSubscriptionTopicWithoutSubscription: pgDeleteSubscriptionTopicWithoutSubscription, + }, }, nil } @@ -119,115 +124,3 @@ func setupNewPostgresDB(db *sql.DB) error { } return tx.Commit() } - -// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. -func (c *PostgresStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Read number of subscriptions for subscriber IP address - var subscriptionCount int - if err := tx.QueryRow(pgSelectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil { - return err - } - // Read existing subscription ID for endpoint (or create new ID) - var subscriptionID string - err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID) - if errors.Is(err, sql.ErrNoRows) { - if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { - return ErrWebPushTooManySubscriptions - } - subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) - } else if err != nil { - return err - } - // Insert or update subscription - updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err = tx.Exec(pgInsertSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { - return err - } - // Replace all subscription topics - if _, err := tx.Exec(pgDeleteSubscriptionTopicAllQuery, subscriptionID); err != nil { - return err - } - for _, topic := range topics { - if _, err = tx.Exec(pgInsertSubscriptionTopicQuery, subscriptionID, topic); err != nil { - return err - } - } - return tx.Commit() -} - -// SubscriptionsForTopic returns all subscriptions for the given topic. -func (c *PostgresStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { - rows, err := c.db.Query(pgSelectSubscriptionsForTopicQuery, topic) - if err != nil { - return nil, err - } - defer rows.Close() - return subscriptionsFromRows(rows) -} - -// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. -func (c *PostgresStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { - rows, err := c.db.Query(pgSelectSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - return subscriptionsFromRows(rows) -} - -// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. -func (c *PostgresStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, subscription := range subscriptions { - if _, err := tx.Exec(pgUpdateSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { - return err - } - } - return tx.Commit() -} - -// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. -func (c *PostgresStore) RemoveSubscriptionsByEndpoint(endpoint string) error { - _, err := c.db.Exec(pgDeleteSubscriptionByEndpointQuery, endpoint) - return err -} - -// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. -func (c *PostgresStore) RemoveSubscriptionsByUserID(userID string) error { - if userID == "" { - return ErrWebPushUserIDCannotBeEmpty - } - _, err := c.db.Exec(pgDeleteSubscriptionByUserIDQuery, userID) - return err -} - -// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. -func (c *PostgresStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { - _, err := c.db.Exec(pgDeleteSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) - if err != nil { - return err - } - _, err = c.db.Exec(pgDeleteSubscriptionTopicWithoutSubscription) - return err -} - -// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is -// exported for testing purposes and is not part of the Store interface. -func (c *PostgresStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error { - _, err := c.db.Exec(pgUpdateSubscriptionUpdatedAtQuery, updatedAt, endpoint) - return err -} - -// Close closes the underlying database connection. -func (c *PostgresStore) Close() error { - return c.db.Close() -} diff --git a/webpush/store_postgres_test.go b/webpush/store_postgres_test.go index 07a3c6be..0124441c 100644 --- a/webpush/store_postgres_test.go +++ b/webpush/store_postgres_test.go @@ -12,7 +12,7 @@ import ( "heckel.io/ntfy/v2/webpush" ) -func newTestPostgresStore(t *testing.T) *webpush.PostgresStore { +func newTestPostgresStore(t *testing.T) webpush.Store { dsn := os.Getenv("NTFY_TEST_DATABASE_URL") if dsn == "" { t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index bbd9be48..97bacf74 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -2,12 +2,8 @@ package webpush import ( "database/sql" - "net/netip" - "time" _ "github.com/mattn/go-sqlite3" // SQLite driver - - "heckel.io/ntfy/v2/util" ) const ( @@ -80,13 +76,8 @@ const ( sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` ) -// SQLiteStore is a web push subscription store backed by SQLite. -type SQLiteStore struct { - db *sql.DB -} - // NewSQLiteStore creates a new SQLite-backed web push store. -func NewSQLiteStore(filename, startupQueries string) (*SQLiteStore, error) { +func NewSQLiteStore(filename, startupQueries string) (Store, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -97,8 +88,23 @@ func NewSQLiteStore(filename, startupQueries string) (*SQLiteStore, error) { if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil { return nil, err } - return &SQLiteStore{ + return &commonStore{ db: db, + queries: storeQueries{ + selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpoint, + selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIP, + selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery, + selectSubscriptionsExpiringSoon: sqliteSelectWebPushSubscriptionsExpiringSoonQuery, + insertSubscription: sqliteInsertWebPushSubscriptionQuery, + updateSubscriptionWarningSent: sqliteUpdateWebPushSubscriptionWarningSentQuery, + updateSubscriptionUpdatedAt: sqliteUpdateWebPushSubscriptionUpdatedAtQuery, + deleteSubscriptionByEndpoint: sqliteDeleteWebPushSubscriptionByEndpointQuery, + deleteSubscriptionByUserID: sqliteDeleteWebPushSubscriptionByUserIDQuery, + deleteSubscriptionByAge: sqliteDeleteWebPushSubscriptionByAgeQuery, + insertSubscriptionTopic: sqliteInsertWebPushSubscriptionTopicQuery, + deleteSubscriptionTopicAll: sqliteDeleteWebPushSubscriptionTopicAllQuery, + deleteSubscriptionTopicWithoutSubscription: sqliteDeleteWebPushSubscriptionTopicWithoutSubscription, + }, }, nil } @@ -130,136 +136,3 @@ func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error { } return nil } - -// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all -// existing entries for a given endpoint. -func (c *SQLiteStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Read number of subscriptions for subscriber IP address - rowsCount, err := tx.Query(sqliteSelectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) - if err != nil { - return err - } - defer rowsCount.Close() - var subscriptionCount int - if !rowsCount.Next() { - return ErrWebPushNoRows - } - if err := rowsCount.Scan(&subscriptionCount); err != nil { - return err - } - if err := rowsCount.Close(); err != nil { - return err - } - // Read existing subscription ID for endpoint (or create new ID) - rows, err := tx.Query(sqliteSelectWebPushSubscriptionIDByEndpoint, endpoint) - if err != nil { - return err - } - defer rows.Close() - var subscriptionID string - if rows.Next() { - if err := rows.Scan(&subscriptionID); err != nil { - return err - } - } else { - if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { - return ErrWebPushTooManySubscriptions - } - subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) - } - if err := rows.Close(); err != nil { - return err - } - // Insert or update subscription - updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err = tx.Exec(sqliteInsertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { - return err - } - // Replace all subscription topics - if _, err := tx.Exec(sqliteDeleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil { - return err - } - for _, topic := range topics { - if _, err = tx.Exec(sqliteInsertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil { - return err - } - } - return tx.Commit() -} - -// SubscriptionsForTopic returns all subscriptions for the given topic. -func (c *SQLiteStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { - rows, err := c.db.Query(sqliteSelectWebPushSubscriptionsForTopicQuery, topic) - if err != nil { - return nil, err - } - defer rows.Close() - return subscriptionsFromRows(rows) -} - -// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. -func (c *SQLiteStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { - rows, err := c.db.Query(sqliteSelectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - return subscriptionsFromRows(rows) -} - -// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. -func (c *SQLiteStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, subscription := range subscriptions { - if _, err := tx.Exec(sqliteUpdateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { - return err - } - } - return tx.Commit() -} - -// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. -func (c *SQLiteStore) RemoveSubscriptionsByEndpoint(endpoint string) error { - _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByEndpointQuery, endpoint) - return err -} - -// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. -func (c *SQLiteStore) RemoveSubscriptionsByUserID(userID string) error { - if userID == "" { - return ErrWebPushUserIDCannotBeEmpty - } - _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByUserIDQuery, userID) - return err -} - -// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. -func (c *SQLiteStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { - _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) - if err != nil { - return err - } - _, err = c.db.Exec(sqliteDeleteWebPushSubscriptionTopicWithoutSubscription) - return err -} - -// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is -// exported for testing purposes and is not part of the Store interface. -func (c *SQLiteStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error { - _, err := c.db.Exec(sqliteUpdateWebPushSubscriptionUpdatedAtQuery, updatedAt, endpoint) - return err -} - -// Close closes the underlying database connection. -func (c *SQLiteStore) Close() error { - return c.db.Close() -} diff --git a/webpush/store_sqlite_test.go b/webpush/store_sqlite_test.go index b01ca55d..1d4087d1 100644 --- a/webpush/store_sqlite_test.go +++ b/webpush/store_sqlite_test.go @@ -8,7 +8,7 @@ import ( "heckel.io/ntfy/v2/webpush" ) -func newTestSQLiteStore(t *testing.T) *webpush.SQLiteStore { +func newTestSQLiteStore(t *testing.T) webpush.Store { store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "") require.Nil(t, err) t.Cleanup(func() { store.Close() })