Compare commits

..

52 Commits

Author SHA1 Message Date
binwiederhier
11c79a6369 Refine again 2026-03-01 14:24:06 -05:00
binwiederhier
10a6939d8e Rename queries for consistency 2026-03-01 13:47:42 -05:00
binwiederhier
9736973286 REmove store interface 2026-03-01 13:19:53 -05:00
binwiederhier
039566bcaf REmove unused 2026-03-01 11:59:24 -05:00
binwiederhier
544ce112b5 Manual review 2026-03-01 11:44:22 -05:00
binwiederhier
c19377109e Small refinements 2026-02-28 21:15:51 -05:00
binwiederhier
ccbd02331c Re-add execTx 2026-02-28 19:49:01 -05:00
binwiederhier
542aa403d2 Re-order functions 2026-02-28 19:34:09 -05:00
binwiederhier
ebb48e217d Merge user store and manager 2026-02-28 17:35:35 -05:00
binwiederhier
7710ace184 Self-review 2026-02-28 16:13:58 -05:00
binwiederhier
5c26e70fe7 Manual review: Make resetUserAccessTx 2026-02-26 17:34:19 -05:00
binwiederhier
c66fa92341 Re-add tx integrity 2026-02-26 16:54:56 -05:00
binwiederhier
a7d5a9c5d8 Sample 2026-02-23 23:08:13 -05:00
binwiederhier
391cd2c920 Migration tool 2026-02-23 22:44:21 -05:00
binwiederhier
9eec72adcc Fix race tests 2026-02-23 21:39:07 -05:00
binwiederhier
28a436c0d2 Optimizations 2026-02-23 14:11:36 -05:00
binwiederhier
b02366b42b Make more consistent 2026-02-23 13:49:54 -05:00
binwiederhier
90d0eca14d Consistency 2026-02-23 11:17:57 -05:00
binwiederhier
811c7ae25a TX fixes 2026-02-22 20:39:41 -05:00
binwiederhier
850a9d4cc4 Fix bug with provisioned token removal 2026-02-22 20:28:37 -05:00
binwiederhier
43280fbc0a Performance improvements in pg 2026-02-22 16:21:27 -05:00
binwiederhier
35a54407a8 Manual review 2026-02-22 15:12:54 -05:00
binwiederhier
f726cc768e Manual refactor 2026-02-22 14:25:15 -05:00
binwiederhier
5d301e7dce Merge branch 'main' of github.com:binwiederhier/ntfy into postgres-support 2026-02-22 14:15:42 -05:00
binwiederhier
8b12bdeb3a Release notes 2026-02-22 09:07:48 -05:00
binwiederhier
6375c2ce60 Renaming for consistency 2026-02-21 21:29:29 -05:00
binwiederhier
459c80ef9b Refactor tests again: forEachBackend 2026-02-21 21:02:41 -05:00
binwiederhier
b1eb90addc Move NewPollRequestMessage 2026-02-21 18:01:36 -05:00
binwiederhier
4b6979aa89 Move model constants 2026-02-21 10:42:34 -05:00
binwiederhier
c76e39bb0e Manual updates 2026-02-21 10:36:09 -05:00
binwiederhier
b82e1c3915 Release notes 2026-02-21 10:02:27 -05:00
binwiederhier
07e60ba041 Merge branch 'main' into postgres-support 2026-02-21 09:59:55 -05:00
binwiederhier
4e22e7f4c8 Release notes 2026-02-21 09:58:52 -05:00
binwiederhier
cf3ae187ce Merge branch 'main' of https://hosted.weblate.org/git/ntfy/web 2026-02-21 09:56:33 -05:00
Philipp C. Heckel
d19b825192 Merge pull request #1620 from uzkikh/fix/smtp-html-br-linebreaks
fix(smtp): preserve <br> line breaks in HTML emails
2026-02-21 09:56:24 -05:00
Ivan Uzkikh
2e499389fc fix(smtp): preserve <br> line breaks in HTML emails
HTML-only emails (e.g. from Synology DSM 7.3 Task Scheduler) use <br>
tags for line breaks. The existing implementation passed the raw HTML
body to bluemonday with AddSpaceWhenStrippingTag(true), which replaced
every tag including <br> with a space, causing all content to appear
on a single unreadable line.

Fix: convert <br> tags to newlines before stripping HTML, so line break
semantics are preserved in the notification body.

Resolves the gap noted in #690 ("very sub-par" HTML support).
2026-02-21 14:44:06 +01:00
binwiederhier
7c69a76345 Docs 2026-02-20 16:56:47 -05:00
binwiederhier
a28d8e7924 Update docs, manual changes 2026-02-20 16:36:41 -05:00
binwiederhier
13a3062a7f Re-add comments 2026-02-20 16:15:07 -05:00
binwiederhier
eb6e1ac44a Docs 2026-02-20 15:57:47 -05:00
binwiederhier
a4c836b531 Refine 2026-02-20 15:36:12 -05:00
binwiederhier
e818b063f7 Fmt 2026-02-20 08:25:32 -05:00
binwiederhier
039d555689 Fix tests 2026-02-20 08:19:25 -05:00
binwiederhier
209d5a4c62 Update parallelism 2026-02-19 22:42:08 -05:00
binwiederhier
bf265449ac Docs 2026-02-19 22:38:27 -05:00
binwiederhier
4cbd80c68e Use one PG connection, add support for connection params 2026-02-19 22:34:53 -05:00
binwiederhier
305e3fc9af DB conn to 25 2026-02-19 21:36:09 -05:00
Philipp C. Heckel
93cd7f99f8 Merge pull request #1618 from PanderMusubi/doc_zabbix
add zabbix-ntfy to intergration doc
2026-02-19 11:00:21 -05:00
PanderMusubi
28e85df36e add zabbix-ntfy to intergration doc 2026-02-19 15:47:49 +01:00
Philipp C. Heckel
2bc7b5217b Merge pull request #1615 from benkenawell/example/zsh-terminal
Add comment for zsh users in the terminal example
2026-02-18 20:12:49 -05:00
Benjamin Kenawell
046c0e8c79 Add comment for zsh users in the terminal example
Bash and Zsh terminals' history works slightly different. In bash, the
current command is in the history already. In zsh, the
history command does not have the currently running command in it.  But
it does set the HISTCMD variable to the entry number for the current
command. So we can load the current command with `history "$HISTCMD"` in
zsh and apply the same sed filter to get the same result as the bash
`history | tail -n1` command.

I think it makes more sense to leave this as a comment in the current
example than to create a new example, since it otherwise works the exact
same.  I've added this to my workflow locally!
2026-02-18 08:07:08 -05:00
Yaron Shahrabani
652b2097ad Translated using Weblate (Hebrew)
Currently translated at 17.4% (71 of 407 strings)

Translation: ntfy/Web app
Translate-URL: https://hosted.weblate.org/projects/ntfy/web/he/
2026-02-18 13:09:47 +01:00
65 changed files with 5624 additions and 5454 deletions

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@ build/
server/docs/
server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
playground/
secrets/
*.iml

View File

@@ -268,10 +268,10 @@ check: test web-fmt-check fmt-check vet web-lint lint staticcheck
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
test: .PHONY
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -parallel 3 $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
testv: .PHONY
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
go test -v -parallel 3 $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
race: .PHONY
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')

View File

@@ -282,7 +282,9 @@ func execServe(c *cli.Context) error {
}
// Check values
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
return errors.New("if set, FCM key file must exist")
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
@@ -323,8 +325,8 @@ func execServe(c *cli.Context) error {
return errors.New("if upstream-base-url is set, base-url must also be set")
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
} else if authFile == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
} else if authFile == "" && databaseURL == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file or database-url is not set")
} else if enableSignup && !enableLogin {
return errors.New("cannot set enable-signup without also setting enable-login")
} else if requireLogin && !enableLogin {
@@ -333,8 +335,8 @@ func execServe(c *cli.Context) error {
return errors.New("cannot set stripe-secret-key or stripe-webhook-key, support for payments is not available in this build (nopayments)")
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || authFile == "") {
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file must also be set")
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || (authFile == "" && databaseURL == "")) {
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file (or database-url) must also be set")
} else if messageSizeLimit > server.DefaultMessageSizeLimit {
log.Warn("message-size-limit is greater than 4K, this is not recommended and largely untested, and may lead to issues with some clients")
if messageSizeLimit > 5*1024*1024 {
@@ -414,6 +416,15 @@ func execServe(c *cli.Context) error {
payments.Setup(stripeSecretKey)
}
// Parse Twilio template
var twilioCallFormatTemplate *template.Template
if twilioCallFormat != "" {
twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat)
if err != nil {
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
}
}
// Add default forbidden topics
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
@@ -461,13 +472,7 @@ func execServe(c *cli.Context) error {
conf.TwilioAuthToken = twilioAuthToken
conf.TwilioPhoneNumber = twilioPhoneNumber
conf.TwilioVerifyService = twilioVerifyService
if twilioCallFormat != "" {
tmpl, err := template.New("twiml").Parse(twilioCallFormat)
if err != nil {
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
}
conf.TwilioCallFormat = tmpl
}
conf.TwilioCallFormat = twilioCallFormatTemplate
conf.MessageSizeLimit = int(messageSizeLimit)
conf.MessageDelayMax = messageDelayLimit
conf.TotalTopicLimit = totalTopicLimit

View File

@@ -6,13 +6,14 @@ import (
"crypto/subtle"
"errors"
"fmt"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"os"
"strings"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/server"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -377,21 +378,19 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
BcryptCost: user.DefaultUserPasswordBcryptCost,
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
}
var store user.Store
if databaseURL != "" {
store, err = user.NewPostgresStore(databaseURL)
pool, dbErr := db.OpenPostgres(databaseURL)
if dbErr != nil {
return nil, dbErr
}
return user.NewPostgresManager(pool, authConfig)
} else if authFile != "" {
if !util.FileExists(authFile) {
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
}
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
} else {
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
return user.NewSQLiteManager(authFile, authStartupQueries, authConfig)
}
if err != nil {
return nil, err
}
return user.NewManager(store, authConfig)
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
}
func readPasswordAndConfirm(c *cli.Context) (string, error) {

93
db/db.go Normal file
View File

@@ -0,0 +1,93 @@
package db
import (
"database/sql"
"fmt"
"net/url"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
paramMaxOpenConns = "pool_max_conns"
paramMaxIdleConns = "pool_max_idle_conns"
paramConnMaxLifetime = "pool_conn_max_lifetime"
paramConnMaxIdleTime = "pool_conn_max_idle_time"
defaultMaxOpenConns = 10
)
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver.
func OpenPostgres(dsn string) (*sql.DB, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err)
}
q := u.Query()
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
if err != nil {
return nil, err
}
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
if err != nil {
return nil, err
}
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
if err != nil {
return nil, err
}
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
if err != nil {
return nil, err
}
u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String())
if err != nil {
return nil, err
}
db.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns)
}
if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime)
}
if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
return db, nil
}
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
v, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return v, nil
}
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
s := q.Get(key)
if s == "" {
return defaultValue, nil
}
q.Del(key)
d, err := time.ParseDuration(s)
if err != nil {
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
}
return d, nil
}

63
db/test/test.go Normal file
View File

@@ -0,0 +1,63 @@
package dbtest
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util"
)
const testPoolMaxConns = "2"
// CreateTestPostgresSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it.
// It registers a cleanup function to drop the schema when the test finishes.
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
func CreateTestPostgresSchema(t *testing.T) string {
t.Helper()
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set")
}
schema := fmt.Sprintf("test_%s", util.RandomString(10))
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode()
dsn = u.String()
setupDB, err := db.OpenPostgres(dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
q.Set("search_path", schema)
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanDB, err := db.OpenPostgres(dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
return schemaDSN
}
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
func CreateTestPostgres(t *testing.T) *sql.DB {
t.Helper()
schemaDSN := CreateTestPostgresSchema(t)
testDB, err := db.OpenPostgres(schemaDSN)
require.Nil(t, err)
t.Cleanup(func() {
testDB.Close()
})
return testDB
}

View File

@@ -53,6 +53,16 @@ Here are a few working sample configs using a `/etc/ntfy/server.yml` file:
behind-proxy: true
```
=== "server.yml (PostgreSQL, behind proxy)"
``` yaml
base-url: "https://ntfy.example.com"
listen-http: ":2586"
database-url: "postgres://ntfy:mypassword@db.example.com:5432/ntfy?sslmode=require"
attachment-cache-dir: "/var/cache/ntfy/attachments"
behind-proxy: true
auth-default-access: "deny-all"
```
=== "server.yml (ntfy.sh config)"
``` yaml
# All the things: Behind a proxy, Firebase, cache, attachments,
@@ -125,16 +135,63 @@ using Docker Compose (i.e. `docker-compose.yml`):
command: serve
```
## Database options
ntfy uses a database for storing messages ([message cache](#message-cache)), users and [access control](#access-control), and [web push](#web-push) subscriptions.
You can choose between **SQLite** and **PostgreSQL** as the database backend.
### SQLite
By default, ntfy uses SQLite with separate database files for each store. This is the simplest setup and requires
no external dependencies:
* `cache-file`: Database file for the [message cache](#message-cache).
* `auth-file`: Database file for authentication and [access control](#access-control). If set, enables auth.
* `web-push-file`: Database file for [web push](#web-push) subscriptions.
### PostgreSQL (EXPERIMENTAL)
As an alternative, you can configure ntfy to use PostgreSQL for **all** database-backed stores by setting the
`database-url` option to a PostgreSQL connection string:
```yaml
database-url: "postgres://user:pass@host:5432/ntfy"
```
When `database-url` is set, ntfy will use PostgreSQL for the [message cache](#message-cache),
[access control](#access-control), and [web push](#web-push) subscriptions instead of SQLite. The `cache-file`,
`auth-file`, and `web-push-file` options **must not** be set in this case.
Note that setting `database-url` implicitly enables authentication and access control (equivalent to setting
`auth-file` with SQLite). The default access is `read-write`, so anonymous users can still read and write to all
topics. To restrict access, set `auth-default-access` to `deny-all` (see [access control](#access-control)).
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
The database URL supports the standard [PostgreSQL connection parameters](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
as query parameters, such as `sslmode`, `connect_timeout`, `sslcert`, `sslkey`, `sslrootcert`, and `application_name`.
See the [pgx driver documentation](https://pkg.go.dev/github.com/jackc/pgx/v5) for the full list of supported parameters.
In addition, ntfy supports the following custom query parameters to tune the connection pool:
| Parameter | Default | Description |
|---------------------------|---------|----------------------------------------------------------------------------------|
| `pool_max_conns` | 10 | Maximum number of open connections to the database |
| `pool_max_idle_conns` | - | Maximum number of idle connections in the pool |
| `pool_conn_max_lifetime` | - | Maximum amount of time a connection may be reused (Go duration, e.g. `5m`, `1h`) |
| `pool_conn_max_idle_time` | - | Maximum amount of time a connection may be idle (Go duration, e.g. `30s`, `5m`) |
Example:
```yaml
database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50&pool_conn_max_idle_time=5m"
```
## Message cache
If desired, ntfy can temporarily keep notifications in an in-memory or an on-disk cache. Caching messages for a short period
of time is important to allow [phones](subscribe/phone.md) and other devices with brittle Internet connections to be able to retrieve
notifications that they may have missed.
By default, ntfy keeps messages **in-memory for 12 hours**, which means that **cached messages do not survive an application
restart**. You can override this behavior using the following config settings:
restart**. You can override this behavior by setting `cache-file` (SQLite) or `database-url` (PostgreSQL).
* `cache-file`: if set, ntfy will store messages in a SQLite based cache (default is empty, which means in-memory cache).
**This is required if you'd like messages to be retained across restarts**.
* `cache-duration`: defines the duration for which messages are stored in the cache (default is `12h`).
You can also entirely disable the cache by setting `cache-duration` to `0`. When the cache is disabled, messages are only
@@ -144,20 +201,6 @@ the message to the subscribers.
Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscribe/api.md#poll-for-messages), as well as the
[`since=` parameter](subscribe/api.md#fetch-cached-messages).
## PostgreSQL database
By default, ntfy uses SQLite for all database-backed stores. As an alternative, you can configure ntfy to use PostgreSQL
by setting the `database-url` option to a PostgreSQL connection string:
```yaml
database-url: "postgres://user:pass@host:5432/ntfy"
```
When `database-url` is set, ntfy will use PostgreSQL for the web push subscription store instead of SQLite. The
`web-push-file` option is not required in this case. Support for PostgreSQL for the message cache and user manager
will be added in future releases.
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
## Attachments
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
@@ -199,14 +242,15 @@ and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is
By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how
ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization.
ntfy's auth is implemented with a simple [SQLite](https://www.sqlite.org/)-based backend. It implements two roles
(`user` and `admin`) and per-topic `read` and `write` permissions using an [access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list).
Access control entries can be applied to users as well as the special everyone user (`*`), which represents anonymous API access.
ntfy's auth implements two roles (`user` and `admin`) and per-topic `read` and `write` permissions using an
[access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list). Access control entries can be applied
to users as well as the special everyone user (`*`), which represents anonymous API access.
To set up auth, **configure the following options**:
* `auth-file` is the user/access database; it is created automatically if it doesn't already exist; suggested
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used)
* `auth-file` is the user/access database (SQLite); it is created automatically if it doesn't already exist; suggested
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used). Alternatively, if `database-url` is set,
auth is automatically enabled using PostgreSQL (see [database options](#database-options)).
* `auth-default-access` defines the default/fallback access if no access control entry is found; it can be
set to `read-write` (default), `read-only`, `write-only` or `deny-all`. **If you are setting up a private instance,
you'll want to set this to `deny-all`** (see [private instance example](#example-private-instance)).
@@ -1161,8 +1205,8 @@ a database to keep track of the browser's subscriptions, and an admin email addr
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
Alternatively, you can use PostgreSQL instead of SQLite for the web push subscription store by setting `database-url`
(see [PostgreSQL database](#postgresql-database)).
Alternatively, you can use PostgreSQL instead of SQLite by setting `database-url`
(see [PostgreSQL database](#postgresql-experimental)).
Limitations:
@@ -1773,13 +1817,13 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for database-backed stores instead of SQLite. Currently applies to the web push store. See [PostgreSQL database](#postgresql-database). |
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for all database-backed stores (message cache, user manager, web push) instead of SQLite. See [database options](#database-options). |
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) |
| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) |
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control. If set, enables authentication and access control. See [access control](#access-control). |
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). |
| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. |
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |

View File

@@ -661,6 +661,8 @@ Add the following function and alias to your `.bashrc` or `.bash_profile`:
local token=$(< ~/.ntfy_token) # Securely read the token
local status_icon="$([ $exit_status -eq 0 ] && echo magic_wand || echo warning)"
local last_command=$(history | tail -n1 | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
# for zsh users, use the same sed pattern but get the history differently.
# local last_command=$(history "$HISTCMD" | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
curl -s -X POST "https://n.example.dev/alerts" \
-H "Authorization: Bearer $token" \
@@ -692,4 +694,4 @@ To test failure notifications:
false; alert # Always fails (exit 1)
ls --invalid; alert # Invalid option
cat nonexistent_file; alert # File not found
```
```

View File

@@ -185,6 +185,7 @@ I've added a ⭐ to projects or posts that have a significant following, or had
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
- [zabbix-ntfy](https://github.com/torgrimt/zabbix-ntfy) - Zabbix server Mediatype to add support for ntfy.sh services
## Blog + forum posts

View File

@@ -7,11 +7,32 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
| Component | Version | Release date |
|------------------|---------|--------------|
| ntfy server | v2.17.0 | Feb 8, 2026 |
| ntfy Android app | v1.22.2 | Jan 25, 2026 |
| ntfy Android app | v1.23.0 | Deb 22, 2026 |
| ntfy iOS app | v1.3 | Nov 26, 2023 |
Please check out the release notes for [upcoming releases](#not-released-yet) below.
## ntfy Android v1.23.0
Released February 22, 2026
This release adds support for search within a topic, and adds [copy action](publish.md#copy-to-clipboard) support
to the Android app.
**Features:**
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
**Bug fixes + maintenance:**
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
## ntfy server v2.17.0
Released February 8, 2026
@@ -1698,26 +1719,18 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
## Not released yet
### ntfy Android v1.23.x (UNRELEASED)
### ntfy server v2.18.x (UNRELEASED)
**Features:**
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
* Add experimental [PostgreSQL support](config.md#postgresql-experimental) as an alternative database backend (message cache, user manager, web push subscriptions) via `database-url` config option ([#1114](https://github.com/binwiederhier/ntfy/issues/1114), thanks to [@brettinternet](https://github.com/brettinternet) for reporting)
**Bug fixes + maintenance:**
* Preserve `<br>` line breaks in HTML-only emails received via SMTP ([#690](https://github.com/binwiederhier/ntfy/issues/690), [#1620](https://github.com/binwiederhier/ntfy/pull/1620), thanks to [@uzkikh](https://github.com/uzkikh) for the fix and to [@teastrainer](https://github.com/teastrainer) for reporting)
### ntfy Android v1.24.x (UNRELEASED)
**Bug fixes + maintenance:**
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
* Fix crash in settings when fragment is detached during backup/restore or log operations
### ntfy server v2.12.x (UNRELEASED)
**Features:**
* Add PostgreSQL as an alternative database backend for the web push subscription store via `database-url` config option

View File

@@ -4,7 +4,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
"strings"
"sync"
@@ -21,39 +20,13 @@ const (
var errNoRows = errors.New("no rows found")
// Store is the interface for a message cache store
type Store interface {
AddMessage(m *model.Message) error
AddMessages(ms []*model.Message) error
DB() *sql.DB
Message(id string) (*model.Message, error)
MessageCounts() (map[string]int, error)
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
MessagesDue() ([]*model.Message, error)
MessagesExpired() ([]string, error)
MarkPublished(m *model.Message) error
UpdateMessageTime(messageID string, timestamp int64) error
Topics() ([]string, error)
DeleteMessages(ids ...string) error
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
ExpireMessages(topics ...string) error
AttachmentsExpired() ([]string, error)
MarkAttachmentsDeleted(ids ...string) error
AttachmentBytesUsedBySender(sender string) (int64, error)
AttachmentBytesUsedByUser(userID string) (int64, error)
UpdateStats(messages int64) error
Stats() (int64, error)
Close() error
}
// storeQueries holds the database-specific SQL queries
type storeQueries struct {
// queries holds the database-specific SQL queries
type queries struct {
insertMessage string
deleteMessage string
selectScheduledMessageIDsBySeqID string
deleteScheduledBySequenceID string
updateMessagesForTopicExpiry string
selectRowIDFromMessageID string
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeScheduled string
@@ -64,7 +37,6 @@ type storeQueries struct {
selectMessagesExpired string
updateMessagePublished string
selectMessagesCount string
selectMessageCountPerTopic string
selectTopics string
updateAttachmentDeleted string
selectAttachmentsExpired string
@@ -75,37 +47,46 @@ type storeQueries struct {
updateMessageTime string
}
// commonStore implements store operations that are identical across database backends
type commonStore struct {
// Cache stores published messages
type Cache struct {
db *sql.DB
queue *util.BatchingQueue[*model.Message]
nop bool
mu sync.Mutex
queries storeQueries
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
queries queries
}
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
var queue *util.BatchingQueue[*model.Message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
}
c := &commonStore{
c := &Cache{
db: db,
queue: queue,
nop: nop,
mu: mu,
queries: queries,
}
go c.processMessageBatches()
return c
}
// DB returns the underlying database connection
func (c *commonStore) DB() *sql.DB {
return c.db
func (c *Cache) maybeLock() {
if c.mu != nil {
c.mu.Lock()
}
}
func (c *Cache) maybeUnlock() {
if c.mu != nil {
c.mu.Unlock()
}
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
func (c *commonStore) AddMessage(m *model.Message) error {
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
func (c *Cache) AddMessage(m *model.Message) error {
if c.queue != nil {
c.queue.Enqueue(m)
return nil
@@ -114,13 +95,13 @@ func (c *commonStore) AddMessage(m *model.Message) error {
}
// AddMessages synchronously stores a batch of messages to the message cache
func (c *commonStore) AddMessages(ms []*model.Message) error {
func (c *Cache) AddMessages(ms []*model.Message) error {
return c.addMessages(ms)
}
func (c *commonStore) addMessages(ms []*model.Message) error {
c.mu.Lock()
defer c.mu.Unlock()
func (c *Cache) addMessages(ms []*model.Message) error {
c.maybeLock()
defer c.maybeUnlock()
if c.nop {
return nil
}
@@ -204,7 +185,8 @@ func (c *commonStore) addMessages(ms []*model.Message) error {
return nil
}
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
if since.IsNone() {
return make([]*model.Message, 0), nil
} else if since.IsLatest() {
@@ -215,7 +197,7 @@ func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled
return c.messagesSinceTime(topic, since, scheduled)
}
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
if scheduled {
@@ -229,25 +211,13 @@ func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, s
return readMessages(rows)
}
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
if err != nil {
return nil, err
}
defer idrows.Close()
if !idrows.Next() {
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
}
var rowID int64
if err := idrows.Scan(&rowID); err != nil {
return nil, err
}
idrows.Close()
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
}
if err != nil {
return nil, err
@@ -255,7 +225,7 @@ func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, sch
return readMessages(rows)
}
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
@@ -263,7 +233,8 @@ func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
return readMessages(rows)
}
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
// MessagesDue returns all messages that are due for publishing
func (c *Cache) MessagesDue() ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
if err != nil {
return nil, err
@@ -272,7 +243,7 @@ func (c *commonStore) MessagesDue() ([]*model.Message, error) {
}
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
func (c *commonStore) MessagesExpired() ([]string, error) {
func (c *Cache) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
if err != nil {
return nil, err
@@ -292,7 +263,8 @@ func (c *commonStore) MessagesExpired() ([]string, error) {
return ids, nil
}
func (c *commonStore) Message(id string) (*model.Message, error) {
// Message returns the message with the given ID, or ErrMessageNotFound if not found
func (c *Cache) Message(id string) (*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
if err != nil {
return nil, err
@@ -305,39 +277,40 @@ func (c *commonStore) Message(id string) (*model.Message, error) {
}
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
return err
}
func (c *commonStore) MarkPublished(m *model.Message) error {
c.mu.Lock()
defer c.mu.Unlock()
// MarkPublished marks a message as published
func (c *Cache) MarkPublished(m *model.Message) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
return err
}
func (c *commonStore) MessageCounts() (map[string]int, error) {
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
// MessagesCount returns the total number of messages in the cache
func (c *Cache) MessagesCount() (int, error) {
rows, err := c.db.Query(c.queries.selectMessagesCount)
if err != nil {
return nil, err
return 0, err
}
defer rows.Close()
var topic string
var count int
counts := make(map[string]int)
for rows.Next() {
if err := rows.Scan(&topic, &count); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
counts[topic] = count
if !rows.Next() {
return 0, errNoRows
}
return counts, nil
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (c *commonStore) Topics() ([]string, error) {
// Topics returns a list of all topics with messages in the cache
func (c *Cache) Topics() ([]string, error) {
rows, err := c.db.Query(c.queries.selectTopics)
if err != nil {
return nil, err
@@ -357,9 +330,10 @@ func (c *commonStore) Topics() ([]string, error) {
return topics, nil
}
func (c *commonStore) DeleteMessages(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
// DeleteMessages deletes the messages with the given IDs
func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
@@ -375,14 +349,15 @@ func (c *commonStore) DeleteMessages(ids ...string) error {
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.mu.Lock()
defer c.mu.Unlock()
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
// First, get the message IDs of scheduled messages to be deleted
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
@@ -399,7 +374,8 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close()
rows.Close() // Close rows before executing delete in same transaction
// Then delete the messages
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
@@ -409,9 +385,10 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s
return ids, nil
}
func (c *commonStore) ExpireMessages(topics ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
@@ -425,7 +402,8 @@ func (c *commonStore) ExpireMessages(topics ...string) error {
return tx.Commit()
}
func (c *commonStore) AttachmentsExpired() ([]string, error) {
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
func (c *Cache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
if err != nil {
return nil, err
@@ -445,9 +423,10 @@ func (c *commonStore) AttachmentsExpired() ([]string, error) {
return ids, nil
}
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock()
defer c.maybeUnlock()
tx, err := c.db.Begin()
if err != nil {
return err
@@ -461,7 +440,8 @@ func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
return tx.Commit()
}
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil {
return 0, err
@@ -469,7 +449,8 @@ func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error)
return c.readAttachmentBytesUsed(rows)
}
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil {
return 0, err
@@ -477,7 +458,7 @@ func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
return c.readAttachmentBytesUsed(rows)
}
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close()
var size int64
if !rows.Next() {
@@ -491,14 +472,16 @@ func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
return size, nil
}
func (c *commonStore) UpdateStats(messages int64) error {
c.mu.Lock()
defer c.mu.Unlock()
// UpdateStats updates the total message count statistic
func (c *Cache) UpdateStats(messages int64) error {
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
func (c *commonStore) Stats() (messages int64, err error) {
// Stats returns the total message count statistic
func (c *Cache) Stats() (messages int64, err error) {
rows, err := c.db.Query(c.queries.selectStats)
if err != nil {
return 0, err
@@ -513,11 +496,12 @@ func (c *commonStore) Stats() (messages int64, err error) {
return messages, nil
}
func (c *commonStore) Close() error {
// Close closes the underlying database connection
func (c *Cache) Close() error {
return c.db.Close()
}
func (c *commonStore) processMessageBatches() {
func (c *Cache) processMessageBatches() {
if c.queue == nil {
return
}
@@ -620,9 +604,3 @@ func readMessage(rows *sql.Rows) (*model.Message, error) {
Encoding: encoding,
}, nil
}
// Ensure commonStore implements Store
var _ Store = (*commonStore)(nil)
// Needed by store.go but not part of Store interface; unused import guard
var _ = fmt.Sprintf

110
message/cache_postgres.go Normal file
View File

@@ -0,0 +1,110 @@
package message
import (
"database/sql"
"time"
)
// PostgreSQL runtime query constants
const (
postgresInsertMessageQuery = `
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
`
postgresDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
postgresSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
postgresDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
postgresUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
postgresSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE mid = $1
`
postgresSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`
postgresSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`
postgresSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1
AND id > COALESCE((SELECT id FROM message WHERE mid = $2), 0)
AND published = TRUE
ORDER BY time, id
`
postgresSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1
AND (id > COALESCE((SELECT id FROM message WHERE mid = $2), 0) OR published = FALSE)
ORDER BY time, id
`
postgresSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`
postgresSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`
postgresSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
postgresUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
postgresSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
postgresSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
postgresUpdateAttachmentDeletedQuery = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
postgresSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
postgresSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
postgresSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
)
var pgQueries = queries{
insertMessage: postgresInsertMessageQuery,
deleteMessage: postgresDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: postgresDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: postgresUpdateMessagesForTopicExpiryQuery,
selectMessagesByID: postgresSelectMessagesByIDQuery,
selectMessagesSinceTime: postgresSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: postgresSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: postgresSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: postgresSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: postgresSelectMessagesLatestQuery,
selectMessagesDue: postgresSelectMessagesDueQuery,
selectMessagesExpired: postgresSelectMessagesExpiredQuery,
updateMessagePublished: postgresUpdateMessagePublishedQuery,
selectMessagesCount: postgresSelectMessagesCountQuery,
selectTopics: postgresSelectTopicsQuery,
updateAttachmentDeleted: postgresUpdateAttachmentDeletedQuery,
selectAttachmentsExpired: postgresSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: postgresSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: postgresSelectAttachmentsSizeByUserIDQuery,
selectStats: postgresSelectStatsQuery,
updateStats: postgresUpdateStatsQuery,
updateMessageTime: postgresUpdateMessageTimeQuery,
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
if err := setupPostgres(db); err != nil {
return nil, err
}
return newCache(db, pgQueries, nil, batchSize, batchTimeout, false), nil
}

