REmove store interface

This commit is contained in:
binwiederhier
2026-03-01 13:19:53 -05:00
parent 039566bcaf
commit 9736973286
14 changed files with 122 additions and 144 deletions

View File

@@ -21,21 +21,14 @@ var (
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
// Store is the interface for a web push subscription store.
type Store interface {
UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
SubscriptionsForTopic(topic string) ([]*Subscription, error)
SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error)
MarkExpiryWarningSent(subscriptions []*Subscription) error
RemoveSubscriptionsByEndpoint(endpoint string) error
RemoveSubscriptionsByUserID(userID string) error
RemoveExpiredSubscriptions(expireAfter time.Duration) error
SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error
Close() error
// Store holds the database connection and queries for web push subscriptions.
type Store struct {
db *sql.DB
queries queries
}
// storeQueries holds the database-specific SQL queries.
type storeQueries struct {
// queries holds the database-specific SQL queries.
type queries struct {
selectSubscriptionIDByEndpoint string
selectSubscriptionCountBySubscriberIP string
selectSubscriptionsForTopic string
@@ -51,14 +44,8 @@ type storeQueries struct {
deleteSubscriptionTopicWithoutSubscription string
}
// commonStore implements store operations that are identical across database backends.
type commonStore struct {
db *sql.DB
queries storeQueries
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := s.db.Begin()
if err != nil {
return err
@@ -97,7 +84,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
}
// SubscriptionsForTopic returns all subscriptions for the given topic.
func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
if err != nil {
return nil, err
@@ -107,7 +94,7 @@ func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, erro
}
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
@@ -117,7 +104,7 @@ func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscri
}
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error {
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
tx, err := s.db.Begin()
if err != nil {
return err
@@ -132,13 +119,13 @@ func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
func (s *Store) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
return err
}
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return ErrWebPushUserIDCannotBeEmpty
}
@@ -147,7 +134,7 @@ func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
if err != nil {
return err
@@ -158,13 +145,13 @@ func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) erro
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
// exported for testing purposes.
func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
func (s *Store) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
return err
}
// Close closes the underlying database connection.
func (s *commonStore) Close() error {
func (s *Store) Close() error {
return s.db.Close()
}

View File

@@ -71,13 +71,13 @@ const (
)
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
func NewPostgresStore(db *sql.DB) (Store, error) {
func NewPostgresStore(db *sql.DB) (*Store, error) {
if err := setupPostgresDB(db); err != nil {
return nil, err
}
return &commonStore{
return &Store{
db: db,
queries: storeQueries{
queries: queries{
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery,

View File

@@ -76,7 +76,7 @@ const (
)
// NewSQLiteStore creates a new SQLite-backed web push store.
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@@ -87,9 +87,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &commonStore{
return &Store{
db: db,
queries: storeQueries{
queries: queries{
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIPQuery,
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,

View File

@@ -14,7 +14,7 @@ import (
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) {
func forEachBackend(t *testing.T, f func(t *testing.T, store *webpush.Store)) {
t.Run("sqlite", func(t *testing.T) {
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
@@ -30,7 +30,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, store webpush.Store)) {
}
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := store.SubscriptionsForTopic("test-topic")
@@ -49,7 +49,7 @@ func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
}
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
@@ -68,7 +68,7 @@ func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
}
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics, and another with one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
@@ -99,7 +99,7 @@ func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
}
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert a subscription
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
@@ -124,7 +124,7 @@ func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
}
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert two subscriptions for u_1234 and one for u_5678
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
@@ -147,7 +147,7 @@ func TestStoreRemoveByUserIDMultiple(t *testing.T) {
}
func TestStoreRemoveByEndpoint(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -163,7 +163,7 @@ func TestStoreRemoveByEndpoint(t *testing.T) {
}
func TestStoreRemoveByUserID(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -179,13 +179,13 @@ func TestStoreRemoveByUserID(t *testing.T) {
}
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
})
}
func TestStoreExpiryWarningSent(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
@@ -209,7 +209,7 @@ func TestStoreExpiryWarningSent(t *testing.T) {
}
func TestStoreExpiring(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
@@ -231,7 +231,7 @@ func TestStoreExpiring(t *testing.T) {
}
func TestStoreRemoveExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, store webpush.Store) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")