From e432bf2886631654d2203df5d84aa818c581fbb2 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Feb 2026 12:13:10 -0500 Subject: [PATCH] Rename PostgreSQL table prefix from wp_ to webpush_ --- cmd/serve.go | 7 +- docs/config.md | 27 +++- docs/releases.md | 7 + go.mod | 4 + go.sum | 9 ++ server/config.go | 2 + server/server.go | 13 +- server/server.yml | 6 + server/server_webpush.go | 5 +- server/server_webpush_test.go | 14 +- server/types.go | 16 -- server/webpush_store.go | 285 ---------------------------------- server/webpush_store_test.go | 199 ------------------------ webpush/postgres.go | 224 ++++++++++++++++++++++++++ webpush/postgres_test.go | 207 ++++++++++++++++++++++++ webpush/sqlite.go | 280 +++++++++++++++++++++++++++++++++ webpush/sqlite_test.go | 203 ++++++++++++++++++++++++ webpush/store.go | 51 ++++++ 18 files changed, 1041 insertions(+), 518 deletions(-) delete mode 100644 server/webpush_store.go delete mode 100644 server/webpush_store_test.go create mode 100644 webpush/postgres.go create mode 100644 webpush/postgres_test.go create mode 100644 webpush/sqlite.go create mode 100644 webpush/sqlite_test.go create mode 100644 webpush/store.go diff --git a/cmd/serve.go b/cmd/serve.go index b451a118..33dc838d 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -39,6 +39,7 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "key-file", Aliases: []string{"key_file", "K"}, EnvVars: []string{"NTFY_KEY_FILE"}, Usage: "private key file, if listen-https is set"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}), @@ -143,6 +144,7 @@ func execServe(c *cli.Context) error { keyFile := c.String("key-file") certFile := c.String("cert-file") firebaseKeyFile := c.String("firebase-key-file") + databaseURL := c.String("database-url") webPushPrivateKey := c.String("web-push-private-key") webPushPublicKey := c.String("web-push-public-key") webPushFile := c.String("web-push-file") @@ -284,8 +286,8 @@ func execServe(c *cli.Context) error { return errors.New("if set, FCM key file must exist") } else if firebaseKeyFile != "" && !server.FirebaseAvailable { return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)") - } else if webPushPublicKey != "" && (webPushPrivateKey == "" || webPushFile == "" || webPushEmailAddress == "" || baseURL == "") { - return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file, web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys") + } else if webPushPublicKey != "" && (webPushPrivateKey == "" || (webPushFile == "" && databaseURL == "") || webPushEmailAddress == "" || baseURL == "") { + return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file (or database-url), web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys") } else if keepaliveInterval < 5*time.Second { return errors.New("keepalive interval cannot be lower than five seconds") } else if managerInterval < 5*time.Second { @@ -494,6 +496,7 @@ func execServe(c *cli.Context) error { conf.EnableMetrics = enableMetrics conf.MetricsListenHTTP = metricsListenHTTP conf.ProfileListenHTTP = profileListenHTTP + conf.DatabaseURL = databaseURL conf.WebPushPrivateKey = webPushPrivateKey conf.WebPushPublicKey = webPushPublicKey conf.WebPushFile = webPushFile diff --git a/docs/config.md b/docs/config.md index 8a125146..f6e27f60 100644 --- a/docs/config.md +++ b/docs/config.md @@ -144,6 +144,20 @@ the message to the subscribers. Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscribe/api.md#poll-for-messages), as well as the [`since=` parameter](subscribe/api.md#fetch-cached-messages). +## PostgreSQL database +By default, ntfy uses SQLite for all database-backed stores. As an alternative, you can configure ntfy to use PostgreSQL +by setting the `database-url` option to a PostgreSQL connection string: + +```yaml +database-url: "postgres://user:pass@host:5432/ntfy" +``` + +When `database-url` is set, ntfy will use PostgreSQL for the web push subscription store instead of SQLite. The +`web-push-file` option is not required in this case. Support for PostgreSQL for the message cache and user manager +will be added in future releases. + +You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`. + ## Attachments If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`). @@ -1141,12 +1155,15 @@ a database to keep track of the browser's subscriptions, and an admin email addr - `web-push-public-key` is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890 - `web-push-private-key` is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890 -- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db` +- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db` (not required if `database-url` is set) - `web-push-email-address` is the admin email address send to the push provider, e.g. `sysadmin@example.com` - `web-push-startup-queries` is an optional list of queries to run on startup` - `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`) - `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`) +Alternatively, you can use PostgreSQL instead of SQLite for the web push subscription store by setting `database-url` +(see [PostgreSQL database](#postgresql-database)). + Limitations: - Like foreground browser notifications, background push notifications require the web app to be served over HTTPS. A _valid_ @@ -1172,9 +1189,10 @@ web-push-file: /var/cache/ntfy/webpush.db web-push-email-address: sysadmin@example.com ``` -The `web-push-file` is used to store the push subscriptions. Unused subscriptions will send out a warning after 55 days, -and will automatically expire after 60 days (default). If the gateway returns an error (e.g. 410 Gone when a user has unsubscribed), -subscriptions are also removed automatically. +The `web-push-file` is used to store the push subscriptions in a local SQLite database. Alternatively, if `database-url` +is set, subscriptions are stored in PostgreSQL and `web-push-file` is not required. Unused subscriptions will send out +a warning after 55 days, and will automatically expire after 60 days (default). If the gateway returns an error +(e.g. 410 Gone when a user has unsubscribed), subscriptions are also removed automatically. The web app refreshes subscriptions on start and regularly on an interval, but this file should be persisted across restarts. If the subscription file is deleted or lost, any web apps that aren't open will not receive new web push notifications until you open then. @@ -1755,6 +1773,7 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`). | `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. | | `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. | | `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). | +| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for database-backed stores instead of SQLite. Currently applies to the web push store. See [PostgreSQL database](#postgresql-database). | | `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). | | `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. | | `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) | diff --git a/docs/releases.md b/docs/releases.md index f291bea2..fd71cc0d 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1714,3 +1714,10 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release * Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1)) * Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting) * Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting) +* Fix crash in settings when fragment is detached during backup/restore or log operations + +### ntfy server v2.12.x (UNRELEASED) + +**Features:** + +* Add PostgreSQL as an alternative database backend for the web push subscription store via `database-url` config option diff --git a/go.mod b/go.mod index 992aced8..26c308c9 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.19.0 github.com/SherClockHolmes/webpush-go v1.4.0 + github.com/jackc/pgx/v5 v5.8.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/prometheus/client_golang v1.23.2 github.com/stripe/stripe-go/v74 v74.30.0 @@ -71,6 +72,9 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/gorilla/css v1.0.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index c42210b4..eb0e7d12 100644 --- a/go.sum +++ b/go.sum @@ -104,6 +104,14 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -144,6 +152,7 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= diff --git a/server/config.go b/server/config.go index 278b6aed..85a2c2b2 100644 --- a/server/config.go +++ b/server/config.go @@ -88,6 +88,7 @@ var ( // Config is the main config struct for the application. Use New to instantiate a default config struct. type Config struct { File string // Config file, only used for testing + DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy") BaseURL string ListenHTTP string ListenHTTPS string @@ -192,6 +193,7 @@ type Config struct { func NewConfig() *Config { return &Config{ File: DefaultConfigFile, // Only used for testing + DatabaseURL: "", BaseURL: "", ListenHTTP: DefaultListenHTTP, ListenHTTPS: "", diff --git a/server/server.go b/server/server.go index bf982503..cba0179a 100644 --- a/server/server.go +++ b/server/server.go @@ -37,6 +37,7 @@ import ( "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util/sprig" + "heckel.io/ntfy/v2/webpush" ) // Server is the main server, providing the UI and API for ntfy @@ -57,7 +58,7 @@ type Server struct { messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! messageCache *messageCache // Database that stores the messages - webPush *webPushStore // Database that stores web push subscriptions + webPush webpush.Store // Database that stores web push subscriptions fileCache *fileCache // File system based cache that stores attachments stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) @@ -176,9 +177,13 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - var webPush *webPushStore + var wp webpush.Store if conf.WebPushPublicKey != "" { - webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries) + if conf.DatabaseURL != "" { + wp, err = webpush.NewPostgresStore(conf.DatabaseURL) + } else { + wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries) + } if err != nil { return nil, err } @@ -233,7 +238,7 @@ func New(conf *Config) (*Server, error) { s := &Server{ config: conf, messageCache: messageCache, - webPush: webPush, + webPush: wp, fileCache: fileCache, firebaseClient: firebaseClient, smtpSender: mailer, diff --git a/server/server.yml b/server/server.yml index 598c41bc..63728ae2 100644 --- a/server/server.yml +++ b/server/server.yml @@ -38,6 +38,12 @@ # # firebase-key-file: +# If "database-url" is set, ntfy will use PostgreSQL for database-backed stores instead of SQLite. +# Currently this applies to the web push subscription store. Message cache and user manager support +# will be added in future releases. When set, the "web-push-file" option is not required. +# +# database-url: "postgres://user:pass@host:5432/ntfy" + # If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory. # This allows for service restarts without losing messages in support of the since= parameter. # diff --git a/server/server_webpush.go b/server/server_webpush.go index d3f09bd9..11e37f66 100644 --- a/server/server_webpush.go +++ b/server/server_webpush.go @@ -12,6 +12,7 @@ import ( "github.com/SherClockHolmes/webpush-go" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/user" + wpush "heckel.io/ntfy/v2/webpush" ) const ( @@ -128,7 +129,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error { if err != nil { return err } - warningSent := make([]*webPushSubscription, 0) + warningSent := make([]*wpush.Subscription, 0) for _, subscription := range subscriptions { if err := s.sendWebPushNotification(subscription, payload); err != nil { log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning") @@ -143,7 +144,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error { return nil } -func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error { +func (s *Server) sendWebPushNotification(sub *wpush.Subscription, message []byte, contexters ...log.Contexter) error { log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message") payload := &webpush.Subscription{ Endpoint: sub.Endpoint, diff --git a/server/server_webpush_test.go b/server/server_webpush_test.go index f116103a..fe77787e 100644 --- a/server/server_webpush_test.go +++ b/server/server_webpush_test.go @@ -5,10 +5,6 @@ package server import ( "encoding/json" "fmt" - "github.com/SherClockHolmes/webpush-go" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/user" - "heckel.io/ntfy/v2/util" "io" "net/http" "net/http/httptest" @@ -18,6 +14,12 @@ import ( "sync/atomic" "testing" "time" + + "github.com/SherClockHolmes/webpush-go" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/user" + "heckel.io/ntfy/v2/util" + wpush "heckel.io/ntfy/v2/webpush" ) const ( @@ -238,7 +240,7 @@ func TestServer_WebPush_Expiry(t *testing.T) { addSubscription(t, s, pushService.URL+"/push-receive", "test-topic") requireSubscriptionCount(t, s, "test-topic", 1) - _, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix()) + _, err := s.webPush.(*wpush.SQLiteStore).DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix()) require.Nil(t, err) s.pruneAndNotifyWebPushSubscriptions() @@ -248,7 +250,7 @@ func TestServer_WebPush_Expiry(t *testing.T) { return received.Load() }) - _, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix()) + _, err = s.webPush.(*wpush.SQLiteStore).DB().Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix()) require.Nil(t, err) s.pruneAndNotifyWebPushSubscriptions() diff --git a/server/types.go b/server/types.go index a710a183..dfd797e5 100644 --- a/server/types.go +++ b/server/types.go @@ -593,22 +593,6 @@ func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload { } } -type webPushSubscription struct { - ID string - Endpoint string - Auth string - P256dh string - UserID string -} - -func (w *webPushSubscription) Context() log.Context { - return map[string]any{ - "web_push_subscription_id": w.ID, - "web_push_subscription_user_id": w.UserID, - "web_push_subscription_endpoint": w.Endpoint, - } -} - // https://developer.mozilla.org/en-US/docs/Web/Manifest type webManifestResponse struct { Name string `json:"name"` diff --git a/server/webpush_store.go b/server/webpush_store.go deleted file mode 100644 index db0304be..00000000 --- a/server/webpush_store.go +++ /dev/null @@ -1,285 +0,0 @@ -package server - -import ( - "database/sql" - "errors" - "heckel.io/ntfy/v2/util" - "net/netip" - "time" - - _ "github.com/mattn/go-sqlite3" // SQLite driver -) - -const ( - subscriptionIDPrefix = "wps_" - subscriptionIDLength = 10 - subscriptionEndpointLimitPerSubscriberIP = 10 -) - -var ( - errWebPushNoRows = errors.New("no rows found") - errWebPushTooManySubscriptions = errors.New("too many subscriptions") - errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty") -) - -const ( - createWebPushSubscriptionsTableQuery = ` - BEGIN; - CREATE TABLE IF NOT EXISTS subscription ( - id TEXT PRIMARY KEY, - endpoint TEXT NOT NULL, - key_auth TEXT NOT NULL, - key_p256dh TEXT NOT NULL, - user_id TEXT NOT NULL, - subscriber_ip TEXT NOT NULL, - updated_at INT NOT NULL, - warned_at INT NOT NULL DEFAULT 0 - ); - CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); - CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip); - CREATE TABLE IF NOT EXISTS subscription_topic ( - subscription_id TEXT NOT NULL, - topic TEXT NOT NULL, - PRIMARY KEY (subscription_id, topic), - FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - COMMIT; - ` - builtinStartupQueries = ` - PRAGMA foreign_keys = ON; - ` - - selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` - selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?` - selectWebPushSubscriptionsForTopicQuery = ` - SELECT id, endpoint, key_auth, key_p256dh, user_id - FROM subscription_topic st - JOIN subscription s ON s.id = st.subscription_id - WHERE st.topic = ? - ORDER BY endpoint - ` - selectWebPushSubscriptionsExpiringSoonQuery = ` - SELECT id, endpoint, key_auth, key_p256dh, user_id - FROM subscription - WHERE warned_at = 0 AND updated_at <= ? - ` - insertWebPushSubscriptionQuery = ` - INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT (endpoint) - DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at - ` - updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` - deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?` - deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?` - deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan! - - insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)` - deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?` - deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)` -) - -// Schema management queries -const ( - currentWebPushSchemaVersion = 1 - insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` -) - -type webPushStore struct { - db *sql.DB -} - -func newWebPushStore(filename, startupQueries string) (*webPushStore, error) { - db, err := sql.Open("sqlite3", filename) - if err != nil { - return nil, err - } - if err := setupWebPushDB(db); err != nil { - return nil, err - } - if err := runWebPushStartupQueries(db, startupQueries); err != nil { - return nil, err - } - return &webPushStore{ - db: db, - }, nil -} - -func setupWebPushDB(db *sql.DB) error { - // If 'schemaVersion' table does not exist, this must be a new database - rows, err := db.Query(selectWebPushSchemaVersionQuery) - if err != nil { - return setupNewWebPushDB(db) - } - return rows.Close() -} - -func setupNewWebPushDB(db *sql.DB) error { - if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil { - return err - } - if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil { - return err - } - return nil -} - -func runWebPushStartupQueries(db *sql.DB, startupQueries string) error { - if _, err := db.Exec(startupQueries); err != nil { - return err - } - if _, err := db.Exec(builtinStartupQueries); err != nil { - return err - } - return nil -} - -// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all -// existing entries for a given endpoint. -func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Read number of subscriptions for subscriber IP address - rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) - if err != nil { - return err - } - defer rowsCount.Close() - var subscriptionCount int - if !rowsCount.Next() { - return errWebPushNoRows - } - if err := rowsCount.Scan(&subscriptionCount); err != nil { - return err - } - if err := rowsCount.Close(); err != nil { - return err - } - // Read existing subscription ID for endpoint (or create new ID) - rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint) - if err != nil { - return err - } - defer rows.Close() - var subscriptionID string - if rows.Next() { - if err := rows.Scan(&subscriptionID); err != nil { - return err - } - } else { - if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP { - return errWebPushTooManySubscriptions - } - subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength) - } - if err := rows.Close(); err != nil { - return err - } - // Insert or update subscription - updatedAt, warnedAt := time.Now().Unix(), 0 - if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { - return err - } - // Replace all subscription topics - if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil { - return err - } - for _, topic := range topics { - if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil { - return err - } - } - return tx.Commit() -} - -// SubscriptionsForTopic returns all subscriptions for the given topic -func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) { - rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic) - if err != nil { - return nil, err - } - defer rows.Close() - return c.subscriptionsFromRows(rows) -} - -// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period -func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) { - rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - return c.subscriptionsFromRows(rows) -} - -// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon -func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error { - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, subscription := range subscriptions { - if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { - return err - } - } - return tx.Commit() -} - -func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) { - subscriptions := make([]*webPushSubscription, 0) - for rows.Next() { - var id, endpoint, auth, p256dh, userID string - if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil { - return nil, err - } - subscriptions = append(subscriptions, &webPushSubscription{ - ID: id, - Endpoint: endpoint, - Auth: auth, - P256dh: p256dh, - UserID: userID, - }) - } - return subscriptions, nil -} - -// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint -func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error { - _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint) - return err -} - -// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID -func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error { - if userID == "" { - return errWebPushUserIDCannotBeEmpty - } - _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID) - return err -} - -// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period -func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { - _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) - if err != nil { - return err - } - _, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription) - return err -} - -// Close closes the underlying database connection -func (c *webPushStore) Close() error { - return c.db.Close() -} diff --git a/server/webpush_store_test.go b/server/webpush_store_test.go deleted file mode 100644 index ab5bc424..00000000 --- a/server/webpush_store_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package server - -import ( - "fmt" - "github.com/stretchr/testify/require" - "net/netip" - "path/filepath" - "testing" - "time" -) - -func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - - subs, err := webPush.SubscriptionsForTopic("test-topic") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) - require.Equal(t, subs[0].P256dh, "p256dh-key") - require.Equal(t, subs[0].Auth, "auth-key") - require.Equal(t, subs[0].UserID, "u_1234") - - subs2, err := webPush.SubscriptionsForTopic("mytopic") - require.Nil(t, err) - require.Len(t, subs2, 1) - require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) -} - -func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert 10 subscriptions with the same IP address - for i := 0; i < 10; i++ { - endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) - require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - } - - // Another one for the same endpoint should be fine - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - - // But with a different endpoint it should fail - require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) - - // But with a different IP address it should be fine again - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) -} - -func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics, and another with one topic - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) - - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 2) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint) - - subs, err = webPush.SubscriptionsForTopic("topic2") - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - - // Update the first subscription to have only one topic - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) - - subs, err = webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 2) - require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) - - subs, err = webPush.SubscriptionsForTopic("topic2") - require.Nil(t, err) - require.Len(t, subs, 0) -} - -func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - - // And remove it again - require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint)) - subs, err = webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) -} - -func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - - // And remove it again - require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234")) - subs, err = webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) -} - -func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID("")) -} - -func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - - // Mark them as warning sent - require.Nil(t, webPush.MarkExpiryWarningSent(subs)) - - rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0") - require.Nil(t, err) - defer rows.Close() - var endpoint string - require.True(t, rows.Next()) - require.Nil(t, rows.Scan(&endpoint)) - require.Nil(t, err) - require.Equal(t, testWebPushEndpoint, endpoint) - require.False(t, rows.Next()) -} - -func TestWebPushStore_SubscriptionsExpiring(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - - // Fake-mark them as soon-to-expire - _, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint) - require.Nil(t, err) - - // Should not be cleaned up yet - require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour)) - - // Run expiration - subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour) - require.Nil(t, err) - require.Len(t, subs, 1) - require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) -} - -func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) { - webPush := newTestWebPushStore(t) - defer webPush.Close() - - // Insert subscription with two topics - require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) - subs, err := webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 1) - - // Fake-mark them as expired - _, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint) - require.Nil(t, err) - - // Run expiration - require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour)) - - // List again, should be 0 - subs, err = webPush.SubscriptionsForTopic("topic1") - require.Nil(t, err) - require.Len(t, subs, 0) -} - -func newTestWebPushStore(t *testing.T) *webPushStore { - webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "") - require.Nil(t, err) - return webPush -} diff --git a/webpush/postgres.go b/webpush/postgres.go new file mode 100644 index 00000000..bc6a984e --- /dev/null +++ b/webpush/postgres.go @@ -0,0 +1,224 @@ +package webpush + +import ( + "database/sql" + "net/netip" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver + + "heckel.io/ntfy/v2/util" +) + +const ( + pgCreateTablesQuery = ` + CREATE TABLE IF NOT EXISTS webpush_subscription ( + id TEXT PRIMARY KEY, + endpoint TEXT NOT NULL UNIQUE, + key_auth TEXT NOT NULL, + key_p256dh TEXT NOT NULL, + user_id TEXT NOT NULL, + subscriber_ip TEXT NOT NULL, + updated_at BIGINT NOT NULL, + warned_at BIGINT NOT NULL DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip); + CREATE TABLE IF NOT EXISTS webpush_subscription_topic ( + subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE, + topic TEXT NOT NULL, + PRIMARY KEY (subscription_id, topic) + ); + CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic); + CREATE TABLE IF NOT EXISTS webpush_schema_version ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + ` + + pgSelectSubscriptionIDByEndpoint = `SELECT id FROM webpush_subscription WHERE endpoint = $1` + pgSelectSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1` + pgSelectSubscriptionsForTopicQuery = ` + SELECT s.id, s.endpoint, s.key_auth, s.key_p256dh, s.user_id + FROM webpush_subscription_topic st + JOIN webpush_subscription s ON s.id = st.subscription_id + WHERE st.topic = $1 + ORDER BY s.endpoint + ` + pgSelectSubscriptionsExpiringSoonQuery = ` + SELECT id, endpoint, key_auth, key_p256dh, user_id + FROM webpush_subscription + WHERE warned_at = 0 AND updated_at <= $1 + ` + pgInsertSubscriptionQuery = ` + INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (endpoint) + DO UPDATE SET key_auth = EXCLUDED.key_auth, key_p256dh = EXCLUDED.key_p256dh, user_id = EXCLUDED.user_id, subscriber_ip = EXCLUDED.subscriber_ip, updated_at = EXCLUDED.updated_at, warned_at = EXCLUDED.warned_at + ` + pgUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2` + pgDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1` + pgDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1` + pgDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1` + + pgInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)` + pgDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1` + pgDeleteSubscriptionTopicWithoutSubscription = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)` +) + +// PostgreSQL schema management queries +const ( + pgCurrentSchemaVersion = 1 + pgInsertSchemaVersion = `INSERT INTO webpush_schema_version VALUES (1, $1)` + pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1` +) + +// PostgresStore is a web push subscription store backed by PostgreSQL. +type PostgresStore struct { + db *sql.DB +} + +// NewPostgresStore creates a new PostgreSQL-backed web push store. +func NewPostgresStore(dsn string) (*PostgresStore, error) { + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, err + } + if err := db.Ping(); err != nil { + return nil, err + } + if err := setupPostgresDB(db); err != nil { + return nil, err + } + return &PostgresStore{ + db: db, + }, nil +} + +// DB returns the underlying database connection. This is exported for testing purposes. +func (c *PostgresStore) DB() *sql.DB { + return c.db +} + +func setupPostgresDB(db *sql.DB) error { + // If 'wp_schema_version' table does not exist, this must be a new database + rows, err := db.Query(pgSelectSchemaVersionQuery) + if err != nil { + return setupNewPostgresDB(db) + } + return rows.Close() +} + +func setupNewPostgresDB(db *sql.DB) error { + if _, err := db.Exec(pgCreateTablesQuery); err != nil { + return err + } + if _, err := db.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil { + return err + } + return nil +} + +// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. +func (c *PostgresStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Read number of subscriptions for subscriber IP address + var subscriptionCount int + if err := tx.QueryRow(pgSelectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil { + return err + } + // Read existing subscription ID for endpoint (or create new ID) + var subscriptionID string + err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID) + if err == sql.ErrNoRows { + if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP { + return ErrWebPushTooManySubscriptions + } + subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength) + } else if err != nil { + return err + } + // Insert or update subscription + updatedAt, warnedAt := time.Now().Unix(), 0 + if _, err = tx.Exec(pgInsertSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { + return err + } + // Replace all subscription topics + if _, err := tx.Exec(pgDeleteSubscriptionTopicAllQuery, subscriptionID); err != nil { + return err + } + for _, topic := range topics { + if _, err = tx.Exec(pgInsertSubscriptionTopicQuery, subscriptionID, topic); err != nil { + return err + } + } + return tx.Commit() +} + +// SubscriptionsForTopic returns all subscriptions for the given topic. +func (c *PostgresStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { + rows, err := c.db.Query(pgSelectSubscriptionsForTopicQuery, topic) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. +func (c *PostgresStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { + rows, err := c.db.Query(pgSelectSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. +func (c *PostgresStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, subscription := range subscriptions { + if _, err := tx.Exec(pgUpdateSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { + return err + } + } + return tx.Commit() +} + +// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. +func (c *PostgresStore) RemoveSubscriptionsByEndpoint(endpoint string) error { + _, err := c.db.Exec(pgDeleteSubscriptionByEndpointQuery, endpoint) + return err +} + +// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. +func (c *PostgresStore) RemoveSubscriptionsByUserID(userID string) error { + if userID == "" { + return ErrWebPushUserIDCannotBeEmpty + } + _, err := c.db.Exec(pgDeleteSubscriptionByUserIDQuery, userID) + return err +} + +// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. +func (c *PostgresStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { + _, err := c.db.Exec(pgDeleteSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) + if err != nil { + return err + } + _, err = c.db.Exec(pgDeleteSubscriptionTopicWithoutSubscription) + return err +} + +// Close closes the underlying database connection. +func (c *PostgresStore) Close() error { + return c.db.Close() +} diff --git a/webpush/postgres_test.go b/webpush/postgres_test.go new file mode 100644 index 00000000..2cb77bc8 --- /dev/null +++ b/webpush/postgres_test.go @@ -0,0 +1,207 @@ +package webpush_test + +import ( + "fmt" + "net/netip" + "os" + "testing" + "time" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/webpush" +) + +func newTestPostgresStore(t *testing.T) *webpush.PostgresStore { + dsn := os.Getenv("NTFY_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") + } + store, err := webpush.NewPostgresStore(dsn) + require.Nil(t, err) + t.Cleanup(func() { + // Clean up tables after each test + db := store.DB() + db.Exec("DELETE FROM webpush_subscription_topic") + db.Exec("DELETE FROM webpush_subscription") + store.Close() + }) + // Clean up tables before test + db := store.DB() + db.Exec("DELETE FROM webpush_subscription_topic") + db.Exec("DELETE FROM webpush_subscription") + return store +} + +func TestPostgresStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) { + store := newTestPostgresStore(t) + + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + subs, err := store.SubscriptionsForTopic("test-topic") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) + require.Equal(t, subs[0].P256dh, "p256dh-key") + require.Equal(t, subs[0].Auth, "auth-key") + require.Equal(t, subs[0].UserID, "u_1234") + + subs2, err := store.SubscriptionsForTopic("mytopic") + require.Nil(t, err) + require.Len(t, subs2, 1) + require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) +} + +func TestPostgresStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert 10 subscriptions with the same IP address + for i := 0; i < 10; i++ { + endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) + require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + } + + // Another one for the same endpoint should be fine + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different endpoint it should fail + require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different IP address it should be fine again + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) +} + +func TestPostgresStore_UpsertSubscription_UpdateTopics(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics, and another with one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) + + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint) + + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + + // Update the first subscription to have only one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestPostgresStore_RemoveSubscriptionsByEndpoint(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint)) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestPostgresStore_RemoveSubscriptionsByUserID(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestPostgresStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) { + store := newTestPostgresStore(t) + require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID("")) +} + +func TestPostgresStore_MarkExpiryWarningSent(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Mark them as warning sent + require.Nil(t, store.MarkExpiryWarningSent(subs)) + + rows, err := store.DB().Query("SELECT endpoint FROM webpush_subscription WHERE warned_at > 0") + require.Nil(t, err) + defer rows.Close() + var endpoint string + require.True(t, rows.Next()) + require.Nil(t, rows.Scan(&endpoint)) + require.Nil(t, err) + require.Equal(t, testWebPushEndpoint, endpoint) + require.False(t, rows.Next()) +} + +func TestPostgresStore_SubscriptionsExpiring(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Fake-mark them as soon-to-expire + _, err = store.DB().Exec("UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint) + require.Nil(t, err) + + // Should not be cleaned up yet + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + + // Run expiration + subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) +} + +func TestPostgresStore_RemoveExpiredSubscriptions(t *testing.T) { + store := newTestPostgresStore(t) + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Fake-mark them as expired + _, err = store.DB().Exec("UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint) + require.Nil(t, err) + + // Run expiration + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + + // List again, should be 0 + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} diff --git a/webpush/sqlite.go b/webpush/sqlite.go new file mode 100644 index 00000000..a7871f97 --- /dev/null +++ b/webpush/sqlite.go @@ -0,0 +1,280 @@ +package webpush + +import ( + "database/sql" + "net/netip" + "time" + + _ "github.com/mattn/go-sqlite3" // SQLite driver + + "heckel.io/ntfy/v2/util" +) + +const ( + sqliteCreateWebPushSubscriptionsTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS subscription ( + id TEXT PRIMARY KEY, + endpoint TEXT NOT NULL, + key_auth TEXT NOT NULL, + key_p256dh TEXT NOT NULL, + user_id TEXT NOT NULL, + subscriber_ip TEXT NOT NULL, + updated_at INT NOT NULL, + warned_at INT NOT NULL DEFAULT 0 + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint); + CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip); + CREATE TABLE IF NOT EXISTS subscription_topic ( + subscription_id TEXT NOT NULL, + topic TEXT NOT NULL, + PRIMARY KEY (subscription_id, topic), + FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + COMMIT; + ` + sqliteBuiltinStartupQueries = ` + PRAGMA foreign_keys = ON; + ` + + sqliteSelectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?` + sqliteSelectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?` + sqliteSelectWebPushSubscriptionsForTopicQuery = ` + SELECT id, endpoint, key_auth, key_p256dh, user_id + FROM subscription_topic st + JOIN subscription s ON s.id = st.subscription_id + WHERE st.topic = ? + ORDER BY endpoint + ` + sqliteSelectWebPushSubscriptionsExpiringSoonQuery = ` + SELECT id, endpoint, key_auth, key_p256dh, user_id + FROM subscription + WHERE warned_at = 0 AND updated_at <= ? + ` + sqliteInsertWebPushSubscriptionQuery = ` + INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (endpoint) + DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at + ` + sqliteUpdateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?` + sqliteDeleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?` + sqliteDeleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?` + sqliteDeleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan! + + sqliteInsertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)` + sqliteDeleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?` + sqliteDeleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)` +) + +// SQLite schema management queries +const ( + sqliteCurrentWebPushSchemaVersion = 1 + sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` +) + +// SQLiteStore is a web push subscription store backed by SQLite. +type SQLiteStore struct { + db *sql.DB +} + +// NewSQLiteStore creates a new SQLite-backed web push store. +func NewSQLiteStore(filename, startupQueries string) (*SQLiteStore, error) { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return nil, err + } + if err := setupSQLiteWebPushDB(db); err != nil { + return nil, err + } + if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil { + return nil, err + } + return &SQLiteStore{ + db: db, + }, nil +} + +// DB returns the underlying database connection. This is exported for testing purposes. +func (c *SQLiteStore) DB() *sql.DB { + return c.db +} + +func setupSQLiteWebPushDB(db *sql.DB) error { + // If 'schemaVersion' table does not exist, this must be a new database + rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery) + if err != nil { + return setupNewSQLiteWebPushDB(db) + } + return rows.Close() +} + +func setupNewSQLiteWebPushDB(db *sql.DB) error { + if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil { + return err + } + return nil +} + +func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil { + return err + } + return nil +} + +// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all +// existing entries for a given endpoint. +func (c *SQLiteStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Read number of subscriptions for subscriber IP address + rowsCount, err := tx.Query(sqliteSelectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String()) + if err != nil { + return err + } + defer rowsCount.Close() + var subscriptionCount int + if !rowsCount.Next() { + return ErrWebPushNoRows + } + if err := rowsCount.Scan(&subscriptionCount); err != nil { + return err + } + if err := rowsCount.Close(); err != nil { + return err + } + // Read existing subscription ID for endpoint (or create new ID) + rows, err := tx.Query(sqliteSelectWebPushSubscriptionIDByEndpoint, endpoint) + if err != nil { + return err + } + defer rows.Close() + var subscriptionID string + if rows.Next() { + if err := rows.Scan(&subscriptionID); err != nil { + return err + } + } else { + if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP { + return ErrWebPushTooManySubscriptions + } + subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength) + } + if err := rows.Close(); err != nil { + return err + } + // Insert or update subscription + updatedAt, warnedAt := time.Now().Unix(), 0 + if _, err = tx.Exec(sqliteInsertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil { + return err + } + // Replace all subscription topics + if _, err := tx.Exec(sqliteDeleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil { + return err + } + for _, topic := range topics { + if _, err = tx.Exec(sqliteInsertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil { + return err + } + } + return tx.Commit() +} + +// SubscriptionsForTopic returns all subscriptions for the given topic. +func (c *SQLiteStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) { + rows, err := c.db.Query(sqliteSelectWebPushSubscriptionsForTopicQuery, topic) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period. +func (c *SQLiteStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) { + rows, err := c.db.Query(sqliteSelectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + return subscriptionsFromRows(rows) +} + +// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. +func (c *SQLiteStore) MarkExpiryWarningSent(subscriptions []*Subscription) error { + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, subscription := range subscriptions { + if _, err := tx.Exec(sqliteUpdateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil { + return err + } + } + return tx.Commit() +} + +// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. +func (c *SQLiteStore) RemoveSubscriptionsByEndpoint(endpoint string) error { + _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByEndpointQuery, endpoint) + return err +} + +// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID. +func (c *SQLiteStore) RemoveSubscriptionsByUserID(userID string) error { + if userID == "" { + return ErrWebPushUserIDCannotBeEmpty + } + _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByUserIDQuery, userID) + return err +} + +// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period. +func (c *SQLiteStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error { + _, err := c.db.Exec(sqliteDeleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix()) + if err != nil { + return err + } + _, err = c.db.Exec(sqliteDeleteWebPushSubscriptionTopicWithoutSubscription) + return err +} + +// Close closes the underlying database connection. +func (c *SQLiteStore) Close() error { + return c.db.Close() +} + +func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) { + subscriptions := make([]*Subscription, 0) + for rows.Next() { + var id, endpoint, auth, p256dh, userID string + if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil { + return nil, err + } + subscriptions = append(subscriptions, &Subscription{ + ID: id, + Endpoint: endpoint, + Auth: auth, + P256dh: p256dh, + UserID: userID, + }) + } + return subscriptions, nil +} diff --git a/webpush/sqlite_test.go b/webpush/sqlite_test.go new file mode 100644 index 00000000..c43eca86 --- /dev/null +++ b/webpush/sqlite_test.go @@ -0,0 +1,203 @@ +package webpush_test + +import ( + "fmt" + "net/netip" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/webpush" +) + +const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF" + +func TestSQLiteStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + subs, err := store.SubscriptionsForTopic("test-topic") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, subs[0].Endpoint, testWebPushEndpoint) + require.Equal(t, subs[0].P256dh, "p256dh-key") + require.Equal(t, subs[0].Auth, "auth-key") + require.Equal(t, subs[0].UserID, "u_1234") + + subs2, err := store.SubscriptionsForTopic("mytopic") + require.Nil(t, err) + require.Len(t, subs2, 1) + require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint) +} + +func TestSQLiteStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert 10 subscriptions with the same IP address + for i := 0; i < 10; i++ { + endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i) + require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + } + + // Another one for the same endpoint should be fine + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different endpoint it should fail + require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"})) + + // But with a different IP address it should be fine again + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"})) +} + +func TestSQLiteStore_UpsertSubscription_UpdateTopics(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics, and another with one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"})) + + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint) + + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + + // Update the first subscription to have only one topic + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"})) + + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 2) + require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint) + + subs, err = store.SubscriptionsForTopic("topic2") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestSQLiteStore_RemoveSubscriptionsByEndpoint(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint)) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestSQLiteStore_RemoveSubscriptionsByUserID(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // And remove it again + require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234")) + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func TestSQLiteStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID("")) +} + +func TestSQLiteStore_MarkExpiryWarningSent(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Mark them as warning sent + require.Nil(t, store.MarkExpiryWarningSent(subs)) + + rows, err := store.DB().Query("SELECT endpoint FROM subscription WHERE warned_at > 0") + require.Nil(t, err) + defer rows.Close() + var endpoint string + require.True(t, rows.Next()) + require.Nil(t, rows.Scan(&endpoint)) + require.Nil(t, err) + require.Equal(t, testWebPushEndpoint, endpoint) + require.False(t, rows.Next()) +} + +func TestSQLiteStore_SubscriptionsExpiring(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Fake-mark them as soon-to-expire + _, err = store.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint) + require.Nil(t, err) + + // Should not be cleaned up yet + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + + // Run expiration + subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour) + require.Nil(t, err) + require.Len(t, subs, 1) + require.Equal(t, testWebPushEndpoint, subs[0].Endpoint) +} + +func TestSQLiteStore_RemoveExpiredSubscriptions(t *testing.T) { + store := newTestSQLiteStore(t) + defer store.Close() + + // Insert subscription with two topics + require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"})) + subs, err := store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 1) + + // Fake-mark them as expired + _, err = store.DB().Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint) + require.Nil(t, err) + + // Run expiration + require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour)) + + // List again, should be 0 + subs, err = store.SubscriptionsForTopic("topic1") + require.Nil(t, err) + require.Len(t, subs, 0) +} + +func newTestSQLiteStore(t *testing.T) *webpush.SQLiteStore { + store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "") + require.Nil(t, err) + return store +} diff --git a/webpush/store.go b/webpush/store.go new file mode 100644 index 00000000..a711b6bb --- /dev/null +++ b/webpush/store.go @@ -0,0 +1,51 @@ +package webpush + +import ( + "errors" + "net/netip" + "time" + + "heckel.io/ntfy/v2/log" +) + +const ( + SubscriptionIDPrefix = "wps_" + SubscriptionIDLength = 10 + SubscriptionEndpointLimitPerSubscriberIP = 10 +) + +var ( + ErrWebPushNoRows = errors.New("no rows found") + ErrWebPushTooManySubscriptions = errors.New("too many subscriptions") + ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty") +) + +// Store is the interface for a web push subscription store. +type Store interface { + UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error + SubscriptionsForTopic(topic string) ([]*Subscription, error) + SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) + MarkExpiryWarningSent(subscriptions []*Subscription) error + RemoveSubscriptionsByEndpoint(endpoint string) error + RemoveSubscriptionsByUserID(userID string) error + RemoveExpiredSubscriptions(expireAfter time.Duration) error + Close() error +} + +// Subscription represents a web push subscription. +type Subscription struct { + ID string + Endpoint string + Auth string + P256dh string + UserID string +} + +// Context returns the logging context for the subscription. +func (w *Subscription) Context() log.Context { + return map[string]any{ + "web_push_subscription_id": w.ID, + "web_push_subscription_user_id": w.UserID, + "web_push_subscription_endpoint": w.Endpoint, + } +}