View File

@@ -7,7 +7,7 @@ import (
// Initial PostgreSQL schema
const (
pgCreateTablesQuery = `
postgresCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS message (
id BIGSERIAL PRIMARY KEY,
mid TEXT NOT NULL,
@@ -37,12 +37,10 @@ const (
);
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
CREATE TABLE IF NOT EXISTS message_stats (
key TEXT PRIMARY KEY,
value BIGINT
@@ -57,14 +55,14 @@ const (
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 14
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
pgCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgresDB(db *sql.DB) error {
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewPostgresDB(db)
}
@@ -80,10 +78,10 @@ func setupNewPostgresDB(db *sql.DB) error {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"fmt"
"path/filepath"
"sync"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
@@ -20,7 +21,6 @@ const (
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
sqliteSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
@@ -41,13 +41,13 @@ const (
sqliteSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
WHERE topic = ? AND id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) AND published = 1
ORDER BY time, id
`
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
WHERE topic = ? AND (id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) OR published = 0)
ORDER BY time, id
`
sqliteSelectMessagesLatestQuery = `
@@ -63,13 +63,12 @@ const (
WHERE time <= ? AND published = 0
ORDER BY time, id
`
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
sqliteUpdateAttachmentDeletedQuery = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
@@ -79,13 +78,12 @@ const (
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
)
var sqliteQueries = storeQueries{
var sqliteQueries = queries{
insertMessage: sqliteInsertMessageQuery,
deleteMessage: sqliteDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
selectMessagesByID: sqliteSelectMessagesByIDQuery,
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
@@ -96,9 +94,8 @@ var sqliteQueries = storeQueries{
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
selectMessagesCount: sqliteSelectMessagesCountQuery,
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
selectTopics: sqliteSelectTopicsQuery,
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
updateAttachmentDeleted: sqliteUpdateAttachmentDeletedQuery,
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
@@ -108,7 +105,7 @@ var sqliteQueries = storeQueries{
}
// NewSQLiteStore creates a SQLite file-backed cache
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) {
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
@@ -120,21 +117,26 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
}
// NewMemStore creates an in-memory cache
func NewMemStore() (Store, error) {
func NewMemStore() (*Cache, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
}
// NewNopStore creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func NewNopStore() (Store, error) {
func NewNopStore() (*Cache, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
}
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
// sql database, so if the stdlib's sql engine happens to open another connection and
// you've only specified ":memory:", that connection will see a brand new database.
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
// Every connection to this string will point to the same in-memory database."
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}

View File

@@ -10,7 +10,7 @@ import (
// Initial SQLite schema
const (
sqliteCreateMessagesTableQuery = `
sqliteCreateTablesQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -65,8 +65,8 @@ const (
version INT NOT NULL
);
`
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
@@ -259,13 +259,13 @@ func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration)
}
func setupNewSQLite(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
if _, err := db.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
@@ -288,7 +288,7 @@ func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, 1); err != nil {
return err
}
return nil
@@ -299,7 +299,7 @@ func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
return err
}
return nil
@@ -310,7 +310,7 @@ func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
return err
}
return nil
@@ -321,7 +321,7 @@ func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
return err
}
return nil
@@ -332,7 +332,7 @@ func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
return err
}
return nil
@@ -343,7 +343,7 @@ func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
return err
}
return nil
@@ -354,7 +354,7 @@ func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 7); err != nil {
return err
}
return nil
@@ -365,7 +365,7 @@ func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 8); err != nil {
return err
}
return nil
@@ -376,7 +376,7 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 9); err != nil {
return err
}
return nil
@@ -395,7 +395,7 @@ func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
return err
}
return tx.Commit()
@@ -411,7 +411,7 @@ func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
return err
}
return tx.Commit()
@@ -427,7 +427,7 @@ func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
return err
}
return tx.Commit()
@@ -443,7 +443,7 @@ func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
return err
}
return tx.Commit()
@@ -459,7 +459,7 @@ func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
return err
}
return tx.Commit()

View File

@@ -13,158 +13,6 @@ import (
"heckel.io/ntfy/v2/model"
)
func TestSqliteStore_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestStore(t))
}
func TestMemStore_Messages(t *testing.T) {
testCacheMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestStore(t))
}
func TestSqliteStore_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestStore(t))
}
func TestMemStore_Topics(t *testing.T) {
testCacheTopics(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestStore(t))
}
func TestSqliteStore_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestStore(t))
}
func TestMemStore_Prune(t *testing.T) {
testCachePrune(t, newMemTestStore(t))
}
func TestSqliteStore_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestStore(t))
}
func TestMemStore_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestStore(t))
}
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
}
func TestMemStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestStore(t))
}
func TestSqliteStore_Sender(t *testing.T) {
testSender(t, newSqliteTestStore(t))
}
func TestMemStore_Sender(t *testing.T) {
testSender(t, newMemTestStore(t))
}
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
}
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
}
func TestSqliteStore_MessageByID(t *testing.T) {
testMessageByID(t, newSqliteTestStore(t))
}
func TestMemStore_MessageByID(t *testing.T) {
testMessageByID(t, newMemTestStore(t))
}
func TestSqliteStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newSqliteTestStore(t))
}
func TestMemStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newMemTestStore(t))
}
func TestSqliteStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newSqliteTestStore(t))
}
func TestMemStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
}
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newMemTestStore(t))
}
func TestSqliteStore_Stats(t *testing.T) {
testStats(t, newSqliteTestStore(t))
}
func TestMemStore_Stats(t *testing.T) {
testStats(t, newMemTestStore(t))
}
func TestSqliteStore_AddMessages(t *testing.T) {
testAddMessages(t, newSqliteTestStore(t))
}
func TestMemStore_AddMessages(t *testing.T) {
testAddMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newMemTestStore(t))
}
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
}
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newMemTestStore(t))
}
func TestSqliteStore_Migration_From0(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
@@ -419,32 +267,17 @@ func TestNopStore(t *testing.T) {
require.Empty(t, topics)
}
func newSqliteTestStore(t *testing.T) message.Store {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newSqliteTestStoreFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache {
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newMemTestStore(t *testing.T) message.Store {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func checkSqliteSchemaVersion(t *testing.T, filename string) {
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)

829
message/cache_test.go Normal file
View File

@@ -0,0 +1,829 @@
package message_test
import (
"net/netip"
"path/filepath"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func newSqliteTestStore(t *testing.T) *message.Cache {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newMemTestStore(t *testing.T) *message.Cache {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newTestPostgresStore(t *testing.T) *message.Cache {
testDB := dbtest.CreateTestPostgres(t)
store, err := message.NewPostgresStore(testDB, 0, 0)
require.Nil(t, err)
return store
}
func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) {
t.Run("sqlite", func(t *testing.T) {
f(t, newSqliteTestStore(t))
})
t.Run("mem", func(t *testing.T) {
f(t, newMemTestStore(t))
})
t.Run("postgres", func(t *testing.T) {
f(t, newTestPostgresStore(t))
})
}
func TestStore_Messages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
require.Nil(t, s.AddMessage(m2))
// Adding invalid
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
// count
count, err := s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 3, count)
// mytopic: since all
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, model.MessageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: since all
messages, _ = s.Messages("example", model.SinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: since all
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
require.Empty(t, messages)
})
}
func TestStore_MessagesLock(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
})
}
func TestStore_MessagesScheduled(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := model.NewDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = s.MessagesDue()
require.Empty(t, messages)
})
}
func TestStore_Topics(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
topics, err := s.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Contains(t, topics, "topic1")
require.Contains(t, topics, "topic2")
})
}
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m := model.NewDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, s.AddMessage(m))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
})
}
func TestStore_MessagesSinceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := model.NewDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := model.NewDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := model.NewDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := model.NewDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
require.Nil(t, s.AddMessage(m4))
require.Nil(t, s.AddMessage(m5))
require.Nil(t, s.AddMessage(m6))
require.Nil(t, s.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
})
}
func TestStore_Prune(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
now := time.Now().Unix()
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := model.NewDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
count, err := s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 3, count)
expiredMessageIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
count, err = s.MessagesCount()
require.Nil(t, err)
require.Equal(t, 1, count)
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
})
}
func TestStore_Attachments(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, s.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &model.Attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
})
}
func TestStore_AttachmentsExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
})
}
func TestStore_Sender(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
m1 := model.NewDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, s.AddMessage(m1))
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, s.AddMessage(m2))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
})
}
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Create a scheduled (unpublished) message
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, s.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, s.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
})
}
func TestStore_MessageByID(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message
m := model.NewDefaultMessage("mytopic", "some message")
m.Title = "some title"
m.Priority = 4
m.Tags = []string{"tag1", "tag2"}
require.Nil(t, s.AddMessage(m))
// Retrieve by ID
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "some message", retrieved.Message)
require.Equal(t, "some title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
// Non-existent ID returns ErrMessageNotFound
_, err = s.Message("doesnotexist")
require.Equal(t, model.ErrMessageNotFound, err)
})
}
func TestStore_MarkPublished(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a scheduled message (future time -> unpublished)
m := model.NewDefaultMessage("mytopic", "scheduled message")
m.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
// Verify it does not appear in non-scheduled queries
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(messages))
// Verify it does appear in scheduled queries
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Mark as published
require.Nil(t, s.MarkPublished(m))
// Now it should appear in non-scheduled queries too
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "scheduled message", messages[0].Message)
})
}
func TestStore_ExpireMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add messages to two topics
m1 := model.NewDefaultMessage("topic1", "message 1")
m1.Expires = time.Now().Add(time.Hour).Unix()
m2 := model.NewDefaultMessage("topic1", "message 2")
m2.Expires = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("topic2", "message 3")
m3.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
// Verify all messages exist
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Expire topic1 messages
require.Nil(t, s.ExpireMessages("topic1"))
// topic1 messages should now be expired (expires set to past)
expiredIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Equal(t, 2, len(expiredIDs))
sort.Strings(expiredIDs)
expectedIDs := []string{m1.ID, m2.ID}
sort.Strings(expectedIDs)
require.Equal(t, expectedIDs, expiredIDs)
// topic2 should be unaffected
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 3", messages[0].Message)
})
}
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message with an expired attachment (file needs cleanup)
m1 := model.NewDefaultMessage("mytopic", "old file")
m1.ID = "msg1"
m1.SequenceID = "msg1"
m1.Expires = time.Now().Add(time.Hour).Unix()
m1.Attachment = &model.Attachment{
Name: "old.pdf",
Type: "application/pdf",
Size: 50000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/old.pdf",
}
require.Nil(t, s.AddMessage(m1))
// Add a message with another expired attachment
m2 := model.NewDefaultMessage("mytopic", "another old file")
m2.ID = "msg2"
m2.SequenceID = "msg2"
m2.Expires = time.Now().Add(time.Hour).Unix()
m2.Attachment = &model.Attachment{
Name: "another.pdf",
Type: "application/pdf",
Size: 30000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/another.pdf",
}
require.Nil(t, s.AddMessage(m2))
// Both should show as expired attachments needing cleanup
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 2, len(ids))
// Mark msg1's attachment as deleted (file cleaned up)
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
// Now only msg2 should show as needing cleanup
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "msg2", ids[0])
// Mark msg2 too
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
// No more expired attachments to clean up
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 0, len(ids))
// Messages themselves still exist
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
})
}
func TestStore_Stats(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Initial stats should be zero
messages, err := s.Stats()
require.Nil(t, err)
require.Equal(t, int64(0), messages)
// Update stats
require.Nil(t, s.UpdateStats(42))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(42), messages)
// Update again (overwrites)
require.Nil(t, s.UpdateStats(100))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(100), messages)
})
}
func TestStore_AddMessages(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Batch add multiple messages
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "batch 1"),
model.NewDefaultMessage("mytopic", "batch 2"),
model.NewDefaultMessage("othertopic", "batch 3"),
}
require.Nil(t, s.AddMessages(msgs))
// Verify all were inserted
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "batch 3", messages[0].Message)
// Empty batch should succeed
require.Nil(t, s.AddMessages([]*model.Message{}))
// Batch with invalid event type should fail
badMsgs := []*model.Message{
model.NewKeepaliveMessage("mytopic"),
}
require.NotNil(t, s.AddMessages(badMsgs))
})
}
func TestStore_MessagesDue(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Add a message scheduled in the past (i.e. it's due now)
m1 := model.NewDefaultMessage("mytopic", "due message")
m1.Time = time.Now().Add(-time.Second).Unix()
// Set expires in the future so it doesn't get pruned
m1.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
// Add a message scheduled in the future (not due)
m2 := model.NewDefaultMessage("mytopic", "future message")
m2.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m2))
// Mark m1 as published so it won't be "due"
// (MessagesDue returns unpublished messages whose time <= now)
// m1 is auto-published (time <= now), so it should not be due
// m2 is unpublished (time in future), not due yet
due, err := s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Add a message that was explicitly scheduled in the past but time has "arrived"
// We need to manipulate the database to create a truly "due" message:
// a message with published=false and time <= now
m3 := model.NewDefaultMessage("mytopic", "truly due message")
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
require.Nil(t, s.AddMessage(m3))
// Not due yet
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Wait for it to become due
time.Sleep(3 * time.Second)
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 1, len(due))
require.Equal(t, "truly due message", due[0].Message)
})
}
func TestStore_MessageFieldRoundTrip(t *testing.T) {
forEachBackend(t, func(t *testing.T, s *message.Cache) {
// Create a message with all fields populated
m := model.NewDefaultMessage("mytopic", "hello world")
m.SequenceID = "custom_seq_id"
m.Title = "A Title"
m.Priority = 4
m.Tags = []string{"warning", "srv01"}
m.Click = "https://example.com/click"
m.Icon = "https://example.com/icon.png"
m.Actions = []*model.Action{
{
ID: "action1",
Action: "view",
Label: "Open Site",
URL: "https://example.com",
Clear: true,
},
{
ID: "action2",
Action: "http",
Label: "Call Webhook",
URL: "https://example.com/hook",
Method: "PUT",
Headers: map[string]string{"X-Token": "secret"},
Body: `{"key":"value"}`,
},
}
m.ContentType = "text/markdown"
m.Encoding = "base64"
m.Sender = netip.MustParseAddr("9.8.7.6")
m.User = "u_TestUser123"
require.Nil(t, s.AddMessage(m))
// Retrieve and verify every field
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
require.Equal(t, m.Time, retrieved.Time)
require.Equal(t, m.Expires, retrieved.Expires)
require.Equal(t, model.MessageEvent, retrieved.Event)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "hello world", retrieved.Message)
require.Equal(t, "A Title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
require.Equal(t, "https://example.com/click", retrieved.Click)
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
require.Equal(t, "text/markdown", retrieved.ContentType)
require.Equal(t, "base64", retrieved.Encoding)
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
require.Equal(t, "u_TestUser123", retrieved.User)
// Verify actions round-trip
require.Equal(t, 2, len(retrieved.Actions))
require.Equal(t, "action1", retrieved.Actions[0].ID)
require.Equal(t, "view", retrieved.Actions[0].Action)
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
require.Equal(t, true, retrieved.Actions[0].Clear)
require.Equal(t, "action2", retrieved.Actions[1].ID)
require.Equal(t, "http", retrieved.Actions[1].Action)
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
require.Equal(t, "PUT", retrieved.Actions[1].Method)
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
})
}

View File

@@ -1,120 +0,0 @@
package message
import (
"database/sql"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
// PostgreSQL runtime query constants
const (
pgInsertMessageQuery = `
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
`
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
pgSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE mid = $1
`
pgSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`
pgSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND id > $2 AND published = TRUE
ORDER BY time, id
`
pgSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND (id > $2 OR published = FALSE)
ORDER BY time, id
`
pgSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`
pgSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
pgUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
)
var pgQueries = storeQueries{
insertMessage: pgInsertMessageQuery,
deleteMessage: pgDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
selectMessagesByID: pgSelectMessagesByIDQuery,
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: pgSelectMessagesLatestQuery,
selectMessagesDue: pgSelectMessagesDueQuery,
selectMessagesExpired: pgSelectMessagesExpiredQuery,
updateMessagePublished: pgUpdateMessagePublishedQuery,
selectMessagesCount: pgSelectMessagesCountQuery,
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
selectTopics: pgSelectTopicsQuery,
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
selectStats: pgSelectStatsQuery,
updateStats: pgUpdateStatsQuery,
updateMessageTime: pgUpdateMessageTimesQuery,
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
if err := setupPostgresDB(db); err != nil {
return nil, err
}
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
}

View File

@@ -1,120 +0,0 @@
package message_test
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/util"
)
func newTestPostgresStore(t *testing.T) message.Store {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
}
// Create a unique schema for this test
schema := fmt.Sprintf("test_%s", util.RandomString(10))
setupDB, err := sql.Open("pgx", dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
// Open store with search_path set to the new schema
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("search_path", schema)
u.RawQuery = q.Encode()
store, err := message.NewPostgresStore(u.String(), 0, 0)
require.Nil(t, err)
t.Cleanup(func() {
store.Close()
cleanDB, err := sql.Open("pgx", dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
return store
}
func TestPostgresStore_Messages(t *testing.T) {
testCacheMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newTestPostgresStore(t))
}
func TestPostgresStore_Topics(t *testing.T) {
testCacheTopics(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_Prune(t *testing.T) {
testCachePrune(t, newTestPostgresStore(t))
}
func TestPostgresStore_Attachments(t *testing.T) {
testCacheAttachments(t, newTestPostgresStore(t))
}
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
}
func TestPostgresStore_Sender(t *testing.T) {
testSender(t, newTestPostgresStore(t))
}
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageByID(t *testing.T) {
testMessageByID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newTestPostgresStore(t))
}
func TestPostgresStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
}
func TestPostgresStore_Stats(t *testing.T) {
testStats(t, newTestPostgresStore(t))
}
func TestPostgresStore_AddMessages(t *testing.T) {
testAddMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
}

View File

@@ -1,767 +0,0 @@
package message_test
import (
"net/netip"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func testCacheMessages(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
require.Nil(t, s.AddMessage(m2))
// Adding invalid
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
// mytopic: count
counts, err := s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
// mytopic: since all
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, model.MessageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["example"])
// example: since all
messages, _ = s.Messages("example", model.SinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: count
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 0, counts["doesnotexist"])
// non-existing: since all
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
require.Empty(t, messages)
}
func testCacheMessagesLock(t *testing.T, s message.Store) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
}
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := model.NewDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = s.MessagesDue()
require.Empty(t, messages)
}
func testCacheTopics(t *testing.T, s message.Store) {
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
topics, err := s.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Contains(t, topics, "topic1")
require.Contains(t, topics, "topic2")
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, s.AddMessage(m))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
}
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := model.NewDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := model.NewDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := model.NewDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := model.NewDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
require.Nil(t, s.AddMessage(m4))
require.Nil(t, s.AddMessage(m5))
require.Nil(t, s.AddMessage(m6))
require.Nil(t, s.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
}
func testCachePrune(t *testing.T, s message.Store) {
now := time.Now().Unix()
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := model.NewDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
counts, err := s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessageIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"])
require.Equal(t, 0, counts["another_topic"])
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
}
func testCacheAttachments(t *testing.T, s message.Store) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, s.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &model.Attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
}
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func testSender(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, s.AddMessage(m1))
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, s.AddMessage(m2))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
}
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
// Create a scheduled (unpublished) message
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, s.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, s.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
}
func testMessageByID(t *testing.T, s message.Store) {
// Add a message
m := model.NewDefaultMessage("mytopic", "some message")
m.Title = "some title"
m.Priority = 4
m.Tags = []string{"tag1", "tag2"}
require.Nil(t, s.AddMessage(m))
// Retrieve by ID
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "some message", retrieved.Message)
require.Equal(t, "some title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
// Non-existent ID returns ErrMessageNotFound
_, err = s.Message("doesnotexist")
require.Equal(t, model.ErrMessageNotFound, err)
}
func testMarkPublished(t *testing.T, s message.Store) {
// Add a scheduled message (future time → unpublished)
m := model.NewDefaultMessage("mytopic", "scheduled message")
m.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
// Verify it does not appear in non-scheduled queries
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(messages))
// Verify it does appear in scheduled queries
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Mark as published
require.Nil(t, s.MarkPublished(m))
// Now it should appear in non-scheduled queries too
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "scheduled message", messages[0].Message)
}
func testExpireMessages(t *testing.T, s message.Store) {
// Add messages to two topics
m1 := model.NewDefaultMessage("topic1", "message 1")
m1.Expires = time.Now().Add(time.Hour).Unix()
m2 := model.NewDefaultMessage("topic1", "message 2")
m2.Expires = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("topic2", "message 3")
m3.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
// Verify all messages exist
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Expire topic1 messages
require.Nil(t, s.ExpireMessages("topic1"))
// topic1 messages should now be expired (expires set to past)
expiredIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Equal(t, 2, len(expiredIDs))
sort.Strings(expiredIDs)
expectedIDs := []string{m1.ID, m2.ID}
sort.Strings(expectedIDs)
require.Equal(t, expectedIDs, expiredIDs)
// topic2 should be unaffected
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 3", messages[0].Message)
}
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
// Add a message with an expired attachment (file needs cleanup)
m1 := model.NewDefaultMessage("mytopic", "old file")
m1.ID = "msg1"
m1.SequenceID = "msg1"
m1.Expires = time.Now().Add(time.Hour).Unix()
m1.Attachment = &model.Attachment{
Name: "old.pdf",
Type: "application/pdf",
Size: 50000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/old.pdf",
}
require.Nil(t, s.AddMessage(m1))
// Add a message with another expired attachment
m2 := model.NewDefaultMessage("mytopic", "another old file")
m2.ID = "msg2"
m2.SequenceID = "msg2"
m2.Expires = time.Now().Add(time.Hour).Unix()
m2.Attachment = &model.Attachment{
Name: "another.pdf",
Type: "application/pdf",
Size: 30000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/another.pdf",
}
require.Nil(t, s.AddMessage(m2))
// Both should show as expired attachments needing cleanup
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 2, len(ids))
// Mark msg1's attachment as deleted (file cleaned up)
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
// Now only msg2 should show as needing cleanup
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "msg2", ids[0])
// Mark msg2 too
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
// No more expired attachments to clean up
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 0, len(ids))
// Messages themselves still exist
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
}
func testStats(t *testing.T, s message.Store) {
// Initial stats should be zero
messages, err := s.Stats()
require.Nil(t, err)
require.Equal(t, int64(0), messages)
// Update stats
require.Nil(t, s.UpdateStats(42))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(42), messages)
// Update again (overwrites)
require.Nil(t, s.UpdateStats(100))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(100), messages)
}
func testAddMessages(t *testing.T, s message.Store) {
// Batch add multiple messages
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "batch 1"),
model.NewDefaultMessage("mytopic", "batch 2"),
model.NewDefaultMessage("othertopic", "batch 3"),
}
require.Nil(t, s.AddMessages(msgs))
// Verify all were inserted
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "batch 3", messages[0].Message)
// Empty batch should succeed
require.Nil(t, s.AddMessages([]*model.Message{}))
// Batch with invalid event type should fail
badMsgs := []*model.Message{
model.NewKeepaliveMessage("mytopic"),
}
require.NotNil(t, s.AddMessages(badMsgs))
}
func testMessagesDue(t *testing.T, s message.Store) {
// Add a message scheduled in the past (i.e. it's due now)
m1 := model.NewDefaultMessage("mytopic", "due message")
m1.Time = time.Now().Add(-time.Second).Unix()
// Set expires in the future so it doesn't get pruned
m1.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
// Add a message scheduled in the future (not due)
m2 := model.NewDefaultMessage("mytopic", "future message")
m2.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m2))
// Mark m1 as published so it won't be "due"
// (MessagesDue returns unpublished messages whose time <= now)
// m1 is auto-published (time <= now), so it should not be due
// m2 is unpublished (time in future), not due yet
due, err := s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Add a message that was explicitly scheduled in the past but time has "arrived"
// We need to manipulate the database to create a truly "due" message:
// a message with published=false and time <= now
m3 := model.NewDefaultMessage("mytopic", "truly due message")
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
require.Nil(t, s.AddMessage(m3))
// Not due yet
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Wait for it to become due
time.Sleep(3 * time.Second)
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 1, len(due))
require.Equal(t, "truly due message", due[0].Message)
}
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
// Create a message with all fields populated
m := model.NewDefaultMessage("mytopic", "hello world")
m.SequenceID = "custom_seq_id"
m.Title = "A Title"
m.Priority = 4
m.Tags = []string{"warning", "srv01"}
m.Click = "https://example.com/click"
m.Icon = "https://example.com/icon.png"
m.Actions = []*model.Action{
{
ID: "action1",
Action: "view",
Label: "Open Site",
URL: "https://example.com",
Clear: true,
},
{
ID: "action2",
Action: "http",
Label: "Call Webhook",
URL: "https://example.com/hook",
Method: "PUT",
Headers: map[string]string{"X-Token": "secret"},
Body: `{"key":"value"}`,
},
}
m.ContentType = "text/markdown"
m.Encoding = "base64"
m.Sender = netip.MustParseAddr("9.8.7.6")
m.User = "u_TestUser123"
require.Nil(t, s.AddMessage(m))
// Retrieve and verify every field
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
require.Equal(t, m.Time, retrieved.Time)
require.Equal(t, m.Expires, retrieved.Expires)
require.Equal(t, model.MessageEvent, retrieved.Event)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "hello world", retrieved.Message)
require.Equal(t, "A Title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
require.Equal(t, "https://example.com/click", retrieved.Click)
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
require.Equal(t, "text/markdown", retrieved.ContentType)
require.Equal(t, "base64", retrieved.Encoding)
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
require.Equal(t, "u_TestUser123", retrieved.User)
// Verify actions round-trip
require.Equal(t, 2, len(retrieved.Actions))
require.Equal(t, "action1", retrieved.Actions[0].ID)
require.Equal(t, "view", retrieved.Actions[0].Action)
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
require.Equal(t, true, retrieved.Actions[0].Clear)
require.Equal(t, "action2", retrieved.Actions[1].ID)
require.Equal(t, "http", retrieved.Actions[1].Action)
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
require.Equal(t, "PUT", retrieved.Actions[1].Method)
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
}

View File

@@ -146,6 +146,13 @@ func NewActionMessage(event, topic, sequenceID string) *Message {
return m
}
// NewPollRequestMessage is a convenience method to create a poll request message
func NewPollRequestMessage(topic, pollID string) *Message {
m := NewMessage(PollRequestEvent, topic, "New message")
m.PollID = pollID
return m
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, MessageIDLength)
@@ -184,7 +191,7 @@ func (t SinceMarker) IsLatest() bool {
// IsID returns true if this marker references a specific message ID
func (t SinceMarker) IsID() bool {
return t.id != "" && t.id != "latest"
return t.id != "" && t.id != SinceLatestMessage.id
}
// Time returns the time component of the marker

View File

@@ -136,7 +136,7 @@ func (p *actionParser) Parse() ([]*model.Action, error) {
// and then uses populateAction to interpret the keys/values. The function terminates
// when EOF or ";" is reached.
func (p *actionParser) parseAction() (*model.Action, error) {
a := newAction()
a := model.NewAction()
section := 0
for {
key, value, last, err := p.parseSection()

View File

@@ -88,7 +88,6 @@ var (
// Config is the main config struct for the application. Use New to instantiate a default config struct.
type Config struct {
File string // Config file, only used for testing
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
BaseURL string
ListenHTTP string
ListenHTTPS string
@@ -96,6 +95,7 @@ type Config struct {
ListenUnixMode fs.FileMode
KeyFile string
CertFile string
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
FirebaseKeyFile string
CacheFile string
CacheDuration time.Duration
@@ -193,7 +193,6 @@ type Config struct {
func NewConfig() *Config {
return &Config{
File: DefaultConfigFile, // Only used for testing
DatabaseURL: "",
BaseURL: "",
ListenHTTP: DefaultListenHTTP,
ListenHTTPS: "",
@@ -201,6 +200,7 @@ func NewConfig() *Config {
ListenUnixMode: 0,
KeyFile: "",
CertFile: "",
DatabaseURL: "",
FirebaseKeyFile: "",
CacheFile: "",
CacheDuration: DefaultCacheDuration,

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
"io"
"os"
@@ -13,7 +14,7 @@ import (
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
errFileExists = errors.New("file exists")
)

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/sha256"
"database/sql"
"embed"
"encoding/base64"
"encoding/json"
@@ -32,6 +33,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
@@ -45,6 +47,7 @@ import (
// Server is the main server, providing the UI and API for ntfy
type Server struct {
config *Config
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
httpServer *http.Server
httpsServer *http.Server
httpMetricsServer *http.Server
@@ -59,8 +62,8 @@ type Server struct {
messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache message.Store // Database that stores the messages
webPush webpush.Store // Database that stores web push subscriptions
messageCache *message.Cache // Database that stores the messages
webPush *webpush.Store // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
@@ -175,14 +178,23 @@ func New(conf *Config) (*Server, error) {
if payments.Available && conf.StripeSecretKey != "" {
stripe = newStripeAPI()
}
messageCache, err := createMessageCache(conf)
// OpenPostgres shared PostgreSQL connection pool if configured
var pool *sql.DB
if conf.DatabaseURL != "" {
var err error
pool, err = db.OpenPostgres(conf.DatabaseURL)
if err != nil {
return nil, err
}
}
messageCache, err := createMessageCache(conf, pool)
if err != nil {
return nil, err
}
var wp webpush.Store
var wp *webpush.Store
if conf.WebPushPublicKey != "" {
if conf.DatabaseURL != "" {
wp, err = webpush.NewPostgresStore(conf.DatabaseURL)
if pool != nil {
wp, err = webpush.NewPostgresStore(pool)
} else {
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
}
@@ -210,7 +222,7 @@ func New(conf *Config) (*Server, error) {
}
}
var userManager *user.Manager
if conf.AuthFile != "" || conf.DatabaseURL != "" {
if conf.AuthFile != "" || pool != nil {
authConfig := &user.Config{
Filename: conf.AuthFile,
DatabaseURL: conf.DatabaseURL,
@@ -223,19 +235,14 @@ func New(conf *Config) (*Server, error) {
BcryptCost: conf.AuthBcryptCost,
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
}
var store user.Store
if conf.DatabaseURL != "" {
store, err = user.NewPostgresStore(conf.DatabaseURL)
if pool != nil {
userManager, err = user.NewPostgresManager(pool, authConfig)
} else {
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
userManager, err = user.NewSQLiteManager(conf.AuthFile, conf.AuthStartupQueries, authConfig)
}
if err != nil {
return nil, err
}
userManager, err = user.NewManager(store, authConfig)
if err != nil {
return nil, err
}
}
var firebaseClient *firebaseClient
if conf.FirebaseKeyFile != "" {
@@ -253,6 +260,7 @@ func New(conf *Config) (*Server, error) {
}
s := &Server{
config: conf,
db: pool,
messageCache: messageCache,
webPush: wp,
fileCache: fileCache,
@@ -269,11 +277,11 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config) (message.Store, error) {
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
if conf.CacheDuration == 0 {
return message.NewNopStore()
} else if conf.DatabaseURL != "" {
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if pool != nil {
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" {
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
}
@@ -412,6 +420,9 @@ func (s *Server) closeDatabases() {
if s.webPush != nil {
s.webPush.Close()
}
if s.db != nil {
s.db.Close()
}
}
// handle is the main entry point for all HTTP requests
@@ -754,7 +765,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
// - and also uses the higher bandwidth limits of a paying user
m, err := s.messageCache.Message(messageID)
if errors.Is(err, errMessageNotFound) {
if errors.Is(err, model.ErrMessageNotFound) {
if s.config.CacheBatchTimeout > 0 {
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
// and messages are persisted asynchronously, retry fetching from the database
@@ -818,7 +829,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
if err != nil {
return nil, err
}
m := newDefaultMessage(t.ID, "")
m := model.NewDefaultMessage(t.ID, "")
cache, firebase, email, call, template, unifiedpush, priorityStr, e := s.parsePublishParams(r, m)
if e != nil {
return nil, e.With(t)
@@ -843,7 +854,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
}
}
if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID)
m = model.NewPollRequestMessage(t.ID, m.PollID)
}
m.Sender = v.IP()
m.User = v.MaybeUserID()
@@ -961,11 +972,11 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *
}
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
return s.handleActionMessage(w, r, v, messageDeleteEvent)
return s.handleActionMessage(w, r, v, model.MessageDeleteEvent)
}
func (s *Server) handleClear(w http.ResponseWriter, r *http.Request, v *visitor) error {
return s.handleActionMessage(w, r, v, messageClearEvent)
return s.handleActionMessage(w, r, v, model.MessageClearEvent)
}
func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string) error {
@@ -985,7 +996,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
return e.With(t)
}
// Create an action message with the given event type
m := newActionMessage(event, t.ID, sequenceID)
m := model.NewActionMessage(event, t.ID, sequenceID)
m.Sender = v.IP()
m.User = v.MaybeUserID()
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
@@ -1001,7 +1012,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
if event == messageDeleteEvent {
if event == model.MessageDeleteEvent {
// Delete any existing scheduled message with the same sequence ID
deletedIDs, err := s.messageCache.DeleteScheduledBySequenceID(t.ID, sequenceID)
if err != nil {
@@ -1230,7 +1241,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bo
// 7. curl -T file.txt ntfy.sh/mytopic
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
if m.Event == pollRequestEvent { // Case 1
if m.Event == model.PollRequestEvent { // Case 1
return s.handleBodyDiscard(body)
} else if unifiedpush {
return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
@@ -1450,7 +1461,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
return "", err
}
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
}
return fmt.Sprintf("data: %s\n", buf.String()), nil
@@ -1460,7 +1471,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *model.Message) (string, error) {
if msg.Event == messageEvent { // only handle default events
if msg.Event == model.MessageEvent { // only handle default events
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
}
return "\n", nil // "keepalive" and "open" events just send an empty line
@@ -1538,7 +1549,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
@@ -1561,7 +1572,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
for _, t := range topics {
t.Keepalive()
}
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
if err := sub(v, model.NewKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
return err
}
}
@@ -1687,7 +1698,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
topics[i].Unsubscribe(subscriberID) // Order!
}
}()
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
return err
}
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
@@ -1818,26 +1829,26 @@ func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
// Easy cases (empty, all, none)
if since == "" {
if poll {
return sinceAllMessages, nil
return model.SinceAllMessages, nil
}
return sinceNoMessages, nil
return model.SinceNoMessages, nil
} else if since == "all" {
return sinceAllMessages, nil
return model.SinceAllMessages, nil
} else if since == "latest" {
return sinceLatestMessage, nil
return model.SinceLatestMessage, nil
} else if since == "none" {
return sinceNoMessages, nil
return model.SinceNoMessages, nil
}
// ID, timestamp, duration
if validMessageID(since) {
return newSinceID(since), nil
if model.ValidMessageID(since) {
return model.NewSinceID(since), nil
} else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
return newSinceTime(s), nil
return model.NewSinceTime(s), nil
} else if d, err := time.ParseDuration(since); err == nil {
return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
return model.NewSinceTime(time.Now().Add(-1 * d).Unix()), nil
}
return sinceNoMessages, errHTTPBadRequestSinceInvalid
return model.SinceNoMessages, errHTTPBadRequestSinceInvalid
}
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
@@ -1993,14 +2004,14 @@ func (s *Server) runFirebaseKeepaliver() {
for {
select {
case <-time.After(s.config.FirebaseKeepaliveInterval):
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
s.sendToFirebase(v, model.NewKeepaliveMessage(firebaseControlTopic))
/*
FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
Given that it's not really necessary to poll, turning it off for now should not have any impact.
case <-time.After(s.config.FirebasePollInterval):
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
s.sendToFirebase(v, model.NewKeepaliveMessage(firebasePollTopic))
*/
case <-s.closeChan:
return

View File

@@ -38,14 +38,32 @@
#
# firebase-key-file: <filename>
# If "database-url" is set, ntfy will use PostgreSQL for database-backed stores instead of SQLite.
# Currently this applies to the web push subscription store. Message cache and user manager support
# will be added in future releases. When set, the "web-push-file" option is not required.
# If "database-url" is set, ntfy will use PostgreSQL for all database-backed stores (message cache,
# user manager, and web push subscriptions) instead of SQLite. When set, the "cache-file",
# "auth-file", and "web-push-file" options must not be set.
#
# database-url: "postgres://user:pass@host:5432/ntfy"
# Note: Setting "database-url" implicitly enables authentication and access control.
# The default access is "read-write" (see "auth-default-access").
#
# The URL supports standard PostgreSQL parameters (sslmode, connect_timeout, sslcert, etc.),
# as well as ntfy-specific connection pool parameters:
# pool_max_conns=10 - Maximum number of open connections (default: 10)
# pool_max_idle_conns=N - Maximum number of idle connections
# pool_conn_max_lifetime=5m - Maximum lifetime of a connection (Go duration)
# pool_conn_max_idle_time=1m - Maximum idle time of a connection (Go duration)
#
# See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
# for the full list of supported PostgreSQL connection parameters.
#
# Examples:
# database-url: "postgres://user:pass@host:5432/ntfy"
# database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50"
#
# database-url: <connection-string>
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
# This allows for service restarts without losing messages in support of the since= parameter.
# Not required if "database-url" is set (messages are stored in PostgreSQL instead).
#
# The "cache-duration" parameter defines the duration for which messages will be buffered
# before they are deleted. This is required to support the "since=..." and "poll=1" parameter.
@@ -83,6 +101,8 @@
# If set, access to the ntfy server and API can be controlled on a granular level using
# the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs.
#
# Note: If "database-url" is set, auth is implicitly enabled and "auth-file" must not be set.
#
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
@@ -203,6 +223,7 @@
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
# Not required if "database-url" is set (subscriptions are stored in PostgreSQL instead).
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
# - web-push-startup-queries is an optional list of queries to run on startup
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d)

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"net/http"
@@ -641,7 +642,7 @@ func (s *Server) publishSyncEvent(v *visitor) error {
if err != nil {
return err
}
m := newDefaultMessage(syncTopic.ID, string(messageBytes))
m := model.NewDefaultMessage(syncTopic.ID, string(messageBytes))
if err := syncTopic.Publish(v, m); err != nil {
return err
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"io"
@@ -15,8 +16,8 @@ import (
)
func TestAccount_Signup_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
defer s.closeDatabases()
@@ -54,8 +55,8 @@ func TestAccount_Signup_Success(t *testing.T) {
}
func TestAccount_Signup_UserExists(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
defer s.closeDatabases()
@@ -70,8 +71,8 @@ func TestAccount_Signup_UserExists(t *testing.T) {
}
func TestAccount_Signup_LimitReached(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
defer s.closeDatabases()
@@ -87,8 +88,8 @@ func TestAccount_Signup_LimitReached(t *testing.T) {
}
func TestAccount_Signup_AsUser(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
defer s.closeDatabases()
@@ -111,8 +112,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
}
func TestAccount_Signup_Disabled(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = false
s := newTestServer(t, conf)
defer s.closeDatabases()
@@ -124,8 +125,8 @@ func TestAccount_Signup_Disabled(t *testing.T) {
}
func TestAccount_Signup_Rate_Limit(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -140,8 +141,8 @@ func TestAccount_Signup_Rate_Limit(t *testing.T) {
}
func TestAccount_Get_Anonymous(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.VisitorRequestLimitReplenish = 86 * time.Second
conf.VisitorEmailLimitReplenish = time.Hour
conf.VisitorAttachmentTotalSizeLimit = 5123
@@ -185,8 +186,8 @@ func TestAccount_Get_Anonymous(t *testing.T) {
}
func TestAccount_ChangeSettings(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
@@ -216,8 +217,8 @@ func TestAccount_ChangeSettings(t *testing.T) {
}
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
@@ -269,8 +270,8 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
}
func TestAccount_ChangePassword(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.AuthUsers = []*user.User{
{Name: "philuser", Hash: "$2a$10$U4WSIYY6evyGmZaraavM2e2JeVG6EMGUKN1uUwufUeeRd4Jpg6cGC", Role: user.RoleUser}, // philuser:philpass
}
@@ -314,8 +315,8 @@ func TestAccount_ChangePassword(t *testing.T) {
}
func TestAccount_ChangePassword_NoAccount(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil)
@@ -324,9 +325,9 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
}
func TestAccount_ExtendToken(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
t.Parallel()
s := newTestServer(t, newTestConfigWithAuthFile(t))
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
@@ -363,8 +364,8 @@ func TestAccount_ExtendToken(t *testing.T) {
}
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
@@ -378,8 +379,8 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
}
func TestAccount_DeleteToken(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
@@ -420,8 +421,8 @@ func TestAccount_DeleteToken(t *testing.T) {
}
func TestAccount_Delete_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -451,8 +452,8 @@ func TestAccount_Delete_Success(t *testing.T) {
}
func TestAccount_Delete_Not_Allowed(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -474,8 +475,8 @@ func TestAccount_Delete_Not_Allowed(t *testing.T) {
}
func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -490,8 +491,8 @@ func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
}
func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -544,8 +545,8 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
}
func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.EnableSignup = true
conf.EnableReservations = true
conf.TwilioAccount = "dummy"
@@ -632,8 +633,8 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
}
func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.AuthDefault = user.PermissionReadWrite
conf.EnableSignup = true
s := newTestServer(t, conf)
@@ -668,9 +669,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
}
func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
t.Parallel()
conf := newTestConfigWithAuthFile(t)
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.AuthDefault = user.PermissionReadWrite
s := newTestServer(t, conf)
@@ -715,12 +716,12 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
// Pre-verify message count and file
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
ms, err := s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
ms, err = s.messageCache.Messages("mytopic2", sinceAllMessages, false)
ms, err = s.messageCache.Messages("mytopic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
@@ -741,17 +742,17 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
// Verify that messages and attachments were deleted
// This does not explicitly call the manager!
waitFor(t, func() bool {
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
ms, err := s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
require.Nil(t, err)
return len(ms) == 0 && !util.FileExists(filepath.Join(s.config.AttachmentCacheDir, m1.ID))
})
ms, err = s.messageCache.Messages("mytopic1", sinceAllMessages, false)
ms, err = s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(ms))
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
ms, err = s.messageCache.Messages("mytopic2", sinceAllMessages, false)
ms, err = s.messageCache.Messages("mytopic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(ms))
require.Equal(t, m2.ID, ms[0].ID)
@@ -760,7 +761,7 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
}
/*func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf := newTestConfigWithAuthFile(t, databaseURL)
conf.AuthDefault = user.PermissionReadWrite
conf.AuthStatsQueueWriterInterval = 300 * time.Millisecond
s := newTestServer(t, conf)

View File

@@ -11,8 +11,8 @@ import (
)
func TestVersion_Admin(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.BuildVersion = "1.2.3"
c.BuildCommit = "abcdef0"
c.BuildDate = "2026-02-08T00:00:00Z"
@@ -48,8 +48,8 @@ func TestVersion_Admin(t *testing.T) {
}
func TestUser_AddRemove(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
@@ -106,8 +106,8 @@ func TestUser_AddRemove(t *testing.T) {
}
func TestUser_AddWithPasswordHash(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
@@ -137,8 +137,8 @@ func TestUser_AddWithPasswordHash(t *testing.T) {
}
func TestUser_ChangeUserPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
@@ -177,8 +177,8 @@ func TestUser_ChangeUserPassword(t *testing.T) {
}
func TestUser_ChangeUserTier(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
@@ -219,8 +219,8 @@ func TestUser_ChangeUserTier(t *testing.T) {
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin, tier
@@ -273,8 +273,8 @@ func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
@@ -307,8 +307,8 @@ func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
@@ -324,8 +324,8 @@ func TestUser_DontChangeAdminPassword(t *testing.T) {
}
func TestUser_AddRemove_Failures(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
defer s.closeDatabases()
// Create admin
@@ -365,8 +365,8 @@ func TestUser_AddRemove_Failures(t *testing.T) {
}
func TestAccess_AllowReset(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
@@ -408,8 +408,8 @@ func TestAccess_AllowReset(t *testing.T) {
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
@@ -426,8 +426,8 @@ func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()

View File

@@ -126,7 +126,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
var apnsConfig *messaging.APNSConfig
switch m.Event {
case keepaliveEvent, openEvent:
case model.KeepaliveEvent, model.OpenEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -134,7 +134,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
"topic": m.Topic,
}
apnsConfig = createAPNSBackgroundConfig(data)
case pollRequestEvent:
case model.PollRequestEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -144,7 +144,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
"poll_id": m.PollID,
}
apnsConfig = createAPNSAlertConfig(m, data)
case messageDeleteEvent, messageClearEvent:
case model.MessageDeleteEvent, model.MessageClearEvent:
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
@@ -153,7 +153,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
"sequence_id": m.SequenceID,
}
apnsConfig = createAPNSBackgroundConfig(data)
case messageEvent:
case model.MessageEvent:
if auther != nil {
// If "anonymous read" for a topic is not allowed, we cannot send the message along
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
@@ -298,7 +298,7 @@ func maybeTruncateAPNSBodyMessage(s string) string {
// This empties all the fields that are not needed for a poll request and just sets the required fields,
// most importantly, the PollID.
func toPollRequest(m *model.Message) *model.Message {
pr := newPollRequestMessage(m.Topic, m.ID)
pr := model.NewPollRequestMessage(m.Topic, m.ID)
pr.ID = m.ID
pr.Time = m.Time
pr.Priority = m.Priority // Keep priority

View File

@@ -64,7 +64,7 @@ func (s *testFirebaseSender) Messages() []*messaging.Message {
}
func TestToFirebaseMessage_Keepalive(t *testing.T) {
m := newKeepaliveMessage("mytopic")
m := model.NewKeepaliveMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -95,7 +95,7 @@ func TestToFirebaseMessage_Keepalive(t *testing.T) {
}
func TestToFirebaseMessage_Open(t *testing.T) {
m := newOpenMessage("mytopic")
m := model.NewOpenMessage("mytopic")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -126,7 +126,7 @@ func TestToFirebaseMessage_Open(t *testing.T) {
}
func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
m := newDefaultMessage("mytopic", "this is a message")
m := model.NewDefaultMessage("mytopic", "this is a message")
m.Priority = 4
m.Tags = []string{"tag 1", "tag2"}
m.Click = "https://google.com"
@@ -220,7 +220,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
}
func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
m := newDefaultMessage("mytopic", "this is a message")
m := model.NewDefaultMessage("mytopic", "this is a message")
m.Priority = 5
fbm, err := toFirebaseMessage(m, &testAuther{Allow: false}) // Not allowed!
require.Nil(t, err)
@@ -251,7 +251,7 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
}
func TestToFirebaseMessage_PollRequest(t *testing.T) {
m := newPollRequestMessage("mytopic", "fOv6k1QbCzo6")
m := model.NewPollRequestMessage("mytopic", "fOv6k1QbCzo6")
fbm, err := toFirebaseMessage(m, nil)
require.Nil(t, err)
require.Equal(t, "mytopic", fbm.Topic)
@@ -345,7 +345,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
func TestToFirebaseSender_Abuse(t *testing.T) {
sender := &testFirebaseSender{allowed: 2}
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
visitor := newVisitor(newTestConfig(t, ""), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages()))

View File

@@ -17,15 +17,10 @@ func (s *Server) execManager() {
s.pruneMessages()
s.pruneAndNotifyWebPushSubscriptions()
// Message count per topic
var messagesCached int
messageCounts, err := s.messageCache.MessageCounts()
// Message count
messagesCached, err := s.messageCache.MessagesCount()
if err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
messageCounts = make(map[string]int) // Empty, so we can continue
}
for _, count := range messageCounts {
messagesCached += count
log.Tag(tagManager).Err(err).Warn("Cannot get messages count")
}
// Remove subscriptions without subscribers

View File

@@ -2,13 +2,14 @@ package server
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
"testing"
)
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t)
c := newTestConfig(t, databaseURL)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
@@ -25,6 +26,6 @@ func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testi
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, errMessageNotFound, err)
require.Equal(t, model.ErrMessageNotFound, err)
})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/stripe/stripe-go/v74"
"golang.org/x/time/rate"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
@@ -21,11 +22,11 @@ import (
)
func TestPayments_Tiers(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitReplenish = 12 * time.Hour
@@ -133,11 +134,11 @@ func TestPayments_Tiers(t *testing.T) {
}
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -168,11 +169,11 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
}
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -214,11 +215,11 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
}
func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.EnableSignup = true
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
@@ -261,7 +262,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
}
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
// This test is too overloaded, but it's also a great end-to-end a test.
//
// It tests:
@@ -273,7 +274,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
c.VisitorRequestLimitBurst = 5
@@ -428,7 +429,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
}
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
t.Parallel()
// This tests incoming webhooks from Stripe to update a subscription:
@@ -439,7 +440,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -546,12 +547,12 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
time.Sleep(time.Second)
s.execManager()
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
ms, err := s.messageCache.Messages("atopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(ms))
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
ms, err = s.messageCache.Messages("ztopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(ms))
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
@@ -559,7 +560,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
}
func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
// This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
// updated (all Stripe fields are deleted, and the tier is removed).
//
@@ -568,7 +569,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -626,11 +627,11 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
}
func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -692,11 +693,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
}
func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)
@@ -725,11 +726,11 @@ func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
}
func TestPayments_CreatePortalSession(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
stripeMock := &testStripeAPI{}
defer stripeMock.AssertExpectations(t)
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.StripeSecretKey = "secret key"
c.StripeWebhookKey = "webhook key"
s := newTestServer(t, c)

View File

@@ -0,0 +1,5 @@
//go:build !race
package server
const raceEnabled = false

View File

@@ -0,0 +1,5 @@
//go:build race
package server
const raceEnabled = true

File diff suppressed because one or more lines are too long

View File

@@ -14,7 +14,7 @@ import (
)
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -51,7 +51,7 @@ func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
@@ -117,7 +117,7 @@ func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
}
func TestServer_Twilio_Call_Success(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
@@ -132,7 +132,7 @@ func TestServer_Twilio_Call_Success(t *testing.T) {
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -164,7 +164,7 @@ func TestServer_Twilio_Call_Success(t *testing.T) {
}
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
@@ -179,7 +179,7 @@ func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -211,7 +211,7 @@ func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
}
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
forEachBackend(t, func(t *testing.T, databaseURL string) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
@@ -226,7 +226,7 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -274,8 +274,8 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
}
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -301,8 +301,8 @@ func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
}
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -317,8 +317,8 @@ func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
}
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
c := newTestConfigWithAuthFile(t, databaseURL)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
@@ -333,8 +333,8 @@ func TestServer_Twilio_Call_Anonymous(t *testing.T) {
}
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
})

View File

@@ -26,21 +26,21 @@ const (
)
func TestServer_WebPush_Enabled(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
conf := newTestConfig(t)
forEachBackend(t, func(t *testing.T, databaseURL string) {
conf := newTestConfig(t, databaseURL)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
conf2 := newTestConfig(t, databaseURL)
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf3 := newTestConfigWithWebPush(t)
conf3 := newTestConfigWithWebPush(t, databaseURL)
s3 := newTestServer(t, conf3)
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
@@ -50,8 +50,8 @@ func TestServer_WebPush_Enabled(t *testing.T) {
})
}
func TestServer_WebPush_Disabled(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfig(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
@@ -59,8 +59,8 @@ func TestServer_WebPush_Disabled(t *testing.T) {
}
func TestServer_WebPush_TopicAdd(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
@@ -78,8 +78,8 @@ func TestServer_WebPush_TopicAdd(t *testing.T) {
}
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
@@ -88,8 +88,8 @@ func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
}
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
topicList := make([]string, 51)
for i := range topicList {
@@ -103,8 +103,8 @@ func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
@@ -118,8 +118,8 @@ func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
}
func TestServer_WebPush_Delete(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
@@ -133,8 +133,8 @@ func TestServer_WebPush_Delete(t *testing.T) {
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
@@ -155,8 +155,8 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
@@ -168,8 +168,8 @@ func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
@@ -193,8 +193,8 @@ func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
}
func TestServer_WebPush_Publish(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -217,8 +217,8 @@ func TestServer_WebPush_Publish(t *testing.T) {
}
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -247,8 +247,8 @@ func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
}
func TestServer_WebPush_Expiry(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T, databaseURL string) {
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
var received atomic.Bool
@@ -307,8 +307,8 @@ func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLen
require.Len(t, subs, expectedLength)
}
func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
func newTestConfigWithWebPush(t *testing.T, databaseURL string) *Config {
conf := newTestConfig(t, databaseURL)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
require.Nil(t, err)
if conf.DatabaseURL == "" {

View File

@@ -34,6 +34,7 @@ var (
var (
onlySpacesRegex = regexp.MustCompile(`(?m)^\s+$`)
consecutiveNewLinesRegex = regexp.MustCompile(`\n{3,}`)
htmlLineBreakRegex = regexp.MustCompile(`(?i)<br\s*/?>`)
)
const (
@@ -159,7 +160,7 @@ func (s *smtpSession) Data(r io.Reader) error {
if len(body) > conf.MessageSizeLimit {
body = body[:conf.MessageSizeLimit]
}
m := newDefaultMessage(s.topic, body)
m := model.NewDefaultMessage(s.topic, body)
subject := strings.TrimSpace(msg.Header.Get("Subject"))
if subject != "" {
dec := mime.WordDecoder{}
@@ -328,6 +329,9 @@ func readHTMLMailBody(reader io.Reader, transferEncoding string) (string, error)
if err != nil {
return "", err
}
// Convert <br> tags to newlines before stripping HTML, so that line breaks
// in HTML emails (e.g. from Synology DSM, and other appliances) are preserved.
body = htmlLineBreakRegex.ReplaceAllString(body, "\n")
stripped := bluemonday.
StrictPolicy().
AddSpaceWhenStrippingTag(true).

View File

@@ -694,7 +694,8 @@ home automation setup
Now the light is on
If you don&#39;t want to receive this message anymore, stop the push
services in your FRITZ!Box .
services in your FRITZ!Box .
Here you can see the active push services: &#34;System &gt; Push Service&#34;.
This mail has ben sent by your FRITZ!Box automatically.`
@@ -1354,9 +1355,11 @@ Congratulations! You have successfully set up the email notification on Synology
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/synology", r.URL.Path)
require.Equal(t, "[Synology NAS] Test Message from Litts_NAS", r.Header.Get("Title"))
actual := readAll(t, r.Body)
expected := `Congratulations! You have successfully set up the email notification on Synology_NAS. For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/. (If you cannot connect to the server, please contact the administrator.) From Synology_NAS`
require.Equal(t, expected, actual)
expected := "Congratulations! You have successfully set up the email notification on Synology_NAS.\n" +
"For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/.\n" +
"(If you cannot connect to the server, please contact the administrator.)\n\n" +
"From Synology_NAS"
require.Equal(t, expected, readAll(t, r.Body))
})
conf.SMTPServerDomain = "mydomain.me"
conf.SMTPServerAddrPrefix = ""
@@ -1365,6 +1368,36 @@ Congratulations! You have successfully set up the email notification on Synology
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_HTMLEmail_BrTagsPreserved(t *testing.T) {
email := `EHLO example.com
MAIL FROM: nas@example.com
RCPT TO: ntfy-alerts@ntfy.sh
DATA
Content-Type: text/html; charset=utf-8
Content-Transfer-Encoding: 8bit
Subject: Task Scheduler: daily-backup
Task Scheduler has completed a scheduled task.<BR><BR>Task: daily-backup<BR>Start time: Mon, 01 Jan 2026 02:00:00 +0000<BR>Stop time: Mon, 01 Jan 2024 02:03:00 +0000<BR>Current status: 0 (Normal)<BR>Standard output/error:<BR>OK<BR><BR>From MyNAS
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/alerts", r.URL.Path)
require.Equal(t, "Task Scheduler: daily-backup", r.Header.Get("Title"))
expected := "Task Scheduler has completed a scheduled task.\n\n" +
"Task: daily-backup\n" +
"Start time: Mon, 01 Jan 2026 02:00:00 +0000\n" +
"Stop time: Mon, 01 Jan 2024 02:03:00 +0000\n" +
"Current status: 0 (Normal)\n" +
"Standard output/error:\n" +
"OK\n\n" +
"From MyNAS"
require.Equal(t, expected, readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
@@ -1411,7 +1444,7 @@ what's up
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
conf = newTestConfig(t)
conf = newTestConfig(t, "")
conf.SMTPServerListen = ":25"
conf.SMTPServerDomain = "ntfy.sh"
conf.SMTPServerAddrPrefix = "ntfy-"

View File

@@ -8,49 +8,6 @@ import (
"heckel.io/ntfy/v2/util"
)
// Event constants
const (
openEvent = model.OpenEvent
keepaliveEvent = model.KeepaliveEvent
messageEvent = model.MessageEvent
messageDeleteEvent = model.MessageDeleteEvent
messageClearEvent = model.MessageClearEvent
pollRequestEvent = model.PollRequestEvent
messageIDLength = model.MessageIDLength
)
// SinceMarker aliases
var (
sinceAllMessages = model.SinceAllMessages
sinceNoMessages = model.SinceNoMessages
sinceLatestMessage = model.SinceLatestMessage
)
// Error aliases
var (
errMessageNotFound = model.ErrMessageNotFound
)
// Constructors and helpers
var (
newMessage = model.NewMessage
newDefaultMessage = model.NewDefaultMessage
newOpenMessage = model.NewOpenMessage
newKeepaliveMessage = model.NewKeepaliveMessage
newActionMessage = model.NewActionMessage
newAction = model.NewAction
newSinceTime = model.NewSinceTime
newSinceID = model.NewSinceID
validMessageID = model.ValidMessageID
)
// newPollRequestMessage is a convenience method to create a poll request message
func newPollRequestMessage(topic, pollID string) *model.Message {
m := newMessage(pollRequestEvent, topic, newMessageBody)
m.PollID = pollID
return m
}
// publishMessage is used as input when publishing as JSON
type publishMessage struct {
Topic string `json:"topic"`
@@ -106,7 +63,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
}
func (q *queryFilter) Pass(msg *model.Message) bool {
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
return true // filters only apply to messages
} else if q.ID != "" && msg.ID != q.ID {
return false

View File

@@ -54,7 +54,7 @@ const (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
messageCache message.Store
messageCache *message.Cache
userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil
@@ -115,7 +115,7 @@ const (
visitorLimitBasisTier = visitorLimitBasis("tier")
)
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails, calls int64
if user != nil {
messages = user.Stats.Messages

35
tools/pgimport/README.md Normal file
View File

@@ -0,0 +1,35 @@
# pgimport
Migrates ntfy data from SQLite to PostgreSQL.
## Build
```bash
go build -o pgimport ./tools/pgimport/
```
## Usage
```bash
# Using CLI flags
pgimport \
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
--cache-file /var/cache/ntfy/cache.db \
--auth-file /var/lib/ntfy/user.db \
--web-push-file /var/lib/ntfy/webpush.db
# Using server.yml (flags override config values)
pgimport --config /etc/ntfy/server.yml
```
## Prerequisites
- PostgreSQL schema must already be set up (run ntfy with `database-url` once)
- ntfy must not be running during the import
- All three SQLite files are optional; only the ones specified will be imported
## Notes
- The tool is idempotent and safe to re-run
- After importing, row counts and content are verified against the SQLite sources
- Invalid UTF-8 in messages is replaced with the Unicode replacement character

888
tools/pgimport/main.go Normal file
View File

@@ -0,0 +1,888 @@
package main
import (
"database/sql"
"fmt"
"net/url"
"os"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
)
const (
batchSize = 1000
expectedMessageSchemaVersion = 14
expectedUserSchemaVersion = 6
expectedWebPushSchemaVersion = 1
)
var flags = []cli.Flag{
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, Usage: "path to server.yml config file"},
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, Usage: "PostgreSQL connection string"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
}
func main() {
app := &cli.App{
Name: "pgimport",
Usage: "SQLite to PostgreSQL migration tool for ntfy",
UsageText: "pgimport [OPTIONS]",
Flags: flags,
Before: loadConfigFile("config", flags),
Action: execImport,
}
if err := app.Run(os.Args); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func execImport(c *cli.Context) error {
databaseURL := c.String("database-url")
cacheFile := c.String("cache-file")
authFile := c.String("auth-file")
webPushFile := c.String("web-push-file")
if databaseURL == "" {
return fmt.Errorf("database-url must be set (via --database-url or config file)")
}
if cacheFile == "" && authFile == "" && webPushFile == "" {
return fmt.Errorf("at least one of --cache-file, --auth-file, or --web-push-file must be set")
}
fmt.Println("pgimport - SQLite to PostgreSQL migration tool for ntfy")
fmt.Println()
fmt.Println("Sources:")
printSource(" Cache file: ", cacheFile)
printSource(" Auth file: ", authFile)
printSource(" Web push file: ", webPushFile)
fmt.Println()
fmt.Println("Target:")
fmt.Printf(" Database URL: %s\n", maskPassword(databaseURL))
fmt.Println()
fmt.Println("This will import data from the SQLite databases into PostgreSQL.")
fmt.Print("Make sure ntfy is not running. Continue? (y/n): ")
var answer string
fmt.Scanln(&answer)
if strings.TrimSpace(strings.ToLower(answer)) != "y" {
fmt.Println("Aborted.")
return nil
}
fmt.Println()
pgDB, err := db.OpenPostgres(databaseURL)
if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
}
defer pgDB.Close()
if authFile != "" {
if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil {
return err
}
if err := importUsers(authFile, pgDB); err != nil {
return fmt.Errorf("cannot import users: %w", err)
}
}
if cacheFile != "" {
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
return err
}
if err := importMessages(cacheFile, pgDB); err != nil {
return fmt.Errorf("cannot import messages: %w", err)
}
}
if webPushFile != "" {
if err := verifySchemaVersion(pgDB, "webpush", expectedWebPushSchemaVersion); err != nil {
return err
}
if err := importWebPush(webPushFile, pgDB); err != nil {
return fmt.Errorf("cannot import web push subscriptions: %w", err)
}
}
fmt.Println()
fmt.Println("Verifying migration ...")
failed := false
if authFile != "" {
if err := verifyUsers(authFile, pgDB, &failed); err != nil {
return fmt.Errorf("cannot verify users: %w", err)
}
}
if cacheFile != "" {
if err := verifyMessages(cacheFile, pgDB, &failed); err != nil {
return fmt.Errorf("cannot verify messages: %w", err)
}
}
if webPushFile != "" {
if err := verifyWebPush(webPushFile, pgDB, &failed); err != nil {
return fmt.Errorf("cannot verify web push: %w", err)
}
}
fmt.Println()
if failed {
return fmt.Errorf("verification FAILED, see above for details")
}
fmt.Println("Verification successful. Migration complete.")
return nil
}
func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc {
return func(c *cli.Context) error {
configFile := c.String(configFlag)
if configFile == "" {
return nil
}
if _, err := os.Stat(configFile); os.IsNotExist(err) {
return fmt.Errorf("config file %s does not exist", configFile)
}
inputSource, err := newYamlSourceFromFile(configFile, flags)
if err != nil {
return err
}
return altsrc.ApplyInputSourceValues(c, inputSource, flags)
}
}
func newYamlSourceFromFile(file string, flags []cli.Flag) (altsrc.InputSourceContext, error) {
var rawConfig map[any]any
b, err := os.ReadFile(file)
if err != nil {
return nil, err
}
if err := yaml.Unmarshal(b, &rawConfig); err != nil {
return nil, err
}
for _, f := range flags {
flagName := f.Names()[0]
for _, flagAlias := range f.Names()[1:] {
if _, ok := rawConfig[flagAlias]; ok {
rawConfig[flagName] = rawConfig[flagAlias]
}
}
}
return altsrc.NewMapInputSource(file, rawConfig), nil
}
func verifySchemaVersion(pgDB *sql.DB, store string, expected int) error {
var version int
err := pgDB.QueryRow(`SELECT version FROM schema_version WHERE store = $1`, store).Scan(&version)
if err != nil {
return fmt.Errorf("cannot read %s schema version from PostgreSQL (is the schema set up?): %w", store, err)
}
if version != expected {
return fmt.Errorf("%s schema version mismatch: expected %d, got %d", store, expected, version)
}
return nil
}
func printSource(label, path string) {
if path == "" {
fmt.Printf("%s(not set, skipping)\n", label)
} else if _, err := os.Stat(path); os.IsNotExist(err) {
fmt.Printf("%s%s (NOT FOUND, skipping)\n", label, path)
} else {
fmt.Printf("%s%s\n", label, path)
}
}
func maskPassword(databaseURL string) string {
u, err := url.Parse(databaseURL)
if err != nil {
return databaseURL
}
if u.User != nil {
if _, hasPass := u.User.Password(); hasPass {
masked := u.Scheme + "://" + u.User.Username() + ":****@" + u.Host + u.Path
if u.RawQuery != "" {
masked += "?" + u.RawQuery
}
return masked
}
}
return u.String()
}
func openSQLite(filename string) (*sql.DB, error) {
if _, err := os.Stat(filename); os.IsNotExist(err) {
return nil, fmt.Errorf("file %s does not exist", filename)
}
return sql.Open("sqlite3", filename+"?mode=ro")
}
// User import
func importUsers(sqliteFile string, pgDB *sql.DB) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
fmt.Printf("Skipping user import: %s\n", err)
return nil
}
defer sqlDB.Close()
fmt.Printf("Importing users from %s ...\n", sqliteFile)
count, err := importTiers(sqlDB, pgDB)
if err != nil {
return fmt.Errorf("importing tiers: %w", err)
}
fmt.Printf(" Imported %d tiers\n", count)
count, err = importUserRows(sqlDB, pgDB)
if err != nil {
return fmt.Errorf("importing users: %w", err)
}
fmt.Printf(" Imported %d users\n", count)
count, err = importUserAccess(sqlDB, pgDB)
if err != nil {
return fmt.Errorf("importing user access: %w", err)
}
fmt.Printf(" Imported %d access entries\n", count)
count, err = importUserTokens(sqlDB, pgDB)
if err != nil {
return fmt.Errorf("importing user tokens: %w", err)
}
fmt.Printf(" Imported %d tokens\n", count)
count, err = importUserPhones(sqlDB, pgDB)
if err != nil {
return fmt.Errorf("importing user phones: %w", err)
}
fmt.Printf(" Imported %d phone numbers\n", count)
return nil
}
func importTiers(sqlDB, pgDB *sql.DB) (int, error) {
rows, err := sqlDB.Query(`SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier`)
if err != nil {
return 0, err
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (id) DO NOTHING`)
if err != nil {
return 0, err
}
defer stmt.Close()
count := 0
for rows.Next() {
var id, code, name string
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit int64
var attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit int64
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return 0, err
}
if _, err := stmt.Exec(id, code, name, messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeMonthlyPriceID, stripeYearlyPriceID); err != nil {
return 0, err
}
count++
}
return count, tx.Commit()
}
func importUserRows(sqlDB, pgDB *sql.DB) (int, error) {
rows, err := sqlDB.Query(`SELECT id, user, pass, role, prefs, sync_topic, provisioned, stats_messages, stats_emails, stats_calls, stripe_customer_id, stripe_subscription_id, stripe_subscription_status, stripe_subscription_interval, stripe_subscription_paid_until, stripe_subscription_cancel_at, created, deleted, tier_id FROM user`)
if err != nil {
return 0, err
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`
INSERT INTO "user" (id, user_name, pass, role, prefs, sync_topic, provisioned, stats_messages, stats_emails, stats_calls, stripe_customer_id, stripe_subscription_id, stripe_subscription_status, stripe_subscription_interval, stripe_subscription_paid_until, stripe_subscription_cancel_at, created, deleted, tier_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
ON CONFLICT (id) DO NOTHING
`)
if err != nil {
return 0, err
}
defer stmt.Close()
count := 0
for rows.Next() {
var id, userName, pass, role, prefs, syncTopic string
var provisioned int
var statsMessages, statsEmails, statsCalls int64
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval sql.NullString
var stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
var created int64
var deleted sql.NullInt64
var tierID sql.NullString
if err := rows.Scan(&id, &userName, &pass, &role, &prefs, &syncTopic, &provisioned, &statsMessages, &statsEmails, &statsCalls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &created, &deleted, &tierID); err != nil {
return 0, err
}
provisionedBool := provisioned != 0
if _, err := stmt.Exec(id, userName, pass, role, prefs, syncTopic, provisionedBool, statsMessages, statsEmails, statsCalls, stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, created, deleted, tierID); err != nil {
return 0, err
}
count++
}
return count, tx.Commit()
}
func importUserAccess(sqlDB, pgDB *sql.DB) (int, error) {
rows, err := sqlDB.Query(`SELECT a.user_id, a.topic, a.read, a.write, a.owner_user_id, a.provisioned FROM user_access a JOIN user u ON u.id = a.user_id`)
if err != nil {
return 0, err
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id, topic) DO NOTHING`)
if err != nil {
return 0, err
}
defer stmt.Close()
count := 0
for rows.Next() {
var userID, topic string
var read, write, provisioned int
var ownerUserID sql.NullString
if err := rows.Scan(&userID, &topic, &read, &write, &ownerUserID, &provisioned); err != nil {
return 0, err
}
readBool := read != 0
writeBool := write != 0
provisionedBool := provisioned != 0
if _, err := stmt.Exec(userID, topic, readBool, writeBool, ownerUserID, provisionedBool); err != nil {
return 0, err
}
count++
}
return count, tx.Commit()
}
func importUserTokens(sqlDB, pgDB *sql.DB) (int, error) {
rows, err := sqlDB.Query(`SELECT t.user_id, t.token, t.label, t.last_access, t.last_origin, t.expires, t.provisioned FROM user_token t JOIN user u ON u.id = t.user_id`)
if err != nil {
return 0, err
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (user_id, token) DO NOTHING`)
if err != nil {
return 0, err
}
defer stmt.Close()
count := 0
for rows.Next() {
var userID, token, label, lastOrigin string
var lastAccess, expires int64
var provisioned int
if err := rows.Scan(&userID, &token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
return 0, err
}
provisionedBool := provisioned != 0
if _, err := stmt.Exec(userID, token, label, lastAccess, lastOrigin, expires, provisionedBool); err != nil {
return 0, err
}
count++
}
return count, tx.Commit()
}
func importUserPhones(sqlDB, pgDB *sql.DB) (int, error) {
rows, err := sqlDB.Query(`SELECT p.user_id, p.phone_number FROM user_phone p JOIN user u ON u.id = p.user_id`)
if err != nil {
return 0, err
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return 0, err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2) ON CONFLICT (user_id, phone_number) DO NOTHING`)
if err != nil {
return 0, err
}
defer stmt.Close()
count := 0
for rows.Next() {
var userID, phoneNumber string
if err := rows.Scan(&userID, &phoneNumber); err != nil {
return 0, err
}
if _, err := stmt.Exec(userID, phoneNumber); err != nil {
return 0, err
}
count++
}
return count, tx.Commit()
}
// Message import
func importMessages(sqliteFile string, pgDB *sql.DB) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
fmt.Printf("Skipping message import: %s\n", err)
return nil
}
defer sqlDB.Close()
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
rows, err := sqlDB.Query(`SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`)
if err != nil {
return fmt.Errorf("querying messages: %w", err)
}
defer rows.Close()
if _, err := pgDB.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_message_mid_unique ON message (mid)`); err != nil {
return fmt.Errorf("creating unique index on mid: %w", err)
}
insertQuery := `INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) ON CONFLICT (mid) DO NOTHING`
count := 0
batchCount := 0
tx, err := pgDB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(insertQuery)
if err != nil {
return err
}
defer stmt.Close()
for rows.Next() {
var mid, sequenceID, event, topic, message, title, tags, click, icon, actions string
var attachmentName, attachmentType, attachmentURL, sender, userID, contentType, encoding string
var msgTime, expires, attachmentExpires int64
var priority int
var attachmentSize int64
var attachmentDeleted, published int
if err := rows.Scan(&mid, &sequenceID, &msgTime, &event, &expires, &topic, &message, &title, &priority, &tags, &click, &icon, &actions, &attachmentName, &attachmentType, &attachmentSize, &attachmentExpires, &attachmentURL, &attachmentDeleted, &sender, &userID, &contentType, &encoding, &published); err != nil {
return fmt.Errorf("scanning message: %w", err)
}
mid = toUTF8(mid)
sequenceID = toUTF8(sequenceID)
event = toUTF8(event)
topic = toUTF8(topic)
message = toUTF8(message)
title = toUTF8(title)
tags = toUTF8(tags)
click = toUTF8(click)
icon = toUTF8(icon)
actions = toUTF8(actions)
attachmentName = toUTF8(attachmentName)
attachmentType = toUTF8(attachmentType)
attachmentURL = toUTF8(attachmentURL)
sender = toUTF8(sender)
userID = toUTF8(userID)
contentType = toUTF8(contentType)
encoding = toUTF8(encoding)
attachmentDeletedBool := attachmentDeleted != 0
publishedBool := published != 0
if _, err := stmt.Exec(mid, sequenceID, msgTime, event, expires, topic, message, title, priority, tags, click, icon, actions, attachmentName, attachmentType, attachmentSize, attachmentExpires, attachmentURL, attachmentDeletedBool, sender, userID, contentType, encoding, publishedBool); err != nil {
return fmt.Errorf("inserting message: %w", err)
}
count++
batchCount++
if batchCount >= batchSize {
stmt.Close()
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing message batch: %w", err)
}
fmt.Printf(" ... %d messages\n", count)
tx, err = pgDB.Begin()
if err != nil {
return err
}
stmt, err = tx.Prepare(insertQuery)
if err != nil {
return err
}
batchCount = 0
}
}
if batchCount > 0 {
stmt.Close()
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing final message batch: %w", err)
}
}
fmt.Printf(" Imported %d messages\n", count)
var statsValue int64
err = sqlDB.QueryRow(`SELECT value FROM stats WHERE key = 'messages'`).Scan(&statsValue)
if err == nil {
if _, err := pgDB.Exec(`UPDATE message_stats SET value = $1 WHERE key = 'messages'`, statsValue); err != nil {
return fmt.Errorf("updating message stats: %w", err)
}
fmt.Printf(" Updated message stats (count: %d)\n", statsValue)
}
return nil
}
// Web push import
func importWebPush(sqliteFile string, pgDB *sql.DB) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
fmt.Printf("Skipping web push import: %s\n", err)
return nil
}
defer sqlDB.Close()
fmt.Printf("Importing web push subscriptions from %s ...\n", sqliteFile)
rows, err := sqlDB.Query(`SELECT id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at FROM subscription`)
if err != nil {
return fmt.Errorf("querying subscriptions: %w", err)
}
defer rows.Close()
tx, err := pgDB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (id) DO NOTHING`)
if err != nil {
return err
}
defer stmt.Close()
count := 0
for rows.Next() {
var id, endpoint, keyAuth, keyP256dh, userID, subscriberIP string
var updatedAt, warnedAt int64
if err := rows.Scan(&id, &endpoint, &keyAuth, &keyP256dh, &userID, &subscriberIP, &updatedAt, &warnedAt); err != nil {
return fmt.Errorf("scanning subscription: %w", err)
}
if _, err := stmt.Exec(id, endpoint, keyAuth, keyP256dh, userID, subscriberIP, updatedAt, warnedAt); err != nil {
return fmt.Errorf("inserting subscription: %w", err)
}
count++
}
stmt.Close()
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing subscriptions: %w", err)
}
fmt.Printf(" Imported %d subscriptions\n", count)
topicRows, err := sqlDB.Query(`SELECT subscription_id, topic FROM subscription_topic`)
if err != nil {
return fmt.Errorf("querying subscription topics: %w", err)
}
defer topicRows.Close()
tx, err = pgDB.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err = tx.Prepare(`INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2) ON CONFLICT (subscription_id, topic) DO NOTHING`)
if err != nil {
return err
}
defer stmt.Close()
topicCount := 0
for topicRows.Next() {
var subscriptionID, topic string
if err := topicRows.Scan(&subscriptionID, &topic); err != nil {
return fmt.Errorf("scanning subscription topic: %w", err)
}
if _, err := stmt.Exec(subscriptionID, topic); err != nil {
return fmt.Errorf("inserting subscription topic: %w", err)
}
topicCount++
}
stmt.Close()
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing subscription topics: %w", err)
}
fmt.Printf(" Imported %d subscription topics\n", topicCount)
return nil
}
func toUTF8(s string) string {
return strings.ToValidUTF8(s, "\uFFFD")
}
// Verification
func verifyUsers(sqliteFile string, pgDB *sql.DB, failed *bool) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
return nil
}
defer sqlDB.Close()
verifyCount(sqlDB, pgDB, "tier", `SELECT COUNT(*) FROM tier`, `SELECT COUNT(*) FROM tier`, failed)
verifyContent(sqlDB, pgDB, "tier",
`SELECT id, code, name FROM tier ORDER BY id`,
`SELECT id, code, name FROM tier ORDER BY id COLLATE "C"`,
failed)
verifyCount(sqlDB, pgDB, "user", `SELECT COUNT(*) FROM user`, `SELECT COUNT(*) FROM "user"`, failed)
verifyContent(sqlDB, pgDB, "user",
`SELECT id, user, role, sync_topic FROM user ORDER BY id`,
`SELECT id, user_name, role, sync_topic FROM "user" ORDER BY id COLLATE "C"`,
failed)
verifyCount(sqlDB, pgDB, "user_access", `SELECT COUNT(*) FROM user_access a JOIN user u ON u.id = a.user_id`, `SELECT COUNT(*) FROM user_access`, failed)
verifyContent(sqlDB, pgDB, "user_access",
`SELECT a.user_id, a.topic FROM user_access a JOIN user u ON u.id = a.user_id ORDER BY a.user_id, a.topic`,
`SELECT user_id, topic FROM user_access ORDER BY user_id COLLATE "C", topic COLLATE "C"`,
failed)
verifyCount(sqlDB, pgDB, "user_token", `SELECT COUNT(*) FROM user_token t JOIN user u ON u.id = t.user_id`, `SELECT COUNT(*) FROM user_token`, failed)
verifyContent(sqlDB, pgDB, "user_token",
`SELECT t.user_id, t.token, t.label FROM user_token t JOIN user u ON u.id = t.user_id ORDER BY t.user_id, t.token`,
`SELECT user_id, token, label FROM user_token ORDER BY user_id COLLATE "C", token COLLATE "C"`,
failed)
verifyCount(sqlDB, pgDB, "user_phone", `SELECT COUNT(*) FROM user_phone p JOIN user u ON u.id = p.user_id`, `SELECT COUNT(*) FROM user_phone`, failed)
verifyContent(sqlDB, pgDB, "user_phone",
`SELECT p.user_id, p.phone_number FROM user_phone p JOIN user u ON u.id = p.user_id ORDER BY p.user_id, p.phone_number`,
`SELECT user_id, phone_number FROM user_phone ORDER BY user_id COLLATE "C", phone_number COLLATE "C"`,
failed)
return nil
}
func verifyMessages(sqliteFile string, pgDB *sql.DB, failed *bool) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
return nil
}
defer sqlDB.Close()
verifyCount(sqlDB, pgDB, "messages", `SELECT COUNT(*) FROM messages`, `SELECT COUNT(*) FROM message`, failed)
verifySampledMessages(sqlDB, pgDB, failed)
return nil
}
func verifyWebPush(sqliteFile string, pgDB *sql.DB, failed *bool) error {
sqlDB, err := openSQLite(sqliteFile)
if err != nil {
return nil
}
defer sqlDB.Close()
verifyCount(sqlDB, pgDB, "subscription", `SELECT COUNT(*) FROM subscription`, `SELECT COUNT(*) FROM webpush_subscription`, failed)
verifyContent(sqlDB, pgDB, "subscription",
`SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription ORDER BY id`,
`SELECT id, endpoint, key_auth, key_p256dh, user_id FROM webpush_subscription ORDER BY id COLLATE "C"`,
failed)
verifyCount(sqlDB, pgDB, "subscription_topic", `SELECT COUNT(*) FROM subscription_topic`, `SELECT COUNT(*) FROM webpush_subscription_topic`, failed)
verifyContent(sqlDB, pgDB, "subscription_topic",
`SELECT subscription_id, topic FROM subscription_topic ORDER BY subscription_id, topic`,
`SELECT subscription_id, topic FROM webpush_subscription_topic ORDER BY subscription_id COLLATE "C", topic COLLATE "C"`,
failed)
return nil
}
func verifyCount(sqlDB, pgDB *sql.DB, table, sqliteQuery, pgQuery string, failed *bool) {
var sqliteCount, pgCount int64
if err := sqlDB.QueryRow(sqliteQuery).Scan(&sqliteCount); err != nil {
fmt.Printf(" %-25s count ERROR reading SQLite: %s\n", table, err)
*failed = true
return
}
if err := pgDB.QueryRow(pgQuery).Scan(&pgCount); err != nil {
fmt.Printf(" %-25s count ERROR reading PostgreSQL: %s\n", table, err)
*failed = true
return
}
if sqliteCount == pgCount {
fmt.Printf(" %-25s count OK (%d rows)\n", table, pgCount)
} else {
fmt.Printf(" %-25s count MISMATCH: SQLite=%d, PostgreSQL=%d\n", table, sqliteCount, pgCount)
*failed = true
}
}
func verifyContent(sqlDB, pgDB *sql.DB, table, sqliteQuery, pgQuery string, failed *bool) {
sqliteRows, err := sqlDB.Query(sqliteQuery)
if err != nil {
fmt.Printf(" %-25s content ERROR reading SQLite: %s\n", table, err)
*failed = true
return
}
defer sqliteRows.Close()
pgRows, err := pgDB.Query(pgQuery)
if err != nil {
fmt.Printf(" %-25s content ERROR reading PostgreSQL: %s\n", table, err)
*failed = true
return
}
defer pgRows.Close()
cols, err := sqliteRows.Columns()
if err != nil {
fmt.Printf(" %-25s content ERROR reading columns: %s\n", table, err)
*failed = true
return
}
numCols := len(cols)
rowNum := 0
mismatches := 0
for sqliteRows.Next() {
rowNum++
if !pgRows.Next() {
fmt.Printf(" %-25s content MISMATCH: PostgreSQL has fewer rows (at row %d)\n", table, rowNum)
*failed = true
return
}
sqliteVals := makeStringSlice(numCols)
pgVals := makeStringSlice(numCols)
if err := sqliteRows.Scan(sqliteVals...); err != nil {
fmt.Printf(" %-25s content ERROR scanning SQLite row %d: %s\n", table, rowNum, err)
*failed = true
return
}
if err := pgRows.Scan(pgVals...); err != nil {
fmt.Printf(" %-25s content ERROR scanning PostgreSQL row %d: %s\n", table, rowNum, err)
*failed = true
return
}
for i := 0; i < numCols; i++ {
sv := *(sqliteVals[i].(*sql.NullString))
pv := *(pgVals[i].(*sql.NullString))
if sv != pv {
mismatches++
if mismatches <= 3 {
fmt.Printf(" %-25s content MISMATCH at row %d, col %s: SQLite=%q, PostgreSQL=%q\n", table, rowNum, cols[i], sv.String, pv.String)
}
}
}
}
if pgRows.Next() {
fmt.Printf(" %-25s content MISMATCH: PostgreSQL has more rows than SQLite\n", table)
*failed = true
return
}
if mismatches > 0 {
if mismatches > 3 {
fmt.Printf(" %-25s content ... and %d more mismatches\n", table, mismatches-3)
}
*failed = true
} else {
fmt.Printf(" %-25s content OK\n", table)
}
}
func verifySampledMessages(sqlDB, pgDB *sql.DB, failed *bool) {
rows, err := sqlDB.Query(`SELECT mid, topic, time, message, title, tags, priority FROM messages ORDER BY mid`)
if err != nil {
fmt.Printf(" %-25s content ERROR reading SQLite: %s\n", "messages (sampled)", err)
*failed = true
return
}
defer rows.Close()
rowNum := 0
checked := 0
mismatches := 0
for rows.Next() {
rowNum++
var mid, topic, message, title, tags string
var msgTime int64
var priority int
if err := rows.Scan(&mid, &topic, &msgTime, &message, &title, &tags, &priority); err != nil {
fmt.Printf(" %-25s content ERROR scanning SQLite row %d: %s\n", "messages (sampled)", rowNum, err)
*failed = true
return
}
if rowNum%100 != 1 {
continue
}
checked++
var pgTopic, pgMessage, pgTitle, pgTags string
var pgTime int64
var pgPriority int
err := pgDB.QueryRow(`SELECT topic, time, message, title, tags, priority FROM message WHERE mid = $1`, mid).
Scan(&pgTopic, &pgTime, &pgMessage, &pgTitle, &pgTags, &pgPriority)
if err == sql.ErrNoRows {
mismatches++
if mismatches <= 3 {
fmt.Printf(" %-25s content MISMATCH: mid=%s not found in PostgreSQL\n", "messages (sampled)", mid)
}
continue
} else if err != nil {
fmt.Printf(" %-25s content ERROR querying PostgreSQL for mid=%s: %s\n", "messages (sampled)", mid, err)
*failed = true
return
}
topic = toUTF8(topic)
message = toUTF8(message)
title = toUTF8(title)
tags = toUTF8(tags)
if topic != pgTopic || msgTime != pgTime || message != pgMessage || title != pgTitle || tags != pgTags || priority != pgPriority {
mismatches++
if mismatches <= 3 {
fmt.Printf(" %-25s content MISMATCH at mid=%s\n", "messages (sampled)", mid)
}
}
}
if mismatches > 0 {
if mismatches > 3 {
fmt.Printf(" %-25s content ... and %d more mismatches\n", "messages (sampled)", mismatches-3)
}
*failed = true
} else {
fmt.Printf(" %-25s content OK (%d samples checked)\n", "messages (sampled)", checked)
}
}
func makeStringSlice(n int) []any {
vals := make([]any, n)
for i := range vals {
vals[i] = &sql.NullString{}
}
return vals
}

File diff suppressed because it is too large Load Diff

270
user/manager_postgres.go Normal file
View File

@@ -0,0 +1,270 @@
package user
import (
"database/sql"
)
// PostgreSQL queries
const (
// User queries
postgresSelectUserByIDQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = $1
`
postgresSelectUserByNameQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user_name = $1
`
postgresSelectUserByTokenQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
`
postgresSelectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = $1
`
postgresSelectUsernamesQuery = `
SELECT user_name
FROM "user"
ORDER BY
CASE role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, user_name
`
postgresSelectUserCountQuery = `SELECT COUNT(*) FROM "user"`
postgresSelectUserIDFromUsernameQuery = `SELECT id FROM "user" WHERE user_name = $1`
postgresInsertUserQuery = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
postgresUpdateUserPassQuery = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
postgresUpdateUserRoleQuery = `UPDATE "user" SET role = $1 WHERE user_name = $2`
postgresUpdateUserProvisionedQuery = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
postgresUpdateUserPrefsQuery = `UPDATE "user" SET prefs = $1 WHERE id = $2`
postgresUpdateUserStatsQuery = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
postgresUpdateUserStatsResetAllQuery = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
postgresUpdateUserTierQuery = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
postgresUpdateUserDeletedQuery = `UPDATE "user" SET deleted = $1 WHERE id = $2`
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
// Access queries
postgresSelectTopicPermsQuery = `
SELECT read, write
FROM user_access a
JOIN "user" u ON u.id = a.user_id
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
`
postgresSelectUserAllAccessQuery = `
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserAccessQuery = `
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserReservationsQuery = `
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
FROM user_access a_user
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
WHERE a_user.user_id = a_user.owner_user_id
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
ORDER BY a_user.topic
`
postgresSelectUserReservationsCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
`
postgresSelectUserReservationsOwnerQuery = `
SELECT owner_user_id
FROM user_access
WHERE topic = $1
AND user_id = owner_user_id
`
postgresSelectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
AND topic = $2
`
postgresSelectOtherAccessCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
`
postgresUpsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES (
(SELECT id FROM "user" WHERE user_name = $1),
$2,
$3,
$4,
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
$7
)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
postgresDeleteUserAccessQuery = `
DELETE FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
`
postgresDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = true`
postgresDeleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
AND topic = $3
`
postgresDeleteAllAccessQuery = `DELETE FROM user_access`
// Token queries
postgresSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
postgresSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
postgresSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
postgresSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
postgresUpsertTokenQuery = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
postgresDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = $1
AND (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = $2
ORDER BY expires DESC
LIMIT $3
)
`
// Tier queries
postgresInsertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
postgresUpdateTierQuery = `
UPDATE tier
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
WHERE code = $13
`
postgresSelectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
postgresSelectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = $1
`
postgresSelectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
`
postgresDeleteTierQuery = `DELETE FROM tier WHERE code = $1`
// Phone queries
postgresSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = $1`
postgresInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
postgresDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
// Billing queries
postgresUpdateBillingQuery = `
UPDATE "user"
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
WHERE user_name = $7
`
)
// NewPostgresManager creates a new Manager backed by a PostgreSQL database using an existing connection pool.
var postgresQueries = queries{
selectUserByID: postgresSelectUserByIDQuery,
selectUserByName: postgresSelectUserByNameQuery,
selectUserByToken: postgresSelectUserByTokenQuery,
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
selectUsernames: postgresSelectUsernamesQuery,
selectUserCount: postgresSelectUserCountQuery,
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
insertUser: postgresInsertUserQuery,
updateUserPass: postgresUpdateUserPassQuery,
updateUserRole: postgresUpdateUserRoleQuery,
updateUserProvisioned: postgresUpdateUserProvisionedQuery,
updateUserPrefs: postgresUpdateUserPrefsQuery,
updateUserStats: postgresUpdateUserStatsQuery,
updateUserStatsResetAll: postgresUpdateUserStatsResetAllQuery,
updateUserTier: postgresUpdateUserTierQuery,
updateUserDeleted: postgresUpdateUserDeletedQuery,
deleteUser: postgresDeleteUserQuery,
deleteUserTier: postgresDeleteUserTierQuery,
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
selectTopicPerms: postgresSelectTopicPermsQuery,
selectUserAllAccess: postgresSelectUserAllAccessQuery,
selectUserAccess: postgresSelectUserAccessQuery,
selectUserReservations: postgresSelectUserReservationsQuery,
selectUserReservationsCount: postgresSelectUserReservationsCountQuery,
selectUserReservationsOwner: postgresSelectUserReservationsOwnerQuery,
selectUserHasReservation: postgresSelectUserHasReservationQuery,
selectOtherAccessCount: postgresSelectOtherAccessCountQuery,
upsertUserAccess: postgresUpsertUserAccessQuery,
deleteUserAccess: postgresDeleteUserAccessQuery,
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisionedQuery,
deleteTopicAccess: postgresDeleteTopicAccessQuery,
deleteAllAccess: postgresDeleteAllAccessQuery,
selectToken: postgresSelectTokenQuery,
selectTokens: postgresSelectTokensQuery,
selectTokenCount: postgresSelectTokenCountQuery,
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
upsertToken: postgresUpsertTokenQuery,
updateToken: postgresUpdateTokenQuery,
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
deleteToken: postgresDeleteTokenQuery,
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
deleteAllToken: postgresDeleteAllTokenQuery,
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
deleteExcessTokens: postgresDeleteExcessTokensQuery,
insertTier: postgresInsertTierQuery,
selectTiers: postgresSelectTiersQuery,
selectTierByCode: postgresSelectTierByCodeQuery,
selectTierByPriceID: postgresSelectTierByPriceIDQuery,
updateTier: postgresUpdateTierQuery,
deleteTier: postgresDeleteTierQuery,
selectPhoneNumbers: postgresSelectPhoneNumbersQuery,
insertPhoneNumber: postgresInsertPhoneNumberQuery,
deletePhoneNumber: postgresDeletePhoneNumberQuery,
updateBilling: postgresUpdateBillingQuery,
}
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) {
if err := setupPostgres(db); err != nil {
return nil, err
}
return newManager(db, postgresQueries, config)
}

View File

@@ -84,14 +84,14 @@ const (
// Schema table management queries for Postgres
const (
postgresCurrentSchemaVersion = 6
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
postgresCurrentSchemaVersion = 6
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'user'`
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
)
func setupPostgres(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewPostgres(db)
}
@@ -106,7 +106,7 @@ func setupNewPostgres(db *sql.DB) error {
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
if _, err := db.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
return err
}
return nil

View File

@@ -2,38 +2,42 @@ package user
import (
"database/sql"
"fmt"
"path/filepath"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/util"
)
const (
// User queries
sqliteSelectUserByID = `
sqliteSelectUserByIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = ?
`
sqliteSelectUserByName = `
sqliteSelectUserByNameQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user = ?
`
sqliteSelectUserByToken = `
sqliteSelectUserByTokenQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
`
sqliteSelectUserByStripeID = `
sqliteSelectUserByStripeCustomerIDQuery = `
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM user u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = ?
`
sqliteSelectUsernames = `
sqliteSelectUsernamesQuery = `
SELECT user
FROM user
ORDER BY
@@ -43,41 +47,41 @@ const (
ELSE 2
END, user
`
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
sqliteSelectUserCountQuery = `SELECT COUNT(*) FROM user`
sqliteSelectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
sqliteInsertUserQuery = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
sqliteUpdateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
sqliteUpdateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
sqliteUpdateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
sqliteUpdateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
sqliteUpdateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
sqliteUpdateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
sqliteUpdateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
sqliteUpdateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
// Access queries
sqliteSelectTopicPerms = `
sqliteSelectTopicPermsQuery = `
SELECT read, write
FROM user_access a
JOIN user u ON u.id = a.user_id
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
`
sqliteSelectUserAllAccess = `
sqliteSelectUserAllAccessQuery = `
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
sqliteSelectUserAccess = `
sqliteSelectUserAccessQuery = `
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
`
sqliteSelectUserReservations = `
sqliteSelectUserReservationsQuery = `
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
FROM user_access a_user
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
@@ -85,69 +89,68 @@ const (
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
ORDER BY a_user.topic
`
sqliteSelectUserReservationsCount = `
sqliteSelectUserReservationsCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
`
sqliteSelectUserReservationsOwner = `
sqliteSelectUserReservationsOwnerQuery = `
SELECT owner_user_id
FROM user_access
WHERE topic = ?
AND user_id = owner_user_id
`
sqliteSelectUserHasReservation = `
sqliteSelectUserHasReservationQuery = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
AND topic = ?
`
sqliteSelectOtherAccessCount = `
sqliteSelectOtherAccessCountQuery = `
SELECT COUNT(*)
FROM user_access
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
`
sqliteUpsertUserAccess = `
sqliteUpsertUserAccessQuery = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
`
sqliteDeleteUserAccess = `
sqliteDeleteUserAccessQuery = `
DELETE FROM user_access
WHERE user_id = (SELECT id FROM user WHERE user = ?)
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
`
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
sqliteDeleteTopicAccess = `
sqliteDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
sqliteDeleteTopicAccessQuery = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
AND topic = ?
`
sqliteDeleteAllAccess = `DELETE FROM user_access`
sqliteDeleteAllAccessQuery = `DELETE FROM user_access`
// Token queries
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
sqliteUpsertToken = `
sqliteSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
sqliteSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
sqliteSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
sqliteSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
sqliteUpsertTokenQuery = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
sqliteDeleteExcessTokens = `
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
sqliteDeleteExcessTokensQuery = `
DELETE FROM user_token
WHERE user_id = ?
AND (user_id, token) NOT IN (
@@ -160,46 +163,107 @@ const (
`
// Tier queries
sqliteInsertTier = `
sqliteInsertTierQuery = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
sqliteUpdateTier = `
sqliteUpdateTierQuery = `
UPDATE tier
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
WHERE code = ?
`
sqliteSelectTiers = `
sqliteSelectTiersQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
sqliteSelectTierByCode = `
sqliteSelectTierByCodeQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = ?
`
sqliteSelectTierByPriceID = `
sqliteSelectTierByPriceIDQuery = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
`
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
sqliteDeleteTierQuery = `DELETE FROM tier WHERE code = ?`
// Phone queries
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
sqliteSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
sqliteInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
sqliteDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
// Billing queries
sqliteUpdateBilling = `
sqliteUpdateBillingQuery = `
UPDATE user
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
WHERE user = ?
`
)
// NewSQLiteStore creates a new SQLite-backed user store
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
var sqliteQueries = queries{
selectUserByID: sqliteSelectUserByIDQuery,
selectUserByName: sqliteSelectUserByNameQuery,
selectUserByToken: sqliteSelectUserByTokenQuery,
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
selectUsernames: sqliteSelectUsernamesQuery,
selectUserCount: sqliteSelectUserCountQuery,
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
insertUser: sqliteInsertUserQuery,
updateUserPass: sqliteUpdateUserPassQuery,
updateUserRole: sqliteUpdateUserRoleQuery,
updateUserProvisioned: sqliteUpdateUserProvisionedQuery,
updateUserPrefs: sqliteUpdateUserPrefsQuery,
updateUserStats: sqliteUpdateUserStatsQuery,
updateUserStatsResetAll: sqliteUpdateUserStatsResetAllQuery,
updateUserTier: sqliteUpdateUserTierQuery,
updateUserDeleted: sqliteUpdateUserDeletedQuery,
deleteUser: sqliteDeleteUserQuery,
deleteUserTier: sqliteDeleteUserTierQuery,
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
selectTopicPerms: sqliteSelectTopicPermsQuery,
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
selectUserAccess: sqliteSelectUserAccessQuery,
selectUserReservations: sqliteSelectUserReservationsQuery,
selectUserReservationsCount: sqliteSelectUserReservationsCountQuery,
selectUserReservationsOwner: sqliteSelectUserReservationsOwnerQuery,
selectUserHasReservation: sqliteSelectUserHasReservationQuery,
selectOtherAccessCount: sqliteSelectOtherAccessCountQuery,
upsertUserAccess: sqliteUpsertUserAccessQuery,
deleteUserAccess: sqliteDeleteUserAccessQuery,
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisionedQuery,
deleteTopicAccess: sqliteDeleteTopicAccessQuery,
deleteAllAccess: sqliteDeleteAllAccessQuery,
selectToken: sqliteSelectTokenQuery,
selectTokens: sqliteSelectTokensQuery,
selectTokenCount: sqliteSelectTokenCountQuery,
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
upsertToken: sqliteUpsertTokenQuery,
updateToken: sqliteUpdateTokenQuery,
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
deleteToken: sqliteDeleteTokenQuery,
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
deleteAllToken: sqliteDeleteAllTokenQuery,
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
insertTier: sqliteInsertTierQuery,
selectTiers: sqliteSelectTiersQuery,
selectTierByCode: sqliteSelectTierByCodeQuery,
selectTierByPriceID: sqliteSelectTierByPriceIDQuery,
updateTier: sqliteUpdateTierQuery,
deleteTier: sqliteDeleteTierQuery,
selectPhoneNumbers: sqliteSelectPhoneNumbersQuery,
insertPhoneNumber: sqliteInsertPhoneNumberQuery,
deletePhoneNumber: sqliteDeletePhoneNumberQuery,
updateBilling: sqliteUpdateBillingQuery,
}
// NewSQLiteManager creates a new Manager backed by a SQLite database
func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager, error) {
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
}
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@@ -210,64 +274,5 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &commonStore{
db: db,
queries: storeQueries{
selectUserByID: sqliteSelectUserByID,
selectUserByName: sqliteSelectUserByName,
selectUserByToken: sqliteSelectUserByToken,
selectUserByStripeID: sqliteSelectUserByStripeID,
selectUsernames: sqliteSelectUsernames,
selectUserCount: sqliteSelectUserCount,
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
insertUser: sqliteInsertUser,
updateUserPass: sqliteUpdateUserPass,
updateUserRole: sqliteUpdateUserRole,
updateUserProvisioned: sqliteUpdateUserProvisioned,
updateUserPrefs: sqliteUpdateUserPrefs,
updateUserStats: sqliteUpdateUserStats,
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
updateUserTier: sqliteUpdateUserTier,
updateUserDeleted: sqliteUpdateUserDeleted,
deleteUser: sqliteDeleteUser,
deleteUserTier: sqliteDeleteUserTier,
deleteUsersMarked: sqliteDeleteUsersMarked,
selectTopicPerms: sqliteSelectTopicPerms,
selectUserAllAccess: sqliteSelectUserAllAccess,
selectUserAccess: sqliteSelectUserAccess,
selectUserReservations: sqliteSelectUserReservations,
selectUserReservationsCount: sqliteSelectUserReservationsCount,
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
selectUserHasReservation: sqliteSelectUserHasReservation,
selectOtherAccessCount: sqliteSelectOtherAccessCount,
upsertUserAccess: sqliteUpsertUserAccess,
deleteUserAccess: sqliteDeleteUserAccess,
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
deleteTopicAccess: sqliteDeleteTopicAccess,
deleteAllAccess: sqliteDeleteAllAccess,
selectToken: sqliteSelectToken,
selectTokens: sqliteSelectTokens,
selectTokenCount: sqliteSelectTokenCount,
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
upsertToken: sqliteUpsertToken,
updateTokenLabel: sqliteUpdateTokenLabel,
updateTokenExpiry: sqliteUpdateTokenExpiry,
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
deleteToken: sqliteDeleteToken,
deleteProvisionedToken: sqliteDeleteProvisionedToken,
deleteAllToken: sqliteDeleteAllToken,
deleteExpiredTokens: sqliteDeleteExpiredTokens,
deleteExcessTokens: sqliteDeleteExcessTokens,
insertTier: sqliteInsertTier,
selectTiers: sqliteSelectTiers,
selectTierByCode: sqliteSelectTierByCode,
selectTierByPriceID: sqliteSelectTierByPriceID,
updateTier: sqliteUpdateTier,
deleteTier: sqliteDeleteTier,
selectPhoneNumbers: sqliteSelectPhoneNumbers,
insertPhoneNumber: sqliteInsertPhoneNumber,
deletePhoneNumber: sqliteDeletePhoneNumber,
updateBilling: sqliteUpdateBilling,
},
}, nil
return newManager(db, sqliteQueries, config)
}

View File

@@ -103,8 +103,8 @@ const (
// Schema version table management for SQLite
const (
sqliteCurrentSchemaVersion = 6
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
@@ -179,12 +179,12 @@ const (
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
`
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
sqliteMigrate1To2InsertUserNoTx = `
sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery = `SELECT user FROM user_old`
sqliteMigrate1To2InsertUserNoTxQuery = `
INSERT INTO user (id, user, pass, role, sync_topic, created)
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
`
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery = `
INSERT INTO user_access (user_id, topic, read, write)
SELECT u.id, a.topic, a.read, a.write
FROM user u
@@ -352,7 +352,7 @@ func setupNewSQLite(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
@@ -382,7 +382,7 @@ func sqliteMigrateFrom1(db *sql.DB) error {
return err
}
// Insert users from user_old into new user table, with ID and sync_topic
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery)
if err != nil {
return err
}
@@ -401,15 +401,15 @@ func sqliteMigrateFrom1(db *sql.DB) error {
for _, username := range usernames {
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTxQuery, userID, syncTopic, username); err != nil {
return err
}
}
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
return err
}
if err := tx.Commit(); err != nil {
@@ -428,7 +428,7 @@ func sqliteMigrateFrom2(db *sql.DB) error {
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
return err
}
return tx.Commit()
@@ -444,7 +444,7 @@ func sqliteMigrateFrom3(db *sql.DB) error {
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
return err
}
return tx.Commit()
@@ -460,7 +460,7 @@ func sqliteMigrateFrom4(db *sql.DB) error {
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
return err
}
return tx.Commit()
@@ -476,7 +476,7 @@ func sqliteMigrateFrom5(db *sql.DB) error {
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
return err
}
return tx.Commit()

File diff suppressed because it is too large Load Diff

View File

@@ -1,986 +0,0 @@
package user
import (
"database/sql"
"encoding/json"
"errors"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util"
)
// Store is the interface for a user database store
type Store interface {
// User operations
UserByID(id string) (*User, error)
User(username string) (*User, error)
UserByToken(token string) (*User, error)
UserByStripeCustomer(customerID string) (*User, error)
UserIDByUsername(username string) (string, error)
Users() ([]*User, error)
UsersCount() (int64, error)
AddUser(username, hash string, role Role, provisioned bool) error
RemoveUser(username string) error
MarkUserRemoved(userID string) error
RemoveDeletedUsers() error
ChangePassword(username, hash string) error
ChangeRole(username string, role Role) error
ChangeProvisioned(username string, provisioned bool) error
ChangeSettings(userID string, prefs *Prefs) error
ChangeTier(username, tierCode string) error
ResetTier(username string) error
UpdateStats(userID string, stats *Stats) error
ResetStats() error
// Token operations
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
Token(userID, token string) (*Token, error)
Tokens(userID string) ([]*Token, error)
AllProvisionedTokens() ([]*Token, error)
ChangeTokenLabel(userID, token, label string) error
ChangeTokenExpiry(userID, token string, expires time.Time) error
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
RemoveToken(userID, token string) error
RemoveExpiredTokens() error
TokenCount(userID string) (int, error)
RemoveExcessTokens(userID string, maxCount int) error
// Access operations
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
AllGrants() (map[string][]Grant, error)
Grants(username string) ([]Grant, error)
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
ResetAccess(username, topicPattern string) error
ResetAllProvisionedAccess() error
Reservations(username string) ([]Reservation, error)
HasReservation(username, topic string) (bool, error)
ReservationsCount(username string) (int64, error)
ReservationOwner(topic string) (string, error)
OtherAccessCount(username, topic string) (int, error)
// Tier operations
AddTier(tier *Tier) error
UpdateTier(tier *Tier) error
RemoveTier(code string) error
Tiers() ([]*Tier, error)
Tier(code string) (*Tier, error)
TierByStripePrice(priceID string) (*Tier, error)
// Phone operations
PhoneNumbers(userID string) ([]string, error)
AddPhoneNumber(userID, phoneNumber string) error
RemovePhoneNumber(userID, phoneNumber string) error
// Other stuff
ChangeBilling(username string, billing *Billing) error
Close() error
}
// storeQueries holds the database-specific SQL queries
type storeQueries struct {
// User queries
selectUserByID string
selectUserByName string
selectUserByToken string
selectUserByStripeID string
selectUsernames string
selectUserCount string
selectUserIDFromUsername string
insertUser string
updateUserPass string
updateUserRole string
updateUserProvisioned string
updateUserPrefs string
updateUserStats string
updateUserStatsResetAll string
updateUserTier string
updateUserDeleted string
deleteUser string
deleteUserTier string
deleteUsersMarked string
// Access queries
selectTopicPerms string
selectUserAllAccess string
selectUserAccess string
selectUserReservations string
selectUserReservationsCount string
selectUserReservationsOwner string
selectUserHasReservation string
selectOtherAccessCount string
upsertUserAccess string
deleteUserAccess string
deleteUserAccessProvisioned string
deleteTopicAccess string
deleteAllAccess string
// Token queries
selectToken string
selectTokens string
selectTokenCount string
selectAllProvisionedTokens string
upsertToken string
updateTokenLabel string
updateTokenExpiry string
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
deleteAllToken string
deleteExpiredTokens string
deleteExcessTokens string
// Tier queries
insertTier string
selectTiers string
selectTierByCode string
selectTierByPriceID string
updateTier string
deleteTier string
// Phone queries
selectPhoneNumbers string
insertPhoneNumber string
deletePhoneNumber string
// Billing queries
updateBilling string
}
// commonStore implements store operations that work across database backends
type commonStore struct {
db *sql.DB
queries storeQueries
}
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (s *commonStore) UserByID(id string) (*User, error) {
rows, err := s.db.Query(s.queries.selectUserByID, id)
if err != nil {
return nil, err
}
return s.readUser(rows)
}
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (s *commonStore) User(username string) (*User, error) {
rows, err := s.db.Query(s.queries.selectUserByName, username)
if err != nil {
return nil, err
}
return s.readUser(rows)
}
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
func (s *commonStore) UserByToken(token string) (*User, error) {
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
if err != nil {
return nil, err
}
return s.readUser(rows)
}
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
if err != nil {
return nil, err
}
return s.readUser(rows)
}
// Users returns a list of users
func (s *commonStore) Users() ([]*User, error) {
rows, err := s.db.Query(s.queries.selectUsernames)
if err != nil {
return nil, err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
rows.Close()
users := make([]*User, 0)
for _, username := range usernames {
user, err := s.User(username)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UsersCount returns the number of users in the database
func (s *commonStore) UsersCount() (int64, error) {
rows, err := s.db.Query(s.queries.selectUserCount)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// AddUser adds a user with the given username, password hash and role
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
now := time.Now().Unix()
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
if isUniqueConstraintError(err) {
return ErrUserExists
}
return err
}
return nil
}
// RemoveUser deletes the user with the given username
func (s *commonStore) RemoveUser(username string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
// Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
return err
}
return nil
}
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
func (s *commonStore) MarkUserRemoved(userID string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Get username for deleteUserAccess query
user, err := s.UserByID(userID)
if err != nil {
return err
}
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
return err
}
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
return err
}
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
return err
}
return tx.Commit()
}
// RemoveDeletedUsers deletes all users that have been marked deleted
func (s *commonStore) RemoveDeletedUsers() error {
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
return err
}
return nil
}
// ChangePassword changes a user's password
func (s *commonStore) ChangePassword(username, hash string) error {
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
return err
}
return nil
}
// ChangeRole changes a user's role
func (s *commonStore) ChangeRole(username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
return err
}
// If changing to admin, remove all access entries
if role == RoleAdmin {
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
return err
}
}
return tx.Commit()
}
// ChangeProvisioned changes the provisioned status of a user
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
return err
}
return nil
}
// ChangeSettings persists the user settings
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
b, err := json.Marshal(prefs)
if err != nil {
return err
}
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
return err
}
return nil
}
// ChangeTier changes a user's tier using the tier code
func (s *commonStore) ChangeTier(username, tierCode string) error {
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
return err
}
return nil
}
// ResetTier removes the tier from the given user
func (s *commonStore) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
}
_, err := s.db.Exec(s.queries.deleteUserTier, username)
return err
}
// UpdateStats updates the user statistics
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
return err
}
return nil
}
// ResetStats resets all user stats in the user database
func (s *commonStore) ResetStats() error {
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
return err
}
return nil
}
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var provisioned bool
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
user := &User{
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
Provisioned: provisioned,
Stats: &Stats{
Messages: messages,
Emails: emails,
Calls: calls,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
},
Deleted: deleted.Valid,
}
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err
}
if tierCode.Valid {
user.Tier = &Tier{
ID: tierID.String,
Code: tierCode.String,
Name: tierName.String,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
StripeYearlyPriceID: stripeYearlyPriceID.String,
}
}
return user, nil
}
// CreateToken creates a new token
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
return nil, err
}
return &Token{
Value: token,
Label: label,
LastAccess: lastAccess,
LastOrigin: lastOrigin,
Expires: expires,
Provisioned: provisioned,
}, nil
}
// Token returns a specific token for a user
func (s *commonStore) Token(userID, token string) (*Token, error) {
rows, err := s.db.Query(s.queries.selectToken, userID, token)
if err != nil {
return nil, err
}
defer rows.Close()
return s.readToken(rows)
}
// Tokens returns all existing tokens for the user with the given user ID
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
rows, err := s.db.Query(s.queries.selectTokens, userID)
if err != nil {
return nil, err
}
defer rows.Close()
tokens := make([]*Token, 0)
for {
token, err := s.readToken(rows)
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
}
tokens = append(tokens, token)
}
return tokens, nil
}
// AllProvisionedTokens returns all provisioned tokens
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
if err != nil {
return nil, err
}
defer rows.Close()
tokens := make([]*Token, 0)
for {
token, err := s.readToken(rows)
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
}
tokens = append(tokens, token)
}
return tokens, nil
}
// ChangeTokenLabel updates a token's label
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
return err
}
return nil
}
// ChangeTokenExpiry updates a token's expiry time
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
return err
}
return nil
}
// UpdateTokenLastAccess updates a token's last access time and origin
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
return err
}
return nil
}
// RemoveToken deletes the token
func (s *commonStore) RemoveToken(userID, token string) error {
if token == "" {
return errNoTokenProvided
}
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
return err
}
return nil
}
// RemoveExpiredTokens deletes all expired tokens from the database
func (s *commonStore) RemoveExpiredTokens() error {
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
return err
}
return nil
}
// TokenCount returns the number of tokens for a user
func (s *commonStore) TokenCount(userID string) (int, error) {
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
return err
}
return nil
}
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
var token, label, lastOrigin string
var lastAccess, expires int64
var provisioned bool
if !rows.Next() {
return nil, ErrTokenNotFound
}
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
lastOriginIP, err := netip.ParseAddr(lastOrigin)
if err != nil {
lastOriginIP = netip.IPv4Unspecified()
}
return &Token{
Value: token,
Label: label,
LastAccess: time.Unix(lastAccess, 0),
LastOrigin: lastOriginIP,
Expires: time.Unix(expires, 0),
Provisioned: provisioned,
}, nil
}
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
// The found return value indicates whether an ACL entry was found at all.
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
return false, false, false, err
}
defer rows.Close()
if !rows.Next() {
return false, false, false, nil
}
if err := rows.Scan(&read, &write); err != nil {
return false, false, false, err
} else if err := rows.Err(); err != nil {
return false, false, false, err
}
return read, write, true, nil
}
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
rows, err := s.db.Query(s.queries.selectUserAllAccess)
if err != nil {
return nil, err
}
defer rows.Close()
grants := make(map[string][]Grant, 0)
for rows.Next() {
var userID, topic string
var read, write, provisioned bool
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
if _, ok := grants[userID]; !ok {
grants[userID] = make([]Grant, 0)
}
grants[userID] = append(grants[userID], Grant{
TopicPattern: fromSQLWildcard(topic),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
}
// Grants returns all user-specific access control entries
func (s *commonStore) Grants(username string) ([]Grant, error) {
rows, err := s.db.Query(s.queries.selectUserAccess, username)
if err != nil {
return nil, err
}
defer rows.Close()
grants := make([]Grant, 0)
for rows.Next() {
var topic string
var read, write, provisioned bool
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
grants = append(grants, Grant{
TopicPattern: fromSQLWildcard(topic),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
}
// AllowAccess adds or updates an entry in the access control list
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
return err
}
return nil
}
// ResetAccess removes an access control list entry
func (s *commonStore) ResetAccess(username, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
return ErrInvalidArgument
}
if username == "" && topicPattern == "" {
_, err := s.db.Exec(s.queries.deleteAllAccess)
return err
} else if topicPattern == "" {
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
return err
}
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
return err
}
// ResetAllProvisionedAccess removes all provisioned access control entries
func (s *commonStore) ResetAllProvisionedAccess() error {
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
return err
}
return nil
}
// Reservations returns all user-owned topics, and the associated everyone-access
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
if err != nil {
return nil, err
}
defer rows.Close()
reservations := make([]Reservation, 0)
for rows.Next() {
var topic string
var ownerRead, ownerWrite bool
var everyoneRead, everyoneWrite sql.NullBool
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
reservations = append(reservations, Reservation{
Topic: unescapeUnderscore(topic),
Owner: NewPermission(ownerRead, ownerWrite),
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
})
}
return reservations, nil
}
// HasReservation returns true if the given topic access is owned by the user
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
return false, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
// ReservationsCount returns the number of reservations owned by this user
func (s *commonStore) ReservationsCount(username string) (int64, error) {
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
func (s *commonStore) ReservationOwner(topic string) (string, error) {
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", nil
}
var ownerUserID string
if err := rows.Scan(&ownerUserID); err != nil {
return "", err
}
return ownerUserID, nil
}
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// AddTier creates a new tier in the database
func (s *commonStore) AddTier(tier *Tier) error {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err
}
return nil
}
// UpdateTier updates a tier's properties in the database
func (s *commonStore) UpdateTier(tier *Tier) error {
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err
}
return nil
}
// RemoveTier deletes the tier with the given code
func (s *commonStore) RemoveTier(code string) error {
if !AllowedTier(code) {
return ErrInvalidArgument
}
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
return err
}
return nil
}
// Tiers returns a list of all Tier structs
func (s *commonStore) Tiers() ([]*Tier, error) {
rows, err := s.db.Query(s.queries.selectTiers)
if err != nil {
return nil, err
}
defer rows.Close()
tiers := make([]*Tier, 0)
for {
tier, err := s.readTier(rows)
if errors.Is(err, ErrTierNotFound) {
break
} else if err != nil {
return nil, err
}
tiers = append(tiers, tier)
}
return tiers, nil
}
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
func (s *commonStore) Tier(code string) (*Tier, error) {
rows, err := s.db.Query(s.queries.selectTierByCode, code)
if err != nil {
return nil, err
}
defer rows.Close()
return s.readTier(rows)
}
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
if err != nil {
return nil, err
}
defer rows.Close()
return s.readTier(rows)
}
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
return &Tier{
ID: id,
Code: code,
Name: name,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
StripeYearlyPriceID: stripeYearlyPriceID.String,
}, nil
}
// PhoneNumbers returns all phone numbers for the user with the given user ID
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
if err != nil {
return nil, err
}
defer rows.Close()
phoneNumbers := make([]string, 0)
for {
phoneNumber, err := s.readPhoneNumber(rows)
if errors.Is(err, ErrPhoneNumberNotFound) {
break
} else if err != nil {
return nil, err
}
phoneNumbers = append(phoneNumbers, phoneNumber)
}
return phoneNumbers, nil
}
// AddPhoneNumber adds a phone number to the user with the given user ID
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
if isUniqueConstraintError(err) {
return ErrPhoneNumberExists
}
return err
}
return nil
}
// RemovePhoneNumber deletes a phone number from the user with the given user ID
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
return err
}
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
var phoneNumber string
if !rows.Next() {
return "", ErrPhoneNumberNotFound
}
if err := rows.Scan(&phoneNumber); err != nil {
return "", err
} else if err := rows.Err(); err != nil {
return "", err
}
return phoneNumber, nil
}
// ChangeBilling updates a user's billing fields
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err
}
return nil
}
// UserIDByUsername returns the user ID for the given username
func (s *commonStore) UserIDByUsername(username string) (string, error) {
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", ErrUserNotFound
}
var userID string
if err := rows.Scan(&userID); err != nil {
return "", err
}
return userID, nil
}
// Close closes the underlying database
func (s *commonStore) Close() error {
return s.db.Close()
}
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
func isUniqueConstraintError(err error) bool {
errStr := err.Error()
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
}

View File

@@ -1,292 +0,0 @@
package user
import (
"database/sql"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
// PostgreSQL queries
const (
// User queries
postgresSelectUserByID = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.id = $1
`
postgresSelectUserByName = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE user_name = $1
`
postgresSelectUserByToken = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
JOIN user_token tk on u.id = tk.user_id
LEFT JOIN tier t on t.id = u.tier_id
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
`
postgresSelectUserByStripeID = `
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
FROM "user" u
LEFT JOIN tier t on t.id = u.tier_id
WHERE u.stripe_customer_id = $1
`
postgresSelectUsernames = `
SELECT user_name
FROM "user"
ORDER BY
CASE role
WHEN 'admin' THEN 1
WHEN 'anonymous' THEN 3
ELSE 2
END, user_name
`
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
// Access queries
postgresSelectTopicPerms = `
SELECT read, write
FROM user_access a
JOIN "user" u ON u.id = a.user_id
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
`
postgresSelectUserAllAccess = `
SELECT user_id, topic, read, write, provisioned
FROM user_access
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserAccess = `
SELECT topic, read, write, provisioned
FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
`
postgresSelectUserReservations = `
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
FROM user_access a_user
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
WHERE a_user.user_id = a_user.owner_user_id
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
ORDER BY a_user.topic
`
postgresSelectUserReservationsCount = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
`
postgresSelectUserReservationsOwner = `
SELECT owner_user_id
FROM user_access
WHERE topic = $1
AND user_id = owner_user_id
`
postgresSelectUserHasReservation = `
SELECT COUNT(*)
FROM user_access
WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
AND topic = $2
`
postgresSelectOtherAccessCount = `
SELECT COUNT(*)
FROM user_access
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
`
postgresUpsertUserAccess = `
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
VALUES (
(SELECT id FROM "user" WHERE user_name = $1),
$2,
$3,
$4,
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
$7
)
ON CONFLICT (user_id, topic)
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
`
postgresDeleteUserAccess = `
DELETE FROM user_access
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
`
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
postgresDeleteTopicAccess = `
DELETE FROM user_access
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
AND topic = $3
`
postgresDeleteAllAccess = `DELETE FROM user_access`
// Token queries
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
postgresUpsertToken = `
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (user_id, token)
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
`
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
postgresDeleteExcessTokens = `
DELETE FROM user_token
WHERE user_id = $1
AND (user_id, token) NOT IN (
SELECT user_id, token
FROM user_token
WHERE user_id = $2
ORDER BY expires DESC
LIMIT $3
)
`
// Tier queries
postgresInsertTier = `
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`
postgresUpdateTier = `
UPDATE tier
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
WHERE code = $13
`
postgresSelectTiers = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
`
postgresSelectTierByCode = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE code = $1
`
postgresSelectTierByPriceID = `
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
FROM tier
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
`
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
// Phone queries
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
// Billing queries
postgresUpdateBilling = `
UPDATE "user"
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
WHERE user_name = $7
`
)
// NewPostgresStore creates a new PostgreSQL-backed user store
func NewPostgresStore(dsn string) (Store, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
if err := setupPostgres(db); err != nil {
return nil, err
}
return &commonStore{
db: db,
queries: storeQueries{
// User queries
selectUserByID: postgresSelectUserByID,
selectUserByName: postgresSelectUserByName,
selectUserByToken: postgresSelectUserByToken,
selectUserByStripeID: postgresSelectUserByStripeID,
selectUsernames: postgresSelectUsernames,
selectUserCount: postgresSelectUserCount,
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
insertUser: postgresInsertUser,
updateUserPass: postgresUpdateUserPass,
updateUserRole: postgresUpdateUserRole,
updateUserProvisioned: postgresUpdateUserProvisioned,
updateUserPrefs: postgresUpdateUserPrefs,
updateUserStats: postgresUpdateUserStats,
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
updateUserTier: postgresUpdateUserTier,
updateUserDeleted: postgresUpdateUserDeleted,
deleteUser: postgresDeleteUser,
deleteUserTier: postgresDeleteUserTier,
deleteUsersMarked: postgresDeleteUsersMarked,
// Access queries
selectTopicPerms: postgresSelectTopicPerms,
selectUserAllAccess: postgresSelectUserAllAccess,
selectUserAccess: postgresSelectUserAccess,
selectUserReservations: postgresSelectUserReservations,
selectUserReservationsCount: postgresSelectUserReservationsCount,
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
selectUserHasReservation: postgresSelectUserHasReservation,
selectOtherAccessCount: postgresSelectOtherAccessCount,
upsertUserAccess: postgresUpsertUserAccess,
deleteUserAccess: postgresDeleteUserAccess,
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
deleteTopicAccess: postgresDeleteTopicAccess,
deleteAllAccess: postgresDeleteAllAccess,
// Token queries
selectToken: postgresSelectToken,
selectTokens: postgresSelectTokens,
selectTokenCount: postgresSelectTokenCount,
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
upsertToken: postgresUpsertToken,
updateTokenLabel: postgresUpdateTokenLabel,
updateTokenExpiry: postgresUpdateTokenExpiry,
updateTokenLastAccess: postgresUpdateTokenLastAccess,
deleteToken: postgresDeleteToken,
deleteProvisionedToken: postgresDeleteProvisionedToken,
deleteAllToken: postgresDeleteAllToken,
deleteExpiredTokens: postgresDeleteExpiredTokens,
deleteExcessTokens: postgresDeleteExcessTokens,
// Tier queries
insertTier: postgresInsertTier,
selectTiers: postgresSelectTiers,
selectTierByCode: postgresSelectTierByCode,
selectTierByPriceID: postgresSelectTierByPriceID,
updateTier: postgresUpdateTier,
deleteTier: postgresDeleteTier,
// Phone queries
selectPhoneNumbers: postgresSelectPhoneNumbers,
insertPhoneNumber: postgresInsertPhoneNumber,
deletePhoneNumber: postgresDeletePhoneNumber,
// Billing queries
updateBilling: postgresUpdateBilling,
},
}, nil
}

View File

@@ -1,208 +0,0 @@
package user_test
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
func newTestPostgresStore(t *testing.T) user.Store {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
}
// Create a unique schema for this test
schema := fmt.Sprintf("test_%s", util.RandomString(10))
setupDB, err := sql.Open("pgx", dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
// Open store with search_path set to the new schema
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("search_path", schema)
u.RawQuery = q.Encode()
store, err := user.NewPostgresStore(u.String())
require.Nil(t, err)
t.Cleanup(func() {
store.Close()
cleanDB, err := sql.Open("pgx", dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
return store
}
func TestPostgresStoreAddUser(t *testing.T) {
testStoreAddUser(t, newTestPostgresStore(t))
}
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveUser(t *testing.T) {
testStoreRemoveUser(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByID(t *testing.T) {
testStoreUserByID(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByToken(t *testing.T) {
testStoreUserByToken(t, newTestPostgresStore(t))
}
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
}
func TestPostgresStoreUsers(t *testing.T) {
testStoreUsers(t, newTestPostgresStore(t))
}
func TestPostgresStoreUsersCount(t *testing.T) {
testStoreUsersCount(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangePassword(t *testing.T) {
testStoreChangePassword(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeRole(t *testing.T) {
testStoreChangeRole(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokens(t *testing.T) {
testStoreTokens(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemove(t *testing.T) {
testStoreTokenRemove(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
}
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllowAccess(t *testing.T) {
testStoreAllowAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetAccess(t *testing.T) {
testStoreResetAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetAccessAll(t *testing.T) {
testStoreResetAccessAll(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
}
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservations(t *testing.T) {
testStoreReservations(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservationsCount(t *testing.T) {
testStoreReservationsCount(t, newTestPostgresStore(t))
}
func TestPostgresStoreHasReservation(t *testing.T) {
testStoreHasReservation(t, newTestPostgresStore(t))
}
func TestPostgresStoreReservationOwner(t *testing.T) {
testStoreReservationOwner(t, newTestPostgresStore(t))
}
func TestPostgresStoreTiers(t *testing.T) {
testStoreTiers(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierUpdate(t *testing.T) {
testStoreTierUpdate(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierRemove(t *testing.T) {
testStoreTierRemove(t, newTestPostgresStore(t))
}
func TestPostgresStoreTierByStripePrice(t *testing.T) {
testStoreTierByStripePrice(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeTier(t *testing.T) {
testStoreChangeTier(t, newTestPostgresStore(t))
}
func TestPostgresStorePhoneNumbers(t *testing.T) {
testStorePhoneNumbers(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeSettings(t *testing.T) {
testStoreChangeSettings(t, newTestPostgresStore(t))
}
func TestPostgresStoreChangeBilling(t *testing.T) {
testStoreChangeBilling(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpdateStats(t *testing.T) {
testStoreUpdateStats(t, newTestPostgresStore(t))
}
func TestPostgresStoreResetStats(t *testing.T) {
testStoreResetStats(t, newTestPostgresStore(t))
}
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
}
func TestPostgresStoreAllGrants(t *testing.T) {
testStoreAllGrants(t, newTestPostgresStore(t))
}
func TestPostgresStoreOtherAccessCount(t *testing.T) {
testStoreOtherAccessCount(t, newTestPostgresStore(t))
}

View File

@@ -1,180 +0,0 @@
package user_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
)
func newTestSQLiteStore(t *testing.T) user.Store {
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
return store
}
func TestSQLiteStoreAddUser(t *testing.T) {
testStoreAddUser(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveUser(t *testing.T) {
testStoreRemoveUser(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByID(t *testing.T) {
testStoreUserByID(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByToken(t *testing.T) {
testStoreUserByToken(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUsers(t *testing.T) {
testStoreUsers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUsersCount(t *testing.T) {
testStoreUsersCount(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangePassword(t *testing.T) {
testStoreChangePassword(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeRole(t *testing.T) {
testStoreChangeRole(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokens(t *testing.T) {
testStoreTokens(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemove(t *testing.T) {
testStoreTokenRemove(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllowAccess(t *testing.T) {
testStoreAllowAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetAccess(t *testing.T) {
testStoreResetAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetAccessAll(t *testing.T) {
testStoreResetAccessAll(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservations(t *testing.T) {
testStoreReservations(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservationsCount(t *testing.T) {
testStoreReservationsCount(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreHasReservation(t *testing.T) {
testStoreHasReservation(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreReservationOwner(t *testing.T) {
testStoreReservationOwner(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTiers(t *testing.T) {
testStoreTiers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierUpdate(t *testing.T) {
testStoreTierUpdate(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierRemove(t *testing.T) {
testStoreTierRemove(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeTier(t *testing.T) {
testStoreChangeTier(t, newTestSQLiteStore(t))
}
func TestSQLiteStorePhoneNumbers(t *testing.T) {
testStorePhoneNumbers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeSettings(t *testing.T) {
testStoreChangeSettings(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreChangeBilling(t *testing.T) {
testStoreChangeBilling(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpdateStats(t *testing.T) {
testStoreUpdateStats(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreResetStats(t *testing.T) {
testStoreResetStats(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreAllGrants(t *testing.T) {
testStoreAllGrants(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
}

View File

@@ -1,619 +0,0 @@
package user_test
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/user"
)
func testStoreAddUser(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, "phil", u.Name)
require.Equal(t, user.RoleUser, u.Role)
require.False(t, u.Provisioned)
require.NotEmpty(t, u.ID)
require.NotEmpty(t, u.SyncTopic)
}
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
}
func testStoreRemoveUser(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, "phil", u.Name)
require.Nil(t, store.RemoveUser("phil"))
_, err = store.User("phil")
require.Equal(t, user.ErrUserNotFound, err)
}
func testStoreUserByID(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
u, err := store.User("phil")
require.Nil(t, err)
u2, err := store.UserByID(u.ID)
require.Nil(t, err)
require.Equal(t, u.Name, u2.Name)
require.Equal(t, u.ID, u2.ID)
}
func testStoreUserByToken(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
require.Nil(t, err)
require.Equal(t, "tk_test123", tk.Value)
u2, err := store.UserByToken(tk.Value)
require.Nil(t, err)
require.Equal(t, "phil", u2.Name)
}
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
StripeCustomerID: "cus_test123",
StripeSubscriptionID: "sub_test123",
}))
u, err := store.UserByStripeCustomer("cus_test123")
require.Nil(t, err)
require.Equal(t, "phil", u.Name)
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
}
func testStoreUsers(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
users, err := store.Users()
require.Nil(t, err)
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
}
func testStoreUsersCount(t *testing.T, store user.Store) {
count, err := store.UsersCount()
require.Nil(t, err)
require.True(t, count >= 1) // At least the everyone user
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
count2, err := store.UsersCount()
require.Nil(t, err)
require.Equal(t, count+1, count2)
}
func testStoreChangePassword(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, "philhash", u.Hash)
require.Nil(t, store.ChangePassword("phil", "newhash"))
u, err = store.User("phil")
require.Nil(t, err)
require.Equal(t, "newhash", u.Hash)
}
func testStoreChangeRole(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, user.RoleUser, u.Role)
require.Nil(t, store.ChangeRole("phil", user.RoleAdmin))
u, err = store.User("phil")
require.Nil(t, err)
require.Equal(t, user.RoleAdmin, u.Role)
}
func testStoreTokens(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
now := time.Now()
expires := now.Add(24 * time.Hour)
origin := netip.MustParseAddr("9.9.9.9")
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
require.Nil(t, err)
require.Equal(t, "tk_abc", tk.Value)
require.Equal(t, "my token", tk.Label)
// Get single token
tk2, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err)
require.Equal(t, "tk_abc", tk2.Value)
require.Equal(t, "my token", tk2.Label)
// Get all tokens
tokens, err := store.Tokens(u.ID)
require.Nil(t, err)
require.Len(t, tokens, 1)
require.Equal(t, "tk_abc", tokens[0].Value)
// Token count
count, err := store.TokenCount(u.ID)
require.Nil(t, err)
require.Equal(t, 1, count)
}
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
require.Nil(t, err)
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
tk, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err)
require.Equal(t, "new label", tk.Label)
}
func testStoreTokenRemove(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
require.Nil(t, err)
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
_, err = store.Token(u.ID, "tk_abc")
require.Equal(t, user.ErrTokenNotFound, err)
}
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
// Create expired token and active token
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
require.Nil(t, err)
require.Nil(t, store.RemoveExpiredTokens())
// Expired token should be gone
_, err = store.Token(u.ID, "tk_expired")
require.Equal(t, user.ErrTokenNotFound, err)
// Active token should still exist
tk, err := store.Token(u.ID, "tk_active")
require.Nil(t, err)
require.Equal(t, "tk_active", tk.Value)
}
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
// Create 3 tokens with increasing expiry
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
require.Nil(t, err)
}
count, err := store.TokenCount(u.ID)
require.Nil(t, err)
require.Equal(t, 3, count)
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
count, err = store.TokenCount(u.ID)
require.Nil(t, err)
require.Equal(t, 2, count)
// tk_a should be removed (earliest expiry)
_, err = store.Token(u.ID, "tk_a")
require.Equal(t, user.ErrTokenNotFound, err)
// tk_b and tk_c should remain
_, err = store.Token(u.ID, "tk_b")
require.Nil(t, err)
_, err = store.Token(u.ID, "tk_c")
require.Nil(t, err)
}
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
require.Nil(t, err)
newTime := time.Now().Add(5 * time.Minute)
newOrigin := netip.MustParseAddr("5.5.5.5")
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin))
tk, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err)
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
require.Equal(t, newOrigin, tk.LastOrigin)
}
func testStoreAllowAccess(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
grants, err := store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 1)
require.Equal(t, "mytopic", grants[0].TopicPattern)
require.True(t, grants[0].Permission.IsReadWrite())
}
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
grants, err := store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 1)
require.True(t, grants[0].Permission.IsRead())
require.False(t, grants[0].Permission.IsWrite())
}
func testStoreResetAccess(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
grants, err := store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 2)
require.Nil(t, store.ResetAccess("phil", "topic1"))
grants, err = store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 1)
require.Equal(t, "topic2", grants[0].TopicPattern)
}
func testStoreResetAccessAll(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
require.Nil(t, store.ResetAccess("phil", ""))
grants, err := store.Grants("phil")
require.Nil(t, err)
require.Len(t, grants, 0)
}
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic")
require.Nil(t, err)
require.True(t, found)
require.True(t, read)
require.True(t, write)
}
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
require.Nil(t, err)
require.False(t, found)
}
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
read, write, found, err := store.AuthorizeTopicAccess("phil", "secret")
require.Nil(t, err)
require.True(t, found)
require.False(t, read)
require.False(t, write)
}
func testStoreReservations(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
reservations, err := store.Reservations("phil")
require.Nil(t, err)
require.Len(t, reservations, 1)
require.Equal(t, "mytopic", reservations[0].Topic)
require.True(t, reservations[0].Owner.IsReadWrite())
require.True(t, reservations[0].Everyone.IsRead())
require.False(t, reservations[0].Everyone.IsWrite())
}
func testStoreReservationsCount(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
count, err := store.ReservationsCount("phil")
require.Nil(t, err)
require.Equal(t, int64(2), count)
}
func testStoreHasReservation(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
has, err := store.HasReservation("phil", "mytopic")
require.Nil(t, err)
require.True(t, has)
has, err = store.HasReservation("phil", "other")
require.Nil(t, err)
require.False(t, has)
}
func testStoreReservationOwner(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
owner, err := store.ReservationOwner("mytopic")
require.Nil(t, err)
require.NotEmpty(t, owner) // Returns the user ID
owner, err = store.ReservationOwner("unowned")
require.Nil(t, err)
require.Empty(t, owner)
}
func testStoreTiers(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
Name: "Pro",
MessageLimit: 5000,
MessageExpiryDuration: 24 * time.Hour,
EmailLimit: 100,
CallLimit: 10,
ReservationLimit: 20,
AttachmentFileSizeLimit: 10 * 1024 * 1024,
AttachmentTotalSizeLimit: 100 * 1024 * 1024,
AttachmentExpiryDuration: 48 * time.Hour,
AttachmentBandwidthLimit: 500 * 1024 * 1024,
}
require.Nil(t, store.AddTier(tier))
// Get by code
t2, err := store.Tier("pro")
require.Nil(t, err)
require.Equal(t, "ti_test", t2.ID)
require.Equal(t, "pro", t2.Code)
require.Equal(t, "Pro", t2.Name)
require.Equal(t, int64(5000), t2.MessageLimit)
require.Equal(t, int64(100), t2.EmailLimit)
require.Equal(t, int64(10), t2.CallLimit)
require.Equal(t, int64(20), t2.ReservationLimit)
// List all tiers
tiers, err := store.Tiers()
require.Nil(t, err)
require.Len(t, tiers, 1)
require.Equal(t, "pro", tiers[0].Code)
}
func testStoreTierUpdate(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
Name: "Pro",
}
require.Nil(t, store.AddTier(tier))
tier.Name = "Professional"
tier.MessageLimit = 9999
require.Nil(t, store.UpdateTier(tier))
t2, err := store.Tier("pro")
require.Nil(t, err)
require.Equal(t, "Professional", t2.Name)
require.Equal(t, int64(9999), t2.MessageLimit)
}
func testStoreTierRemove(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
Name: "Pro",
}
require.Nil(t, store.AddTier(tier))
t2, err := store.Tier("pro")
require.Nil(t, err)
require.Equal(t, "pro", t2.Code)
require.Nil(t, store.RemoveTier("pro"))
_, err = store.Tier("pro")
require.Equal(t, user.ErrTierNotFound, err)
}
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
Name: "Pro",
StripeMonthlyPriceID: "price_monthly",
StripeYearlyPriceID: "price_yearly",
}
require.Nil(t, store.AddTier(tier))
t2, err := store.TierByStripePrice("price_monthly")
require.Nil(t, err)
require.Equal(t, "pro", t2.Code)
t3, err := store.TierByStripePrice("price_yearly")
require.Nil(t, err)
require.Equal(t, "pro", t3.Code)
}
func testStoreChangeTier(t *testing.T, store user.Store) {
tier := &user.Tier{
ID: "ti_test",
Code: "pro",
Name: "Pro",
}
require.Nil(t, store.AddTier(tier))
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.ChangeTier("phil", "pro"))
u, err := store.User("phil")
require.Nil(t, err)
require.NotNil(t, u.Tier)
require.Equal(t, "pro", u.Tier.Code)
}
func testStorePhoneNumbers(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890"))
require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321"))
numbers, err := store.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Len(t, numbers, 2)
require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890"))
numbers, err = store.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Len(t, numbers, 1)
require.Equal(t, "+0987654321", numbers[0])
}
func testStoreChangeSettings(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
lang := "de"
prefs := &user.Prefs{Language: &lang}
require.Nil(t, store.ChangeSettings(u.ID, prefs))
u2, err := store.User("phil")
require.Nil(t, err)
require.NotNil(t, u2.Prefs)
require.Equal(t, "de", *u2.Prefs.Language)
}
func testStoreChangeBilling(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
billing := &user.Billing{
StripeCustomerID: "cus_123",
StripeSubscriptionID: "sub_456",
}
require.Nil(t, store.ChangeBilling("phil", billing))
u, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
}
func testStoreUpdateStats(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
require.Nil(t, store.UpdateStats(u.ID, stats))
u2, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, int64(42), u2.Stats.Messages)
require.Equal(t, int64(3), u2.Stats.Emails)
require.Equal(t, int64(1), u2.Stats.Calls)
}
func testStoreResetStats(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
require.Nil(t, store.ResetStats())
u2, err := store.User("phil")
require.Nil(t, err)
require.Equal(t, int64(0), u2.Stats.Messages)
require.Equal(t, int64(0), u2.Stats.Emails)
require.Equal(t, int64(0), u2.Stats.Calls)
}
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID))
u2, err := store.User("phil")
require.Nil(t, err)
require.True(t, u2.Deleted)
}
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID))
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
// Immediately after marking, the user should still exist.
require.Nil(t, store.RemoveDeletedUsers())
u2, err := store.User("phil")
require.Nil(t, err)
require.True(t, u2.Deleted)
}
func testStoreAllGrants(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
phil, err := store.User("phil")
require.Nil(t, err)
ben, err := store.User("ben")
require.Nil(t, err)
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false))
grants, err := store.AllGrants()
require.Nil(t, err)
require.Contains(t, grants, phil.ID)
require.Contains(t, grants, ben.ID)
}
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
count, err := store.OtherAccessCount("phil", "mytopic")
require.Nil(t, err)
require.Equal(t, 1, count)
}

View File

@@ -2,11 +2,12 @@ package user
import (
"errors"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"net/netip"
"strings"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
)
// User is a struct that represents a user
@@ -273,3 +274,72 @@ var (
ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user")
ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token")
)
// queries holds the database-specific SQL queries
type queries struct {
// User queries
selectUserByID string
selectUserByName string
selectUserByToken string
selectUserByStripeCustomerID string
selectUsernames string
selectUserCount string
selectUserIDFromUsername string
insertUser string
updateUserPass string
updateUserRole string
updateUserProvisioned string
updateUserPrefs string
updateUserStats string
updateUserStatsResetAll string
updateUserTier string
updateUserDeleted string
deleteUser string
deleteUserTier string
deleteUsersMarked string
// Access queries
selectTopicPerms string
selectUserAllAccess string
selectUserAccess string
selectUserReservations string
selectUserReservationsCount string
selectUserReservationsOwner string
selectUserHasReservation string
selectOtherAccessCount string
upsertUserAccess string
deleteUserAccess string
deleteUserAccessProvisioned string
deleteTopicAccess string
deleteAllAccess string
// Token queries
selectToken string
selectTokens string
selectTokenCount string
selectAllProvisionedTokens string
upsertToken string
updateToken string
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
deleteAllToken string
deleteExpiredTokens string
deleteExcessTokens string
// Tier queries
insertTier string
selectTiers string
selectTierByCode string
selectTierByPriceID string
updateTier string
deleteTier string
// Phone queries
selectPhoneNumbers string
insertPhoneNumber string
deletePhoneNumber string
// Billing queries
updateBilling string
}

View File

@@ -113,3 +113,35 @@ func escapeUnderscore(s string) string {
func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View File

@@ -48,5 +48,26 @@
"notifications_none_for_topic_title": "לא קיבלת התראות בנושא הזה עדיין.",
"notifications_none_for_topic_description": "כדי לשלוח התראות לנושא הזה, צריך לשלוח PUT או POST לכתובת הנושא הזה.",
"notifications_none_for_any_title": "לא קיבלת התראות כלל.",
"notifications_no_subscriptions_title": "נראה שלא נרשמת למינויים עדיין."
"notifications_no_subscriptions_title": "נראה שלא נרשמת למינויים עדיין.",
"action_bar_toggle_mute": "השתקת/הפעלת התראות",
"action_bar_toggle_action_menu": "פתיחת/סגירת תפריט הפעולות",
"action_bar_profile_title": "פרופיל",
"action_bar_profile_settings": "הגדרות",
"action_bar_profile_logout": "יציאה",
"action_bar_sign_in": "כניסה",
"action_bar_sign_up": "הרשמה",
"message_bar_type_message": "כאן ניתן להקליד הודעה",
"message_bar_error_publishing": "שגיאה בפרסום ההתראה",
"message_bar_show_dialog": "הצגת חלונית פרסום",
"message_bar_publish": "פרסום הודעה",
"nav_topics_title": "נושאים שנרשמת אליהם",
"nav_button_all_notifications": "כל ההתראות",
"nav_button_account": "חשבון",
"nav_button_settings": "הגדרות",
"nav_button_documentation": "תיעוד",
"nav_button_publish_message": "פרסום התראה",
"nav_button_subscribe": "הרשמה לנושא",
"nav_button_muted": "התראות הושתקו",
"nav_button_connecting": "מתחבר",
"nav_upgrade_banner_label": "שדרוג ל־ntfy Pro"
}

View File

@@ -21,26 +21,19 @@ var (
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
// Store is the interface for a web push subscription store.
type Store interface {
UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
SubscriptionsForTopic(topic string) ([]*Subscription, error)
SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error)
MarkExpiryWarningSent(subscriptions []*Subscription) error
RemoveSubscriptionsByEndpoint(endpoint string) error
RemoveSubscriptionsByUserID(userID string) error
RemoveExpiredSubscriptions(expireAfter time.Duration) error
SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error
Close() error
// Store holds the database connection and queries for web push subscriptions.
type Store struct {
db *sql.DB
queries queries
}
// storeQueries holds the database-specific SQL queries.
type storeQueries struct {
// queries holds the database-specific SQL queries.
type queries struct {
selectSubscriptionIDByEndpoint string
selectSubscriptionCountBySubscriberIP string
selectSubscriptionsForTopic string
selectSubscriptionsExpiringSoon string
insertSubscription string
upsertSubscription string
updateSubscriptionWarningSent string
updateSubscriptionUpdatedAt string
deleteSubscriptionByEndpoint string
@@ -51,14 +44,8 @@ type storeQueries struct {
deleteSubscriptionTopicWithoutSubscription string
}
// commonStore implements store operations that are identical across database backends.
type commonStore struct {
db *sql.DB
queries storeQueries
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := s.db.Begin()
if err != nil {
return err
@@ -71,8 +58,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
}
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
err = tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
if errors.Is(err, sql.ErrNoRows) {
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
@@ -82,7 +68,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(s.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
@@ -98,7 +84,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
}
// SubscriptionsForTopic returns all subscriptions for the given topic.
func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
if err != nil {
return nil, err
@@ -108,7 +94,7 @@ func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, erro
}
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
@@ -118,7 +104,7 @@ func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscri
}
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error {
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
tx, err := s.db.Begin()
if err != nil {
return err
@@ -133,13 +119,13 @@ func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
func (s *Store) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
return err
}
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return ErrWebPushUserIDCannotBeEmpty
}
@@ -148,7 +134,7 @@ func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
if err != nil {
return err
@@ -158,14 +144,14 @@ func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) erro
}
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
// exported for testing purposes and is not part of the Store interface.
func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
// exported for testing purposes.
func (s *Store) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
return err
}
// Close closes the underlying database connection.
func (s *commonStore) Close() error {
func (s *Store) Close() error {
return s.db.Close()
}

View File

@@ -3,12 +3,10 @@ package webpush
import (
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
const (
pgCreateTablesQuery = `
postgresCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS webpush_subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL UNIQUE,
@@ -20,6 +18,8 @@ const (
warned_at BIGINT NOT NULL DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
CREATE INDEX IF NOT EXISTS idx_webpush_updated_at ON webpush_subscription (updated_at);
CREATE INDEX IF NOT EXISTS idx_webpush_user_id ON webpush_subscription (user_id);
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
topic TEXT NOT NULL,
@@ -32,79 +32,72 @@ const (
);
`
pgSelectSubscriptionIDByEndpoint = `SELECT id FROM webpush_subscription WHERE endpoint = $1`
pgSelectSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1`
pgSelectSubscriptionsForTopicQuery = `
postgresSelectSubscriptionIDByEndpointQuery = `SELECT id FROM webpush_subscription WHERE endpoint = $1`
postgresSelectSubscriptionCountBySubscriberIPQuery = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1`
postgresSelectSubscriptionsForTopicQuery = `
SELECT s.id, s.endpoint, s.key_auth, s.key_p256dh, s.user_id
FROM webpush_subscription_topic st
JOIN webpush_subscription s ON s.id = st.subscription_id
WHERE st.topic = $1
ORDER BY s.endpoint
`
pgSelectSubscriptionsExpiringSoonQuery = `
postgresSelectSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM webpush_subscription
WHERE warned_at = 0 AND updated_at <= $1
`
pgInsertSubscriptionQuery = `
postgresUpsertSubscriptionQuery = `
INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
pgUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
pgUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
pgDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1`
pgDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1`
pgDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1`
postgresUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
postgresUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
postgresDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1`
postgresDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1`
postgresDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1`
pgInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)`
pgDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1`
pgDeleteSubscriptionTopicWithoutSubscription = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)`
postgresInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)`
postgresDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1`
postgresDeleteSubscriptionTopicWithoutSubscriptionQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)`
)
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 1
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
pgCurrentSchemaVersion = 1
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
)
// NewPostgresStore creates a new PostgreSQL-backed web push store.
func NewPostgresStore(dsn string) (Store, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
func NewPostgresStore(db *sql.DB) (*Store, error) {
if err := setupPostgresDB(db); err != nil {
return nil, err
}
return &commonStore{
return &Store{
db: db,
queries: storeQueries{
selectSubscriptionIDByEndpoint: pgSelectSubscriptionIDByEndpoint,
selectSubscriptionCountBySubscriberIP: pgSelectSubscriptionCountBySubscriberIP,
selectSubscriptionsForTopic: pgSelectSubscriptionsForTopicQuery,
selectSubscriptionsExpiringSoon: pgSelectSubscriptionsExpiringSoonQuery,
insertSubscription: pgInsertSubscriptionQuery,
updateSubscriptionWarningSent: pgUpdateSubscriptionWarningSentQuery,
updateSubscriptionUpdatedAt: pgUpdateSubscriptionUpdatedAtQuery,
deleteSubscriptionByEndpoint: pgDeleteSubscriptionByEndpointQuery,
deleteSubscriptionByUserID: pgDeleteSubscriptionByUserIDQuery,
deleteSubscriptionByAge: pgDeleteSubscriptionByAgeQuery,
insertSubscriptionTopic: pgInsertSubscriptionTopicQuery,
deleteSubscriptionTopicAll: pgDeleteSubscriptionTopicAllQuery,
deleteSubscriptionTopicWithoutSubscription: pgDeleteSubscriptionTopicWithoutSubscription,
queries: queries{
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery,
selectSubscriptionsExpiringSoon: postgresSelectSubscriptionsExpiringSoonQuery,
upsertSubscription: postgresUpsertSubscriptionQuery,
updateSubscriptionWarningSent: postgresUpdateSubscriptionWarningSentQuery,
updateSubscriptionUpdatedAt: postgresUpdateSubscriptionUpdatedAtQuery,
deleteSubscriptionByEndpoint: postgresDeleteSubscriptionByEndpointQuery,
deleteSubscriptionByUserID: postgresDeleteSubscriptionByUserIDQuery,
deleteSubscriptionByAge: postgresDeleteSubscriptionByAgeQuery,
insertSubscriptionTopic: postgresInsertSubscriptionTopicQuery,
deleteSubscriptionTopicAll: postgresDeleteSubscriptionTopicAllQuery,
deleteSubscriptionTopicWithoutSubscription: postgresDeleteSubscriptionTopicWithoutSubscriptionQuery,
},
}, nil
}
func setupPostgresDB(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewPostgresDB(db)
}
@@ -120,10 +113,10 @@ func setupNewPostgresDB(db *sql.DB) error {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()

View File

@@ -1,91 +0,0 @@
package webpush_test
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
"heckel.io/ntfy/v2/webpush"
)
func newTestPostgresStore(t *testing.T) webpush.Store {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
}
// Create a unique schema for this test
schema := fmt.Sprintf("test_%s", util.RandomString(10))
setupDB, err := sql.Open("pgx", dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
// Open store with search_path set to the new schema
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("search_path", schema)
u.RawQuery = q.Encode()
store, err := webpush.NewPostgresStore(u.String())
require.Nil(t, err)
t.Cleanup(func() {
store.Close()
cleanDB, err := sql.Open("pgx", dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
return store
}
func TestPostgresStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
testStoreUpsertSubscriptionUpdateTopics(t, newTestPostgresStore(t))
}
func TestPostgresStoreUpsertSubscriptionUpdateFields(t *testing.T) {
testStoreUpsertSubscriptionUpdateFields(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserIDMultiple(t *testing.T) {
testStoreRemoveByUserIDMultiple(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByEndpoint(t *testing.T) {
testStoreRemoveByEndpoint(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserID(t *testing.T) {
testStoreRemoveByUserID(t, newTestPostgresStore(t))
}
func TestPostgresStoreRemoveByUserIDEmpty(t *testing.T) {
testStoreRemoveByUserIDEmpty(t, newTestPostgresStore(t))
}
func TestPostgresStoreExpiryWarningSent(t *testing.T) {
store := newTestPostgresStore(t)
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
}
func TestPostgresStoreExpiring(t *testing.T) {
store := newTestPostgresStore(t)
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
}
func TestPostgresStoreRemoveExpired(t *testing.T) {
store := newTestPostgresStore(t)
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
}

View File

@@ -8,8 +8,7 @@ import (
)
const (
sqliteCreateWebPushSubscriptionsTableQuery = `
BEGIN;
sqliteCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
@@ -32,53 +31,52 @@ const (
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
);
`
sqliteBuiltinStartupQueries = `
PRAGMA foreign_keys = ON;
`
sqliteSelectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
sqliteSelectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
sqliteSelectWebPushSubscriptionsForTopicQuery = `
sqliteSelectSubscriptionIDByEndpointQuery = `SELECT id FROM subscription WHERE endpoint = ?`
sqliteSelectSubscriptionCountBySubscriberIPQuery = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
sqliteSelectSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`
sqliteSelectWebPushSubscriptionsExpiringSoonQuery = `
sqliteSelectSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`
sqliteInsertWebPushSubscriptionQuery = `
sqliteUpsertSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
sqliteUpdateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
sqliteUpdateWebPushSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
sqliteDeleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
sqliteDeleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
sqliteDeleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
sqliteUpdateSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
sqliteUpdateSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
sqliteDeleteSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
sqliteDeleteSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
sqliteDeleteSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
sqliteInsertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
sqliteDeleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
sqliteDeleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
sqliteInsertSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
sqliteDeleteSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
sqliteDeleteSubscriptionTopicWithoutSubscriptionQuery = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
)
// SQLite schema management queries
const (
sqliteCurrentWebPushSchemaVersion = 1
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
sqliteCurrentSchemaVersion = 1
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// NewSQLiteStore creates a new SQLite-backed web push store.
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
@@ -89,46 +87,51 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &commonStore{
return &Store{
db: db,
queries: storeQueries{
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpoint,
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIP,
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,
selectSubscriptionsExpiringSoon: sqliteSelectWebPushSubscriptionsExpiringSoonQuery,
insertSubscription: sqliteInsertWebPushSubscriptionQuery,
updateSubscriptionWarningSent: sqliteUpdateWebPushSubscriptionWarningSentQuery,
updateSubscriptionUpdatedAt: sqliteUpdateWebPushSubscriptionUpdatedAtQuery,
deleteSubscriptionByEndpoint: sqliteDeleteWebPushSubscriptionByEndpointQuery,
deleteSubscriptionByUserID: sqliteDeleteWebPushSubscriptionByUserIDQuery,
deleteSubscriptionByAge: sqliteDeleteWebPushSubscriptionByAgeQuery,
insertSubscriptionTopic: sqliteInsertWebPushSubscriptionTopicQuery,
deleteSubscriptionTopicAll: sqliteDeleteWebPushSubscriptionTopicAllQuery,
deleteSubscriptionTopicWithoutSubscription: sqliteDeleteWebPushSubscriptionTopicWithoutSubscription,
queries: queries{
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,
selectSubscriptionsForTopic: sqliteSelectSubscriptionsForTopicQuery,
selectSubscriptionsExpiringSoon: sqliteSelectSubscriptionsExpiringSoonQuery,
upsertSubscription: sqliteUpsertSubscriptionQuery,
updateSubscriptionWarningSent: sqliteUpdateSubscriptionWarningSentQuery,
updateSubscriptionUpdatedAt: sqliteUpdateSubscriptionUpdatedAtQuery,
deleteSubscriptionByEndpoint: sqliteDeleteSubscriptionByEndpointQuery,
deleteSubscriptionByUserID: sqliteDeleteSubscriptionByUserIDQuery,
deleteSubscriptionByAge: sqliteDeleteSubscriptionByAgeQuery,
insertSubscriptionTopic: sqliteInsertSubscriptionTopicQuery,
deleteSubscriptionTopicAll: sqliteDeleteSubscriptionTopicAllQuery,
deleteSubscriptionTopicWithoutSubscription: sqliteDeleteSubscriptionTopicWithoutSubscriptionQuery,
},
}, nil
}
func setupSQLite(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewSQLite(db)
}
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
return nil
}
func setupNewSQLite(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
tx, err := db.Begin()
if err != nil {
return err
}
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
defer tx.Rollback()
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err
}
return nil
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {

View File

@@ -1,63 +0,0 @@
package webpush_test
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/webpush"
)
func newTestSQLiteStore(t *testing.T) webpush.Store {
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
return store
}
func TestSQLiteStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
testStoreUpsertSubscriptionUpdateTopics(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreUpsertSubscriptionUpdateFields(t *testing.T) {
testStoreUpsertSubscriptionUpdateFields(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserIDMultiple(t *testing.T) {
testStoreRemoveByUserIDMultiple(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByEndpoint(t *testing.T) {
testStoreRemoveByEndpoint(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserID(t *testing.T) {
testStoreRemoveByUserID(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreRemoveByUserIDEmpty(t *testing.T) {
testStoreRemoveByUserIDEmpty(t, newTestSQLiteStore(t))
}
func TestSQLiteStoreExpiryWarningSent(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
}
func TestSQLiteStoreExpiring(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
}
func TestSQLiteStoreRemoveExpired(t *testing.T) {
store := newTestSQLiteStore(t)
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
}

View File

@@ -3,211 +3,250 @@ package webpush_test
import (
"fmt"
"net/netip"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/webpush"
)
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpush.Store) {
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := store.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
subs2, err := store.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
func forEachBackend(t *testing.T, f func(t *testing.T, store *webpush.Store)) {
t.Run("sqlite", func(t *testing.T) {
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
t.Cleanup(func() { store.Close() })
f(t, store)
})
t.Run("postgres", func(t *testing.T) {
testDB := dbtest.CreateTestPostgres(t)
store, err := webpush.NewPostgresStore(testDB)
require.Nil(t, err)
f(t, store)
})
}
func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store webpush.Store) {
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// Another one for the same endpoint should be fine
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := store.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
// But with a different endpoint it should fail
require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different IP address it should be fine again
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
subs2, err := store.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
})
}
func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store) {
// Insert subscription with two topics, and another with one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
// Another one for the same endpoint should be fine
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err = store.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
// But with a different endpoint it should fail
require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// Update the first subscription to have only one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = store.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
// But with a different IP address it should be fine again
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
})
}
func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store) {
// Insert a subscription
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics, and another with one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, "auth-key", subs[0].Auth)
require.Equal(t, "p256dh-key", subs[0].P256dh)
require.Equal(t, "u_1234", subs[0].UserID)
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
// Re-upsert the same endpoint with different auth, p256dh, and userID
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "new-auth", "new-p256dh", "u_5678", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = store.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
require.Equal(t, "new-auth", subs[0].Auth)
require.Equal(t, "new-p256dh", subs[0].P256dh)
require.Equal(t, "u_5678", subs[0].UserID)
// Update the first subscription to have only one topic
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = store.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
// Insert two subscriptions for u_1234 and one for u_5678
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"2", "auth-key", "p256dh-key", "u_5678", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert a subscription
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 3)
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, "auth-key", subs[0].Auth)
require.Equal(t, "p256dh-key", subs[0].P256dh)
require.Equal(t, "u_1234", subs[0].UserID)
// Remove all subscriptions for u_1234
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
// Re-upsert the same endpoint with different auth, p256dh, and userID
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "new-auth", "new-p256dh", "u_5678", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
// Only u_5678's subscription should remain
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
require.Equal(t, "u_5678", subs[0].UserID)
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
require.Equal(t, "new-auth", subs[0].Auth)
require.Equal(t, "new-p256dh", subs[0].P256dh)
require.Equal(t, "u_5678", subs[0].UserID)
})
}
func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert two subscriptions for u_1234 and one for u_5678
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"2", "auth-key", "p256dh-key", "u_5678", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
// And remove it again
require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 3)
// Remove all subscriptions for u_1234
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
// Only u_5678's subscription should remain
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
require.Equal(t, "u_5678", subs[0].UserID)
})
}
func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
func TestStoreRemoveByEndpoint(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
// And remove it again
require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreRemoveByUserIDEmpty(t *testing.T, store webpush.Store) {
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
func TestStoreRemoveByUserID(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
// Set updated_at to the past so it shows up as expiring
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Verify subscription appears in expiring list (warned_at == 0)
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
// Mark them as warning sent
require.Nil(t, store.MarkExpiryWarningSent(subs))
// Verify subscription no longer appears in expiring list (warned_at > 0)
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 0)
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
})
}
func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
func TestStoreExpiryWarningSent(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
// Fake-mark them as soon-to-expire
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Set updated_at to the past so it shows up as expiring
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Should not be cleaned up yet
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
// Verify subscription appears in expiring list (warned_at == 0)
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
// Run expiration
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
// Mark them as warning sent
require.Nil(t, store.MarkExpiryWarningSent(subs))
// Verify subscription no longer appears in expiring list (warned_at > 0)
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 0)
})
}
func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
func TestStoreExpiring(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as expired
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
// Fake-mark them as soon-to-expire
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
// Run expiration
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
// Should not be cleaned up yet
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
// List again, should be 0
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
// Run expiration
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
})
}
func TestStoreRemoveExpired(t *testing.T) {
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
// Insert subscription with two topics
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as expired
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
// Run expiration
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
// List again, should be 0
subs, err = store.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
})
}