Compare commits
38 Commits
1364-copy-
...
postgres-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e4a48b058 | ||
|
|
939b3d1117 | ||
|
|
9cc9891f49 | ||
|
|
0d1f3444f2 | ||
|
|
2716ede6e1 | ||
|
|
ae5e1fe8d8 | ||
|
|
e3a402ed95 | ||
|
|
1abc1005d0 | ||
|
|
909c3fe17b | ||
|
|
07c3e280bf | ||
|
|
b567b4e904 | ||
|
|
60fa50f0d5 | ||
|
|
ceda5ec3d8 | ||
|
|
3d72845c81 | ||
|
|
82e15d84bd | ||
|
|
4e5f95ba0c | ||
|
|
869b972a50 | ||
|
|
bdd20197b3 | ||
|
|
a8dcecdb6d | ||
|
|
5331437664 | ||
|
|
e432bf2886 | ||
|
|
0edad84d86 | ||
|
|
ddf728acd1 | ||
|
|
b1d3671dbb | ||
|
|
3e6b46ec0c | ||
|
|
b16d381626 | ||
|
|
3bd1a1ea03 | ||
|
|
7adb37b94b | ||
|
|
bc08819525 | ||
|
|
a03a37feb1 | ||
|
|
4cd556f5aa | ||
|
|
90aeb811ff | ||
|
|
c6ab380ea4 | ||
|
|
7860f2142c | ||
|
|
18d5d31bd2 | ||
|
|
cfdc364e3f | ||
|
|
763215ecfa | ||
|
|
49991d5aa7 |
16
.github/workflows/release.yaml
vendored
16
.github/workflows/release.yaml
vendored
@@ -6,6 +6,22 @@ on:
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: ntfy
|
||||
POSTGRES_PASSWORD: ntfy
|
||||
POSTGRES_DB: ntfy_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U ntfy"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
18
.github/workflows/test.yaml
vendored
18
.github/workflows/test.yaml
vendored
@@ -3,6 +3,22 @@ on: [ push, pull_request ]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: ntfy
|
||||
POSTGRES_PASSWORD: ntfy
|
||||
POSTGRES_DB: ntfy_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U ntfy"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
env:
|
||||
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
@@ -23,7 +39,7 @@ jobs:
|
||||
- name: Build web app (required for tests)
|
||||
run: make web
|
||||
- name: Run tests, formatting, vetting and linting
|
||||
run: make check
|
||||
run: make checkv
|
||||
- name: Run coverage
|
||||
run: make coverage
|
||||
- name: Upload coverage to codecov.io
|
||||
|
||||
@@ -40,6 +40,7 @@ ADD ./log ./log
|
||||
ADD ./server ./server
|
||||
ADD ./user ./user
|
||||
ADD ./util ./util
|
||||
ADD ./payments ./payments
|
||||
RUN make VERSION=$VERSION COMMIT=$COMMIT cli-linux-server
|
||||
|
||||
FROM alpine
|
||||
|
||||
15
Makefile
15
Makefile
@@ -1,4 +1,5 @@
|
||||
MAKEFLAGS := --jobs=1
|
||||
NPM := npm
|
||||
PYTHON := python3
|
||||
PIP := pip3
|
||||
VERSION := $(shell git describe --tag)
|
||||
@@ -137,7 +138,7 @@ web: web-deps web-build
|
||||
|
||||
web-build:
|
||||
cd web \
|
||||
&& npm run build \
|
||||
&& $(NPM) run build \
|
||||
&& mv build/index.html build/app.html \
|
||||
&& rm -rf ../server/site \
|
||||
&& mv build ../server/site \
|
||||
@@ -145,20 +146,20 @@ web-build:
|
||||
../server/site/config.js
|
||||
|
||||
web-deps:
|
||||
cd web && npm install
|
||||
cd web && $(NPM) install
|
||||
# If this fails for .svg files, optimize them with svgo
|
||||
|
||||
web-deps-update:
|
||||
cd web && npm update
|
||||
cd web && $(NPM) update
|
||||
|
||||
web-fmt:
|
||||
cd web && npm run format
|
||||
cd web && $(NPM) run format
|
||||
|
||||
web-fmt-check:
|
||||
cd web && npm run format:check
|
||||
cd web && $(NPM) run format:check
|
||||
|
||||
web-lint:
|
||||
cd web && npm run lint
|
||||
cd web && $(NPM) run lint
|
||||
|
||||
# Main server/client build
|
||||
|
||||
@@ -264,6 +265,8 @@ cli-build-results:
|
||||
|
||||
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||
|
||||
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||
|
||||
test: .PHONY
|
||||
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ var flagsServe = append(
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "key-file", Aliases: []string{"key_file", "K"}, EnvVars: []string{"NTFY_KEY_FILE"}, Usage: "private key file, if listen-https is set"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cert-file", Aliases: []string{"cert_file", "E"}, EnvVars: []string{"NTFY_CERT_FILE"}, Usage: "certificate file, if listen-https is set"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "firebase-key-file", Aliases: []string{"firebase_key_file", "F"}, EnvVars: []string{"NTFY_FIREBASE_KEY_FILE"}, Usage: "Firebase credentials file; if set additionally publish to FCM topic"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores (e.g. postgres://user:pass@host:5432/ntfy)"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file", "C"}, EnvVars: []string{"NTFY_CACHE_FILE"}, Usage: "cache file used for message caching"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-duration", Aliases: []string{"cache_duration", "b"}, EnvVars: []string{"NTFY_CACHE_DURATION"}, Value: util.FormatDuration(server.DefaultCacheDuration), Usage: "buffer messages for this time to allow `since` requests"}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{Name: "cache-batch-size", Aliases: []string{"cache_batch_size"}, EnvVars: []string{"NTFY_BATCH_SIZE"}, Usage: "max size of messages to batch together when writing to message cache (if zero, writes are synchronous)"}),
|
||||
@@ -143,6 +144,7 @@ func execServe(c *cli.Context) error {
|
||||
keyFile := c.String("key-file")
|
||||
certFile := c.String("cert-file")
|
||||
firebaseKeyFile := c.String("firebase-key-file")
|
||||
databaseURL := c.String("database-url")
|
||||
webPushPrivateKey := c.String("web-push-private-key")
|
||||
webPushPublicKey := c.String("web-push-public-key")
|
||||
webPushFile := c.String("web-push-file")
|
||||
@@ -284,8 +286,8 @@ func execServe(c *cli.Context) error {
|
||||
return errors.New("if set, FCM key file must exist")
|
||||
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
||||
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
||||
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || webPushFile == "" || webPushEmailAddress == "" || baseURL == "") {
|
||||
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file, web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
|
||||
} else if webPushPublicKey != "" && (webPushPrivateKey == "" || (webPushFile == "" && databaseURL == "") || webPushEmailAddress == "" || baseURL == "") {
|
||||
return errors.New("if web push is enabled, web-push-private-key, web-push-public-key, web-push-file (or database-url), web-push-email-address, and base-url should be set. run 'ntfy webpush keys' to generate keys")
|
||||
} else if keepaliveInterval < 5*time.Second {
|
||||
return errors.New("keepalive interval cannot be lower than five seconds")
|
||||
} else if managerInterval < 5*time.Second {
|
||||
@@ -494,6 +496,7 @@ func execServe(c *cli.Context) error {
|
||||
conf.EnableMetrics = enableMetrics
|
||||
conf.MetricsListenHTTP = metricsListenHTTP
|
||||
conf.ProfileListenHTTP = profileListenHTTP
|
||||
conf.DatabaseURL = databaseURL
|
||||
conf.WebPushPrivateKey = webPushPrivateKey
|
||||
conf.WebPushPublicKey = webPushPublicKey
|
||||
conf.WebPushFile = webPushFile
|
||||
|
||||
25
cmd/user.go
25
cmd/user.go
@@ -29,6 +29,7 @@ var flagsUser = append(
|
||||
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
|
||||
)
|
||||
|
||||
var cmdUser = &cli.Command{
|
||||
@@ -365,24 +366,32 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
authFile := c.String("auth-file")
|
||||
authStartupQueries := c.String("auth-startup-queries")
|
||||
authDefaultAccess := c.String("auth-default-access")
|
||||
if authFile == "" {
|
||||
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
|
||||
} else if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
databaseURL := c.String("database-url")
|
||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||
if err != nil {
|
||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||
}
|
||||
authConfig := &user.Config{
|
||||
Filename: authFile,
|
||||
StartupQueries: authStartupQueries,
|
||||
DefaultAccess: authDefault,
|
||||
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
return user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if databaseURL != "" {
|
||||
store, err = user.NewPostgresStore(databaseURL)
|
||||
} else if authFile != "" {
|
||||
if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
|
||||
} else {
|
||||
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.NewManager(store, authConfig)
|
||||
}
|
||||
|
||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||
|
||||
@@ -144,6 +144,20 @@ the message to the subscribers.
|
||||
Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscribe/api.md#poll-for-messages), as well as the
|
||||
[`since=` parameter](subscribe/api.md#fetch-cached-messages).
|
||||
|
||||
## PostgreSQL database
|
||||
By default, ntfy uses SQLite for all database-backed stores. As an alternative, you can configure ntfy to use PostgreSQL
|
||||
by setting the `database-url` option to a PostgreSQL connection string:
|
||||
|
||||
```yaml
|
||||
database-url: "postgres://user:pass@host:5432/ntfy"
|
||||
```
|
||||
|
||||
When `database-url` is set, ntfy will use PostgreSQL for the web push subscription store instead of SQLite. The
|
||||
`web-push-file` option is not required in this case. Support for PostgreSQL for the message cache and user manager
|
||||
will be added in future releases.
|
||||
|
||||
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
|
||||
|
||||
## Attachments
|
||||
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
|
||||
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
|
||||
@@ -1141,12 +1155,15 @@ a database to keep track of the browser's subscriptions, and an admin email addr
|
||||
|
||||
- `web-push-public-key` is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||
- `web-push-private-key` is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
|
||||
- `web-push-file` is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db` (not required if `database-url` is set)
|
||||
- `web-push-email-address` is the admin email address send to the push provider, e.g. `sysadmin@example.com`
|
||||
- `web-push-startup-queries` is an optional list of queries to run on startup`
|
||||
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
|
||||
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
|
||||
|
||||
Alternatively, you can use PostgreSQL instead of SQLite for the web push subscription store by setting `database-url`
|
||||
(see [PostgreSQL database](#postgresql-database)).
|
||||
|
||||
Limitations:
|
||||
|
||||
- Like foreground browser notifications, background push notifications require the web app to be served over HTTPS. A _valid_
|
||||
@@ -1172,9 +1189,10 @@ web-push-file: /var/cache/ntfy/webpush.db
|
||||
web-push-email-address: sysadmin@example.com
|
||||
```
|
||||
|
||||
The `web-push-file` is used to store the push subscriptions. Unused subscriptions will send out a warning after 55 days,
|
||||
and will automatically expire after 60 days (default). If the gateway returns an error (e.g. 410 Gone when a user has unsubscribed),
|
||||
subscriptions are also removed automatically.
|
||||
The `web-push-file` is used to store the push subscriptions in a local SQLite database. Alternatively, if `database-url`
|
||||
is set, subscriptions are stored in PostgreSQL and `web-push-file` is not required. Unused subscriptions will send out
|
||||
a warning after 55 days, and will automatically expire after 60 days (default). If the gateway returns an error
|
||||
(e.g. 410 Gone when a user has unsubscribed), subscriptions are also removed automatically.
|
||||
|
||||
The web app refreshes subscriptions on start and regularly on an interval, but this file should be persisted across restarts. If the subscription
|
||||
file is deleted or lost, any web apps that aren't open will not receive new web push notifications until you open then.
|
||||
@@ -1755,6 +1773,7 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
|
||||
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
|
||||
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
|
||||
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
|
||||
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for database-backed stores instead of SQLite. Currently applies to the web push store. See [PostgreSQL database](#postgresql-database). |
|
||||
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
|
||||
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
|
||||
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
|
||||
|
||||
@@ -340,10 +340,6 @@ Then either follow the steps for building with or without Firebase.
|
||||
Without Firebase, you may want to still change the default `app_base_url` in [values.xml](https://github.com/binwiederhier/ntfy-android/blob/main/app/src/main/res/values/values.xml)
|
||||
if you're self-hosting the server. Then run:
|
||||
```
|
||||
# Remove Google dependencies (FCM)
|
||||
sed -i -e '/google-services/d' build.gradle
|
||||
sed -i -e '/google-services/d' app/build.gradle
|
||||
|
||||
# To build an unsigned .apk (app/build/outputs/apk/fdroid/*.apk)
|
||||
./gradlew assembleFdroidRelease
|
||||
|
||||
@@ -351,6 +347,8 @@ sed -i -e '/google-services/d' app/build.gradle
|
||||
./gradlew bundleFdroidRelease
|
||||
```
|
||||
|
||||
The F-Droid flavor automatically excludes Google Services dependencies.
|
||||
|
||||
### Build Play flavor (FCM)
|
||||
!!! info
|
||||
I do build the ntfy Android app using IntelliJ IDEA (Android Studio), so I don't know if these Gradle commands will
|
||||
|
||||
@@ -30,37 +30,37 @@ deb/rpm packages.
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_amd64.tar.gz
|
||||
tar zxvf ntfy_2.17.0_linux_amd64.tar.gz
|
||||
sudo cp -a ntfy_2.17.0_linux_amd64/ntfy /usr/local/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.17.0_linux_amd64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv6.tar.gz
|
||||
tar zxvf ntfy_2.17.0_linux_armv6.tar.gz
|
||||
sudo cp -a ntfy_2.17.0_linux_armv6/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.17.0_linux_armv6/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv7.tar.gz
|
||||
tar zxvf ntfy_2.17.0_linux_armv7.tar.gz
|
||||
sudo cp -a ntfy_2.17.0_linux_armv7/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.17.0_linux_armv7/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.16.0_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.16.0_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_arm64.tar.gz
|
||||
tar zxvf ntfy_2.17.0_linux_arm64.tar.gz
|
||||
sudo cp -a ntfy_2.17.0_linux_arm64/ntfy /usr/bin/ntfy
|
||||
sudo mkdir /etc/ntfy && sudo cp ntfy_2.17.0_linux_arm64/{client,server}/*.yml /etc/ntfy
|
||||
sudo ntfy serve
|
||||
```
|
||||
|
||||
@@ -116,7 +116,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_amd64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -124,7 +124,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv6.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -132,7 +132,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv7.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -140,7 +140,7 @@ Manually installing the .deb file:
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.deb
|
||||
wget https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_arm64.deb
|
||||
sudo dpkg -i ntfy_*.deb
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
@@ -150,28 +150,28 @@ Manually installing the .deb file:
|
||||
|
||||
=== "x86_64/amd64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_amd64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_amd64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv6"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv6.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv6.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "armv7/armhf"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_armv7.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_armv7.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
|
||||
=== "arm64"
|
||||
```bash
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_linux_arm64.rpm
|
||||
sudo rpm -ivh https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_linux_arm64.rpm
|
||||
sudo systemctl enable ntfy
|
||||
sudo systemctl start ntfy
|
||||
```
|
||||
@@ -213,18 +213,18 @@ pkg install go-ntfy
|
||||
|
||||
## macOS
|
||||
The [ntfy CLI](subscribe/cli.md) (`ntfy publish` and `ntfy subscribe` only) is supported on macOS as well.
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz),
|
||||
To install, please [download the tarball](https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_darwin_all.tar.gz),
|
||||
extract it and place it somewhere in your `PATH` (e.g. `/usr/local/bin/ntfy`).
|
||||
|
||||
If run as `root`, ntfy will look for its config at `/etc/ntfy/client.yml`. For all other users, it'll look for it at
|
||||
`~/Library/Application Support/ntfy/client.yml` (sample included in the tarball).
|
||||
|
||||
```bash
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_darwin_all.tar.gz > ntfy_2.16.0_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.16.0_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.16.0_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
curl -L https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_darwin_all.tar.gz > ntfy_2.17.0_darwin_all.tar.gz
|
||||
tar zxvf ntfy_2.17.0_darwin_all.tar.gz
|
||||
sudo cp -a ntfy_2.17.0_darwin_all/ntfy /usr/local/bin/ntfy
|
||||
mkdir ~/Library/Application\ Support/ntfy
|
||||
cp ntfy_2.16.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
cp ntfy_2.17.0_darwin_all/client/client.yml ~/Library/Application\ Support/ntfy/client.yml
|
||||
ntfy --help
|
||||
```
|
||||
|
||||
@@ -245,7 +245,7 @@ brew install ntfy
|
||||
The ntfy server and CLI are fully supported on Windows. You can run the ntfy server directly or as a Windows service.
|
||||
To install, you can either
|
||||
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.16.0/ntfy_2.16.0_windows_amd64.zip),
|
||||
* [Download the latest ZIP](https://github.com/binwiederhier/ntfy/releases/download/v2.17.0/ntfy_2.17.0_windows_amd64.zip),
|
||||
extract it and place the `ntfy.exe` binary somewhere in your `%Path%`.
|
||||
* Or install ntfy from the [Scoop](https://scoop.sh) main repository via `scoop install ntfy`
|
||||
|
||||
|
||||
@@ -184,6 +184,7 @@ I've added a ⭐ to projects or posts that have a significant following, or had
|
||||
- [BRun](https://github.com/cbrake/brun) - Native Linux automation platform connecting triggers to actions without containers (Go)
|
||||
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
|
||||
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
|
||||
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
|
||||
|
||||
## Blog + forum posts
|
||||
|
||||
|
||||
@@ -2304,6 +2304,11 @@ _Supported on:_ :material-android: :material-firefox:
|
||||
The `copy` action **copies a given value to the clipboard when the action button is tapped**. This is useful for
|
||||
one-time passcodes, tokens, or any other value you want to quickly copy without opening the full notification.
|
||||
|
||||
!!! info
|
||||
The copy action button is only shown in the web app and Android app notification list, **not** in browser desktop
|
||||
notifications. This is because browsers do not allow clipboard access from notification actions without direct
|
||||
user interaction with the page.
|
||||
|
||||
Here's an example using the [`X-Actions` header](#using-a-header):
|
||||
|
||||
=== "Command line (curl)"
|
||||
|
||||
@@ -6,12 +6,45 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
|
||||
| Component | Version | Release date |
|
||||
|------------------|---------|--------------|
|
||||
| ntfy server | v2.16.0 | Jan 19, 2026 |
|
||||
| ntfy server | v2.17.0 | Feb 8, 2026 |
|
||||
| ntfy Android app | v1.22.2 | Jan 25, 2026 |
|
||||
| ntfy iOS app | v1.3 | Nov 26, 2023 |
|
||||
|
||||
Please check out the release notes for [upcoming releases](#not-released-yet) below.
|
||||
|
||||
## ntfy server v2.17.0
|
||||
Released February 8, 2026
|
||||
|
||||
This release adds support for templating in the priority field, a new "copy" action button to copy values to the clipboard,
|
||||
a red notification dot on the favicon for unread messages, and an admin-only version endpoint. It also includes several
|
||||
crash fixes, web app improvements, and documentation updates.
|
||||
|
||||
❤️ If you like ntfy, **please consider sponsoring me** via [GitHub Sponsors](https://github.com/sponsors/binwiederhier), [Liberapay](https://en.liberapay.com/ntfy/), Bitcoin (`1626wjrw3uWk9adyjCfYwafw4sQWujyjn8`),
|
||||
or by buying a [paid plan via the web app](https://ntfy.sh/app). ntfy will always remain open source.
|
||||
|
||||
**Features:**
|
||||
|
||||
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
|
||||
* Server: Add admin-only `GET /v1/version` endpoint returning server version, build commit, and date ([#1599](https://github.com/binwiederhier/ntfy/issues/1599), thanks to [@crivchri](https://github.com/crivchri) for reporting)
|
||||
* Server/Web: [Support "copy" action](publish.md#copy-to-clipboard) button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
|
||||
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
|
||||
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
|
||||
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
|
||||
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
|
||||
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
|
||||
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
|
||||
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
|
||||
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
|
||||
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
|
||||
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
|
||||
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
|
||||
|
||||
## ntfy Android app v1.22.2
|
||||
Released January 20, 2026
|
||||
|
||||
@@ -1681,27 +1714,10 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
||||
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
|
||||
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
|
||||
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
|
||||
* Fix crash in settings when fragment is detached during backup/restore or log operations
|
||||
|
||||
### ntfy server v2.17.x (UNRELEASED)
|
||||
### ntfy server v2.12.x (UNRELEASED)
|
||||
|
||||
**Features:**
|
||||
|
||||
* Server: Support templating in the priority field ([#1426](https://github.com/binwiederhier/ntfy/issues/1426), thanks to [@seantomburke](https://github.com/seantomburke) for reporting)
|
||||
* Server/Web: Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||
* Web: Show red notification dot on favicon when there are unread messages ([#1017](https://github.com/binwiederhier/ntfy/issues/1017), thanks to [@ad-si](https://github.com/ad-si) for reporting)
|
||||
|
||||
**Bug fixes + maintenance:**
|
||||
|
||||
* Server: Fix crash when commit string is shorter than 7 characters in non-GitHub-Action builds ([#1493](https://github.com/binwiederhier/ntfy/issues/1493), thanks to [@cyrinux](https://github.com/cyrinux) for reporting)
|
||||
* Server: Fix server crash (nil pointer panic) when subscriber disconnects during publish ([#1598](https://github.com/binwiederhier/ntfy/pull/1598))
|
||||
* Server: Fix log spam from `http: response.WriteHeader on hijacked connection` for WebSocket errors ([#1362](https://github.com/binwiederhier/ntfy/issues/1362), thanks to [@bonfiresh](https://github.com/bonfiresh) for reporting)
|
||||
* Server: Use `slices.Contains` from stdlib to simplify code ([#1406](https://github.com/binwiederhier/ntfy/pull/1406), thanks to [@tanhuaan](https://github.com/tanhuaan))
|
||||
* Web: Fix `clear=true` on action buttons not clearing the notification ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||
* Web: Fix Markdown message line height to match plain text (1.5 instead of 1.2) ([#1139](https://github.com/binwiederhier/ntfy/issues/1139), thanks to [@etfz](https://github.com/etfz) for reporting)
|
||||
* Web: Fix long lines (e.g. JSON) being truncated by adding horizontal scroll ([#1363](https://github.com/binwiederhier/ntfy/issues/1363), thanks to [@v3DJG6GL](https://github.com/v3DJG6GL) for reporting)
|
||||
* Web: Fix Windows notification icon being cut off ([#884](https://github.com/binwiederhier/ntfy/issues/884), thanks to [@ZhangTianrong](https://github.com/ZhangTianrong) for reporting)
|
||||
* Web: Use full URL in curl example on empty topic pages ([#1435](https://github.com/binwiederhier/ntfy/issues/1435), [#1535](https://github.com/binwiederhier/ntfy/pull/1535), thanks to [@elmatadoor](https://github.com/elmatadoor) for reporting and [@jjasghar](https://github.com/jjasghar) for the PR)
|
||||
* Web: Add validation feedback for service URL when adding user ([#1566](https://github.com/binwiederhier/ntfy/issues/1566), thanks to [@jermanuts](https://github.com/jermanuts))
|
||||
* Docs: Remove obsolete `version` field from docker-compose examples ([#1333](https://github.com/binwiederhier/ntfy/issues/1333), thanks to [@seals187](https://github.com/seals187) for reporting and [@cyb3rko](https://github.com/cyb3rko) for fixing)
|
||||
* Docs: Fix Kustomize config in installation docs ([#1367](https://github.com/binwiederhier/ntfy/issues/1367), thanks to [@toby-griffiths](https://github.com/toby-griffiths))
|
||||
* Docs: Use SVG F-Droid badge and add app store badges to README ([#1170](https://github.com/binwiederhier/ntfy/issues/1170), thanks to [@PanderMusubi](https://github.com/PanderMusubi) for reporting)
|
||||
* Add PostgreSQL as an alternative database backend for the web push subscription store via `database-url` config option
|
||||
|
||||
4
go.mod
4
go.mod
@@ -30,6 +30,7 @@ require github.com/pkg/errors v0.9.1 // indirect
|
||||
require (
|
||||
firebase.google.com/go/v4 v4.19.0
|
||||
github.com/SherClockHolmes/webpush-go v1.4.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/microcosm-cc/bluemonday v1.0.27
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/stripe/stripe-go/v74 v74.30.0
|
||||
@@ -71,6 +72,9 @@ require (
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
||||
9
go.sum
9
go.sum
@@ -104,6 +104,14 @@ github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@@ -144,6 +152,7 @@ github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xI
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
|
||||
628
message/store.go
Normal file
628
message/store.go
Normal file
@@ -0,0 +1,628 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagMessageCache = "message_cache"
|
||||
)
|
||||
|
||||
var errNoRows = errors.New("no rows found")
|
||||
|
||||
// Store is the interface for a message cache store
|
||||
type Store interface {
|
||||
AddMessage(m *model.Message) error
|
||||
AddMessages(ms []*model.Message) error
|
||||
DB() *sql.DB
|
||||
Message(id string) (*model.Message, error)
|
||||
MessageCounts() (map[string]int, error)
|
||||
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
||||
MessagesDue() ([]*model.Message, error)
|
||||
MessagesExpired() ([]string, error)
|
||||
MarkPublished(m *model.Message) error
|
||||
UpdateMessageTime(messageID string, timestamp int64) error
|
||||
Topics() ([]string, error)
|
||||
DeleteMessages(ids ...string) error
|
||||
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
|
||||
ExpireMessages(topics ...string) error
|
||||
AttachmentsExpired() ([]string, error)
|
||||
MarkAttachmentsDeleted(ids ...string) error
|
||||
AttachmentBytesUsedBySender(sender string) (int64, error)
|
||||
AttachmentBytesUsedByUser(userID string) (int64, error)
|
||||
UpdateStats(messages int64) error
|
||||
Stats() (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries
|
||||
type storeQueries struct {
|
||||
insertMessage string
|
||||
deleteMessage string
|
||||
selectScheduledMessageIDsBySeqID string
|
||||
deleteScheduledBySequenceID string
|
||||
updateMessagesForTopicExpiry string
|
||||
selectRowIDFromMessageID string
|
||||
selectMessagesByID string
|
||||
selectMessagesSinceTime string
|
||||
selectMessagesSinceTimeScheduled string
|
||||
selectMessagesSinceID string
|
||||
selectMessagesSinceIDScheduled string
|
||||
selectMessagesLatest string
|
||||
selectMessagesDue string
|
||||
selectMessagesExpired string
|
||||
updateMessagePublished string
|
||||
selectMessagesCount string
|
||||
selectMessageCountPerTopic string
|
||||
selectTopics string
|
||||
updateAttachmentDeleted string
|
||||
selectAttachmentsExpired string
|
||||
selectAttachmentsSizeBySender string
|
||||
selectAttachmentsSizeByUserID string
|
||||
selectStats string
|
||||
updateStats string
|
||||
updateMessageTime string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that are identical across database backends
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queue *util.BatchingQueue[*model.Message]
|
||||
nop bool
|
||||
mu sync.Mutex
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
|
||||
var queue *util.BatchingQueue[*model.Message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||
}
|
||||
c := &commonStore{
|
||||
db: db,
|
||||
queue: queue,
|
||||
nop: nop,
|
||||
queries: queries,
|
||||
}
|
||||
go c.processMessageBatches()
|
||||
return c
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection
|
||||
func (c *commonStore) DB() *sql.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||
func (c *commonStore) AddMessage(m *model.Message) error {
|
||||
if c.queue != nil {
|
||||
c.queue.Enqueue(m)
|
||||
return nil
|
||||
}
|
||||
return c.addMessages([]*model.Message{m})
|
||||
}
|
||||
|
||||
// AddMessages synchronously stores a batch of messages to the message cache
|
||||
func (c *commonStore) AddMessages(ms []*model.Message) error {
|
||||
return c.addMessages(ms)
|
||||
}
|
||||
|
||||
func (c *commonStore) addMessages(ms []*model.Message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.nop {
|
||||
return nil
|
||||
}
|
||||
if len(ms) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := time.Now()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
stmt, err := tx.Prepare(c.queries.insertMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
for _, m := range ms {
|
||||
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
|
||||
return model.ErrUnexpectedMessageType
|
||||
}
|
||||
published := m.Time <= time.Now().Unix()
|
||||
tags := strings.Join(m.Tags, ",")
|
||||
var attachmentName, attachmentType, attachmentURL string
|
||||
var attachmentSize, attachmentExpires int64
|
||||
var attachmentDeleted bool
|
||||
if m.Attachment != nil {
|
||||
attachmentName = m.Attachment.Name
|
||||
attachmentType = m.Attachment.Type
|
||||
attachmentSize = m.Attachment.Size
|
||||
attachmentExpires = m.Attachment.Expires
|
||||
attachmentURL = m.Attachment.URL
|
||||
}
|
||||
var actionsStr string
|
||||
if len(m.Actions) > 0 {
|
||||
actionsBytes, err := json.Marshal(m.Actions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionsStr = string(actionsBytes)
|
||||
}
|
||||
var sender string
|
||||
if m.Sender.IsValid() {
|
||||
sender = m.Sender.String()
|
||||
}
|
||||
_, err := stmt.Exec(
|
||||
m.ID,
|
||||
m.SequenceID,
|
||||
m.Time,
|
||||
m.Event,
|
||||
m.Expires,
|
||||
m.Topic,
|
||||
m.Message,
|
||||
m.Title,
|
||||
m.Priority,
|
||||
tags,
|
||||
m.Click,
|
||||
m.Icon,
|
||||
actionsStr,
|
||||
attachmentName,
|
||||
attachmentType,
|
||||
attachmentSize,
|
||||
attachmentExpires,
|
||||
attachmentURL,
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
m.ContentType,
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
|
||||
return err
|
||||
}
|
||||
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
if since.IsNone() {
|
||||
return make([]*model.Message, 0), nil
|
||||
} else if since.IsLatest() {
|
||||
return c.messagesLatest(topic)
|
||||
} else if since.IsID() {
|
||||
return c.messagesSinceID(topic, since, scheduled)
|
||||
}
|
||||
return c.messagesSinceTime(topic, since, scheduled)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer idrows.Close()
|
||||
if !idrows.Next() {
|
||||
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
|
||||
}
|
||||
var rowID int64
|
||||
if err := idrows.Scan(&rowID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idrows.Close()
|
||||
var rows *sql.Rows
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||
func (c *commonStore) MessagesExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Message(id string) (*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rows.Next() {
|
||||
return nil, model.ErrMessageNotFound
|
||||
}
|
||||
defer rows.Close()
|
||||
return readMessage(rows)
|
||||
}
|
||||
|
||||
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
||||
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
|
||||
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonStore) MarkPublished(m *model.Message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonStore) MessageCounts() (map[string]int, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var topic string
|
||||
var count int
|
||||
counts := make(map[string]int)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&topic, &count); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[topic] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Topics() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectTopics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics = append(topics, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) DeleteMessages(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close()
|
||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) ExpireMessages(topics ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||
defer rows.Close()
|
||||
var size int64
|
||||
if !rows.Next() {
|
||||
return 0, errors.New("no rows found")
|
||||
}
|
||||
if err := rows.Scan(&size); err != nil {
|
||||
return 0, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) UpdateStats(messages int64) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonStore) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(c.queries.selectStats)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *commonStore) processMessageBatches() {
|
||||
if c.queue == nil {
|
||||
return
|
||||
}
|
||||
for messages := range c.queue.Dequeue() {
|
||||
if err := c.addMessages(messages); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
|
||||
defer rows.Close()
|
||||
messages := make([]*model.Message, 0)
|
||||
for rows.Next() {
|
||||
m, err := readMessage(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func readMessage(rows *sql.Rows) (*model.Message, error) {
|
||||
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||
var priority int
|
||||
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
||||
err := rows.Scan(
|
||||
&id,
|
||||
&sequenceID,
|
||||
×tamp,
|
||||
&event,
|
||||
&expires,
|
||||
&topic,
|
||||
&msg,
|
||||
&title,
|
||||
&priority,
|
||||
&tagsStr,
|
||||
&click,
|
||||
&icon,
|
||||
&actionsStr,
|
||||
&attachmentName,
|
||||
&attachmentType,
|
||||
&attachmentSize,
|
||||
&attachmentExpires,
|
||||
&attachmentURL,
|
||||
&sender,
|
||||
&user,
|
||||
&contentType,
|
||||
&encoding,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
}
|
||||
var actions []*model.Action
|
||||
if actionsStr != "" {
|
||||
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
senderIP, err := netip.ParseAddr(sender)
|
||||
if err != nil {
|
||||
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
||||
}
|
||||
var att *model.Attachment
|
||||
if attachmentName != "" && attachmentURL != "" {
|
||||
att = &model.Attachment{
|
||||
Name: attachmentName,
|
||||
Type: attachmentType,
|
||||
Size: attachmentSize,
|
||||
Expires: attachmentExpires,
|
||||
URL: attachmentURL,
|
||||
}
|
||||
}
|
||||
return &model.Message{
|
||||
ID: id,
|
||||
SequenceID: sequenceID,
|
||||
Time: timestamp,
|
||||
Expires: expires,
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
Title: title,
|
||||
Priority: priority,
|
||||
Tags: tags,
|
||||
Click: click,
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: senderIP,
|
||||
User: user,
|
||||
ContentType: contentType,
|
||||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ensure commonStore implements Store
|
||||
var _ Store = (*commonStore)(nil)
|
||||
|
||||
// Needed by store.go but not part of Store interface; unused import guard
|
||||
var _ = fmt.Sprintf
|
||||
120
message/store_postgres.go
Normal file
120
message/store_postgres.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL runtime query constants
|
||||
const (
|
||||
pgInsertMessageQuery = `
|
||||
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
||||
`
|
||||
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
||||
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
||||
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
|
||||
pgSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE mid = $1
|
||||
`
|
||||
pgSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND id > $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND (id > $2 OR published = FALSE)
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND published = TRUE
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
pgSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE time <= $1 AND published = FALSE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
||||
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
||||
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
||||
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
|
||||
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
||||
|
||||
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
||||
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
||||
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
||||
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
||||
|
||||
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||
pgUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||
)
|
||||
|
||||
var pgQueries = storeQueries{
|
||||
insertMessage: pgInsertMessageQuery,
|
||||
deleteMessage: pgDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
|
||||
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
|
||||
selectMessagesByID: pgSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: pgSelectMessagesLatestQuery,
|
||||
selectMessagesDue: pgSelectMessagesDueQuery,
|
||||
selectMessagesExpired: pgSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: pgUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: pgSelectMessagesCountQuery,
|
||||
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
|
||||
selectTopics: pgSelectTopicsQuery,
|
||||
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
|
||||
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: pgSelectStatsQuery,
|
||||
updateStats: pgUpdateStatsQuery,
|
||||
updateMessageTime: pgUpdateMessageTimesQuery,
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
|
||||
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgresDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
|
||||
}
|
||||
90
message/store_postgres_schema.go
Normal file
90
message/store_postgres_schema.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
pgCreateTablesQuery = `
|
||||
CREATE TABLE IF NOT EXISTS message (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time BIGINT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size BIGINT NOT NULL,
|
||||
attachment_expires BIGINT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
sender TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS message_stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BIGINT
|
||||
);
|
||||
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 14
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||
)
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
120
message/store_postgres_test.go
Normal file
120
message/store_postgres_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func newTestPostgresStore(t *testing.T) message.Store {
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||
}
|
||||
// Create a unique schema for this test
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
setupDB, err := sql.Open("pgx", dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupDB.Close())
|
||||
// Open store with search_path set to the new schema
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
store, err := message.NewPostgresStore(u.String(), 0, 0)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
store.Close()
|
||||
cleanDB, err := sql.Open("pgx", dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
}
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Sender(t *testing.T) {
|
||||
testSender(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Stats(t *testing.T) {
|
||||
testStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
|
||||
}
|
||||
140
message/store_sqlite.go
Normal file
140
message/store_sqlite.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// SQLite runtime query constants
|
||||
const (
|
||||
sqliteInsertMessageQuery = `
|
||||
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
|
||||
sqliteSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > ? OR published = 0)
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND published = 1
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
sqliteSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
||||
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||
|
||||
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
|
||||
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||
)
|
||||
|
||||
var sqliteQueries = storeQueries{
|
||||
insertMessage: sqliteInsertMessageQuery,
|
||||
deleteMessage: sqliteDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
||||
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
|
||||
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
|
||||
selectMessagesDue: sqliteSelectMessagesDueQuery,
|
||||
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
||||
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
|
||||
selectTopics: sqliteSelectTopicsQuery,
|
||||
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
|
||||
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: sqliteSelectStatsQuery,
|
||||
updateStats: sqliteUpdateStatsQuery,
|
||||
updateMessageTime: sqliteUpdateMessageTimeQuery,
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a SQLite file-backed cache
|
||||
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
|
||||
parentDir := filepath.Dir(filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
|
||||
}
|
||||
|
||||
// NewMemStore creates an in-memory cache
|
||||
func NewMemStore() (Store, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||
}
|
||||
|
||||
// NewNopStore creates an in-memory cache that discards all messages;
|
||||
// it is always empty and can be used if caching is entirely disabled
|
||||
func NewNopStore() (Store, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||
}
|
||||
|
||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||
func createMemoryFilename() string {
|
||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||
}
|
||||
466
message/store_sqlite_schema.go
Normal file
466
message/store_sqlite_schema.go
Normal file
@@ -0,0 +1,466 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
sqliteCreateMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema version management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 14
|
||||
sqliteCreateSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 0 -> 1
|
||||
sqliteMigrate0To1AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 1 -> 2
|
||||
sqliteMigrate1To2AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
sqliteMigrate2To3AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
// 3 -> 4
|
||||
sqliteMigrate3To4AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
sqliteMigrate4To5AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_owner TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||
INSERT
|
||||
INTO messages_new (
|
||||
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||
SELECT
|
||||
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||
FROM messages;
|
||||
DROP TABLE messages;
|
||||
ALTER TABLE messages_new RENAME TO messages;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
sqliteMigrate5To6AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 6 -> 7
|
||||
sqliteMigrate6To7AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||
`
|
||||
|
||||
// 7 -> 8
|
||||
sqliteMigrate7To8AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 8 -> 9
|
||||
sqliteMigrate8To9AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
`
|
||||
|
||||
// 9 -> 10
|
||||
sqliteMigrate9To10AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
sqliteMigrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
sqliteMigrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 12 -> 13
|
||||
sqliteMigrate12To13AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
`
|
||||
|
||||
// 13 -> 14
|
||||
sqliteMigrate13To14AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: sqliteMigrateFrom0,
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
6: sqliteMigrateFrom6,
|
||||
7: sqliteMigrateFrom7,
|
||||
8: sqliteMigrateFrom8,
|
||||
9: sqliteMigrateFrom9,
|
||||
10: sqliteMigrateFrom10,
|
||||
11: sqliteMigrateFrom11,
|
||||
12: sqliteMigrateFrom12,
|
||||
13: sqliteMigrateFrom13,
|
||||
}
|
||||
)
|
||||
|
||||
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
rowsMC, err := db.Query(sqliteSelectMessagesCountQuery)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
rowsMC.Close()
|
||||
// If 'messages' table exists, check 'schemaVersion' table
|
||||
schemaVersion := 0
|
||||
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
|
||||
if err == nil {
|
||||
defer rowsSV.Close()
|
||||
if !rowsSV.Next() {
|
||||
return fmt.Errorf("cannot determine schema version: cache file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
}
|
||||
// Do migrations
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db, cacheDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||
if _, err := db.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
459
message/store_sqlite_test.go
Normal file
459
message/store_sqlite_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestSqliteStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Sender(t *testing.T) {
|
||||
testSender(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Sender(t *testing.T) {
|
||||
testSender(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Stats(t *testing.T) {
|
||||
testStats(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Stats(t *testing.T) {
|
||||
testStats(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 0" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
require.Equal(t, "some message 5", messages[5].Message)
|
||||
require.Equal(t, "", messages[5].Title)
|
||||
require.Nil(t, messages[5].Tags)
|
||||
require.Equal(t, 0, messages[5].Priority)
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From1(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 1" schema
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(512) NOT NULL,
|
||||
title VARCHAR(256) NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags VARCHAR(256) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
|
||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(delayedMessage))
|
||||
|
||||
// 10, not 11!
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
|
||||
// 11!
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
require.Nil(t, rows.Scan(&indexName))
|
||||
require.Equal(t, "idx_topic", indexName)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From9(t *testing.T) {
|
||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 9" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
insertQuery := `
|
||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(
|
||||
insertQuery,
|
||||
fmt.Sprintf("abcd%d", i),
|
||||
time.Now().Unix(),
|
||||
"mytopic",
|
||||
fmt.Sprintf("some message %d", i),
|
||||
"", // title
|
||||
0, // priority
|
||||
"", // tags
|
||||
"", // click
|
||||
"", // icon
|
||||
"", // actions
|
||||
"", // attachment_name
|
||||
"", // attachment_type
|
||||
0, // attachment_size
|
||||
0, // attachment_expires
|
||||
"", // attachment_url
|
||||
"9.9.9.9", // sender
|
||||
"", // encoding
|
||||
1, // published
|
||||
)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
cacheDuration := 17 * time.Hour
|
||||
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Check version
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var version int
|
||||
require.Nil(t, rows.Scan(&version))
|
||||
require.Equal(t, 14, version)
|
||||
require.Nil(t, rows.Close())
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
for _, m := range messages {
|
||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
startupQueries := `pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;`
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.FileExists(t, filename+"-wal")
|
||||
require.FileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_None(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.NoFileExists(t, filename+"-wal")
|
||||
require.NoFileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNopStore(t *testing.T) {
|
||||
s, err := message.NewNopStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, messages)
|
||||
|
||||
topics, err := s.Topics()
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestStore(t *testing.T) message.Store {
|
||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newMemTestStore(t *testing.T) message.Store {
|
||||
s, err := message.NewMemStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer db.Close()
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, 14, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
767
message/store_test.go
Normal file
767
message/store_test.go
Normal file
@@ -0,0 +1,767 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func testCacheMessages(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
counts, err := s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "my message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
require.Equal(t, model.MessageEvent, messages[0].Event)
|
||||
require.Equal(t, "", messages[0].Title)
|
||||
require.Equal(t, 0, messages[0].Priority)
|
||||
require.Nil(t, messages[0].Tags)
|
||||
require.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
||||
require.Empty(t, messages)
|
||||
|
||||
// mytopic: since m1 (by ID)
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, m2.ID, messages[0].ID)
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// mytopic: latest
|
||||
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["example"])
|
||||
|
||||
// example: since all
|
||||
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, counts["doesnotexist"])
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func testCacheMessagesLock(t *testing.T, s message.Store) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
|
||||
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||
require.Equal(t, "message 2", messages[2].Message)
|
||||
|
||||
messages, _ = s.MessagesDue()
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func testCacheTopics(t *testing.T, s message.Store) {
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
||||
|
||||
topics, err := s.Topics()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, 2, len(topics))
|
||||
require.Contains(t, topics, "topic1")
|
||||
require.Contains(t, topics, "topic2")
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
require.Equal(t, 5, messages[0].Priority)
|
||||
require.Equal(t, "some title", messages[0].Title)
|
||||
}
|
||||
|
||||
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = 200
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
||||
m4.Time = 400
|
||||
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
||||
m6.Time = 600
|
||||
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
||||
m7.Time = 700
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
require.Nil(t, s.AddMessage(m4))
|
||||
require.Nil(t, s.AddMessage(m5))
|
||||
require.Nil(t, s.AddMessage(m6))
|
||||
require.Nil(t, s.AddMessage(m7))
|
||||
|
||||
// Case 1: Since ID exists, exclude scheduled
|
||||
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
|
||||
// Case 2: Since ID exists, include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
||||
require.Equal(t, 5, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message)
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||
|
||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
||||
require.Equal(t, 7, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 2", messages[1].Message)
|
||||
require.Equal(t, "message 4", messages[2].Message)
|
||||
require.Equal(t, "message 6", messages[3].Message)
|
||||
require.Equal(t, "message 7", messages[4].Message)
|
||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||
|
||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "message 5", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message)
|
||||
}
|
||||
|
||||
func testCachePrune(t *testing.T, s message.Store) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = now - 10
|
||||
m1.Expires = now - 5
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = now - 5
|
||||
m2.Expires = now + 5 // In the future
|
||||
|
||||
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
||||
m3.Time = now - 12
|
||||
m3.Expires = now - 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
counts, err := s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
require.Equal(t, 1, counts["another_topic"])
|
||||
|
||||
expiredMessageIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
||||
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["mytopic"])
|
||||
require.Equal(t, 0, counts["another_topic"])
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testCacheAttachments(t *testing.T, s message.Store) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 5000,
|
||||
Expires: expires1,
|
||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: expires2,
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.User = "u_BAsbaAa"
|
||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: expires3,
|
||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
require.Equal(t, "flower for you", messages[0].Message)
|
||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(10000), size)
|
||||
|
||||
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||
|
||||
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(20000), size)
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.SequenceID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
}
|
||||
|
||||
func testSender(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||
}
|
||||
|
||||
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
|
||||
// Create a scheduled (unpublished) message
|
||||
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
scheduledMsg.ID = "scheduled1"
|
||||
scheduledMsg.SequenceID = "seq123"
|
||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||
require.Nil(t, s.AddMessage(scheduledMsg))
|
||||
|
||||
// Create a published message with different sequence ID
|
||||
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
||||
publishedMsg.ID = "published1"
|
||||
publishedMsg.SequenceID = "seq456"
|
||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||
require.Nil(t, s.AddMessage(publishedMsg))
|
||||
|
||||
// Create a scheduled message in a different topic
|
||||
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
||||
otherTopicMsg.ID = "other1"
|
||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(otherTopicMsg))
|
||||
|
||||
// Verify all messages exist (including scheduled)
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Delete scheduled message by sequence ID and verify returned IDs
|
||||
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(deletedIDs))
|
||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||
|
||||
// Verify scheduled message is deleted
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
|
||||
// Verify other topic's message still exists (topic-scoped deletion)
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "other scheduled", messages[0].Message)
|
||||
|
||||
// Deleting non-existent sequence ID should return empty list
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
// Deleting published message should not affect it (only deletes unpublished)
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testMessageByID(t *testing.T, s message.Store) {
|
||||
// Add a message
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Title = "some title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve by ID
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "some message", retrieved.Message)
|
||||
require.Equal(t, "some title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
||||
|
||||
// Non-existent ID returns ErrMessageNotFound
|
||||
_, err = s.Message("doesnotexist")
|
||||
require.Equal(t, model.ErrMessageNotFound, err)
|
||||
}
|
||||
|
||||
func testMarkPublished(t *testing.T, s message.Store) {
|
||||
// Add a scheduled message (future time → unpublished)
|
||||
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
m.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Verify it does not appear in non-scheduled queries
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Verify it does appear in scheduled queries
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Mark as published
|
||||
require.Nil(t, s.MarkPublished(m))
|
||||
|
||||
// Now it should appear in non-scheduled queries too
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "scheduled message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testExpireMessages(t *testing.T, s message.Store) {
|
||||
// Add messages to two topics
|
||||
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2 := model.NewDefaultMessage("topic1", "message 2")
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("topic2", "message 3")
|
||||
m3.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Verify all messages exist
|
||||
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Expire topic1 messages
|
||||
require.Nil(t, s.ExpireMessages("topic1"))
|
||||
|
||||
// topic1 messages should now be expired (expires set to past)
|
||||
expiredIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(expiredIDs))
|
||||
sort.Strings(expiredIDs)
|
||||
expectedIDs := []string{m1.ID, m2.ID}
|
||||
sort.Strings(expectedIDs)
|
||||
require.Equal(t, expectedIDs, expiredIDs)
|
||||
|
||||
// topic2 should be unaffected
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 3", messages[0].Message)
|
||||
}
|
||||
|
||||
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
|
||||
// Add a message with an expired attachment (file needs cleanup)
|
||||
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||
m1.ID = "msg1"
|
||||
m1.SequenceID = "msg1"
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m1.Attachment = &model.Attachment{
|
||||
Name: "old.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 50000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/old.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message with another expired attachment
|
||||
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
||||
m2.ID = "msg2"
|
||||
m2.SequenceID = "msg2"
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2.Attachment = &model.Attachment{
|
||||
Name: "another.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 30000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/another.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Both should show as expired attachments needing cleanup
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(ids))
|
||||
|
||||
// Mark msg1's attachment as deleted (file cleaned up)
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
||||
|
||||
// Now only msg2 should show as needing cleanup
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "msg2", ids[0])
|
||||
|
||||
// Mark msg2 too
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
||||
|
||||
// No more expired attachments to clean up
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(ids))
|
||||
|
||||
// Messages themselves still exist
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
}
|
||||
|
||||
func testStats(t *testing.T, s message.Store) {
|
||||
// Initial stats should be zero
|
||||
messages, err := s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), messages)
|
||||
|
||||
// Update stats
|
||||
require.Nil(t, s.UpdateStats(42))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(42), messages)
|
||||
|
||||
// Update again (overwrites)
|
||||
require.Nil(t, s.UpdateStats(100))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(100), messages)
|
||||
}
|
||||
|
||||
func testAddMessages(t *testing.T, s message.Store) {
|
||||
// Batch add multiple messages
|
||||
msgs := []*model.Message{
|
||||
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||
model.NewDefaultMessage("mytopic", "batch 2"),
|
||||
model.NewDefaultMessage("othertopic", "batch 3"),
|
||||
}
|
||||
require.Nil(t, s.AddMessages(msgs))
|
||||
|
||||
// Verify all were inserted
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "batch 3", messages[0].Message)
|
||||
|
||||
// Empty batch should succeed
|
||||
require.Nil(t, s.AddMessages([]*model.Message{}))
|
||||
|
||||
// Batch with invalid event type should fail
|
||||
badMsgs := []*model.Message{
|
||||
model.NewKeepaliveMessage("mytopic"),
|
||||
}
|
||||
require.NotNil(t, s.AddMessages(badMsgs))
|
||||
}
|
||||
|
||||
func testMessagesDue(t *testing.T, s message.Store) {
|
||||
// Add a message scheduled in the past (i.e. it's due now)
|
||||
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||
// Set expires in the future so it doesn't get pruned
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message scheduled in the future (not due)
|
||||
m2 := model.NewDefaultMessage("mytopic", "future message")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Mark m1 as published so it won't be "due"
|
||||
// (MessagesDue returns unpublished messages whose time <= now)
|
||||
// m1 is auto-published (time <= now), so it should not be due
|
||||
// m2 is unpublished (time in future), not due yet
|
||||
due, err := s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
||||
// We need to manipulate the database to create a truly "due" message:
|
||||
// a message with published=false and time <= now
|
||||
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
||||
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Not due yet
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Wait for it to become due
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(due))
|
||||
require.Equal(t, "truly due message", due[0].Message)
|
||||
}
|
||||
|
||||
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
|
||||
// Create a message with all fields populated
|
||||
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||
m.SequenceID = "custom_seq_id"
|
||||
m.Title = "A Title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"warning", "srv01"}
|
||||
m.Click = "https://example.com/click"
|
||||
m.Icon = "https://example.com/icon.png"
|
||||
m.Actions = []*model.Action{
|
||||
{
|
||||
ID: "action1",
|
||||
Action: "view",
|
||||
Label: "Open Site",
|
||||
URL: "https://example.com",
|
||||
Clear: true,
|
||||
},
|
||||
{
|
||||
ID: "action2",
|
||||
Action: "http",
|
||||
Label: "Call Webhook",
|
||||
URL: "https://example.com/hook",
|
||||
Method: "PUT",
|
||||
Headers: map[string]string{"X-Token": "secret"},
|
||||
Body: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
m.ContentType = "text/markdown"
|
||||
m.Encoding = "base64"
|
||||
m.Sender = netip.MustParseAddr("9.8.7.6")
|
||||
m.User = "u_TestUser123"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve and verify every field
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
||||
require.Equal(t, m.Time, retrieved.Time)
|
||||
require.Equal(t, m.Expires, retrieved.Expires)
|
||||
require.Equal(t, model.MessageEvent, retrieved.Event)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "hello world", retrieved.Message)
|
||||
require.Equal(t, "A Title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
||||
require.Equal(t, "https://example.com/click", retrieved.Click)
|
||||
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
||||
require.Equal(t, "text/markdown", retrieved.ContentType)
|
||||
require.Equal(t, "base64", retrieved.Encoding)
|
||||
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
||||
require.Equal(t, "u_TestUser123", retrieved.User)
|
||||
|
||||
// Verify actions round-trip
|
||||
require.Equal(t, 2, len(retrieved.Actions))
|
||||
|
||||
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
||||
require.Equal(t, "view", retrieved.Actions[0].Action)
|
||||
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
||||
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
||||
require.Equal(t, true, retrieved.Actions[0].Clear)
|
||||
|
||||
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
||||
require.Equal(t, "http", retrieved.Actions[1].Action)
|
||||
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
||||
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
||||
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
||||
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
||||
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||
}
|
||||
205
model/model.go
Normal file
205
model/model.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// List of possible events
|
||||
const (
|
||||
OpenEvent = "open"
|
||||
KeepaliveEvent = "keepalive"
|
||||
MessageEvent = "message"
|
||||
MessageDeleteEvent = "message_delete"
|
||||
MessageClearEvent = "message_clear"
|
||||
PollRequestEvent = "poll_request"
|
||||
)
|
||||
|
||||
// MessageIDLength is the length of a randomly generated message ID
|
||||
const MessageIDLength = 12
|
||||
|
||||
// Errors for message operations
|
||||
var (
|
||||
ErrUnexpectedMessageType = errors.New("unexpected message type")
|
||||
ErrMessageNotFound = errors.New("message not found")
|
||||
)
|
||||
|
||||
// Message represents a message published to a topic
|
||||
type Message struct {
|
||||
ID string `json:"id"` // Random message ID
|
||||
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*Action `json:"actions,omitempty"`
|
||||
Attachment *Attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
|
||||
// Context returns a log context for the message
|
||||
func (m *Message) Context() log.Context {
|
||||
fields := map[string]any{
|
||||
"topic": m.Topic,
|
||||
"message_id": m.ID,
|
||||
"message_sequence_id": m.SequenceID,
|
||||
"message_time": m.Time,
|
||||
"message_event": m.Event,
|
||||
"message_body_size": len(m.Message),
|
||||
}
|
||||
if m.Sender.IsValid() {
|
||||
fields["message_sender"] = m.Sender.String()
|
||||
}
|
||||
if m.User != "" {
|
||||
fields["message_user"] = m.User
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// ForJSON returns a copy of the message suitable for JSON output.
|
||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||
func (m *Message) ForJSON() *Message {
|
||||
if m.SequenceID == m.ID {
|
||||
clone := *m
|
||||
clone.SequenceID = ""
|
||||
return &clone
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Attachment represents a file attachment on a message
|
||||
type Attachment struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// Action represents a user-defined action on a message
|
||||
type Action struct {
|
||||
ID string `json:"id"`
|
||||
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
||||
Label string `json:"label"` // action button label
|
||||
Clear bool `json:"clear"` // clear notification after successful execution
|
||||
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
||||
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
||||
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
||||
Body string `json:"body,omitempty"` // used in "http" action
|
||||
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
||||
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
||||
Value string `json:"value,omitempty"` // used in "copy" action
|
||||
}
|
||||
|
||||
// NewAction creates a new action with initialized maps
|
||||
func NewAction() *Action {
|
||||
return &Action{
|
||||
Headers: make(map[string]string),
|
||||
Extras: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// NewMessage creates a new message with the current timestamp
|
||||
func NewMessage(event, topic, msg string) *Message {
|
||||
return &Message{
|
||||
ID: util.RandomString(MessageIDLength),
|
||||
Time: time.Now().Unix(),
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
// NewOpenMessage is a convenience method to create an open message
|
||||
func NewOpenMessage(topic string) *Message {
|
||||
return NewMessage(OpenEvent, topic, "")
|
||||
}
|
||||
|
||||
// NewKeepaliveMessage is a convenience method to create a keepalive message
|
||||
func NewKeepaliveMessage(topic string) *Message {
|
||||
return NewMessage(KeepaliveEvent, topic, "")
|
||||
}
|
||||
|
||||
// NewDefaultMessage is a convenience method to create a notification message
|
||||
func NewDefaultMessage(topic, msg string) *Message {
|
||||
return NewMessage(MessageEvent, topic, msg)
|
||||
}
|
||||
|
||||
// NewActionMessage creates a new action message (message_delete or message_clear)
|
||||
func NewActionMessage(event, topic, sequenceID string) *Message {
|
||||
m := NewMessage(event, topic, "")
|
||||
m.SequenceID = sequenceID
|
||||
return m
|
||||
}
|
||||
|
||||
// ValidMessageID returns true if the given string is a valid message ID
|
||||
func ValidMessageID(s string) bool {
|
||||
return util.ValidRandomString(s, MessageIDLength)
|
||||
}
|
||||
|
||||
// SinceMarker represents a point in time or message ID from which to retrieve messages
|
||||
type SinceMarker struct {
|
||||
time time.Time
|
||||
id string
|
||||
}
|
||||
|
||||
// NewSinceTime creates a new SinceMarker from a Unix timestamp
|
||||
func NewSinceTime(timestamp int64) SinceMarker {
|
||||
return SinceMarker{time.Unix(timestamp, 0), ""}
|
||||
}
|
||||
|
||||
// NewSinceID creates a new SinceMarker from a message ID
|
||||
func NewSinceID(id string) SinceMarker {
|
||||
return SinceMarker{time.Unix(0, 0), id}
|
||||
}
|
||||
|
||||
// IsAll returns true if this is the "all messages" marker
|
||||
func (t SinceMarker) IsAll() bool {
|
||||
return t == SinceAllMessages
|
||||
}
|
||||
|
||||
// IsNone returns true if this is the "no messages" marker
|
||||
func (t SinceMarker) IsNone() bool {
|
||||
return t == SinceNoMessages
|
||||
}
|
||||
|
||||
// IsLatest returns true if this is the "latest message" marker
|
||||
func (t SinceMarker) IsLatest() bool {
|
||||
return t == SinceLatestMessage
|
||||
}
|
||||
|
||||
// IsID returns true if this marker references a specific message ID
|
||||
func (t SinceMarker) IsID() bool {
|
||||
return t.id != "" && t.id != "latest"
|
||||
}
|
||||
|
||||
// Time returns the time component of the marker
|
||||
func (t SinceMarker) Time() time.Time {
|
||||
return t.time
|
||||
}
|
||||
|
||||
// ID returns the message ID component of the marker
|
||||
func (t SinceMarker) ID() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
// Common SinceMarker values for subscribing to messages
|
||||
var (
|
||||
SinceAllMessages = SinceMarker{time.Unix(0, 0), ""}
|
||||
SinceNoMessages = SinceMarker{time.Unix(1, 0), ""}
|
||||
SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"}
|
||||
)
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -39,7 +40,7 @@ type actionParser struct {
|
||||
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
|
||||
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
|
||||
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
|
||||
func parseActions(s string) (actions []*action, err error) {
|
||||
func parseActions(s string) (actions []*model.Action, err error) {
|
||||
// Parse JSON or simple format
|
||||
s = strings.TrimSpace(s)
|
||||
if strings.HasPrefix(s, "[") {
|
||||
@@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) {
|
||||
}
|
||||
|
||||
// parseActionsFromJSON converts a JSON array into an array of actions
|
||||
func parseActionsFromJSON(s string) ([]*action, error) {
|
||||
actions := make([]*action, 0)
|
||||
func parseActionsFromJSON(s string) ([]*model.Action, error) {
|
||||
actions := make([]*model.Action, 0)
|
||||
if err := json.Unmarshal([]byte(s), &actions); err != nil {
|
||||
return nil, fmt.Errorf("JSON error: %w", err)
|
||||
}
|
||||
@@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) {
|
||||
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
|
||||
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
|
||||
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
|
||||
func parseActionsFromSimple(s string) ([]*action, error) {
|
||||
func parseActionsFromSimple(s string) ([]*model.Action, error) {
|
||||
if !utf8.ValidString(s) {
|
||||
return nil, errors.New("invalid utf-8 string")
|
||||
}
|
||||
@@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) {
|
||||
}
|
||||
|
||||
// Parse loops trough parseAction() until the end of the string is reached
|
||||
func (p *actionParser) Parse() ([]*action, error) {
|
||||
actions := make([]*action, 0)
|
||||
func (p *actionParser) Parse() ([]*model.Action, error) {
|
||||
actions := make([]*model.Action, 0)
|
||||
for !p.eof() {
|
||||
a, err := p.parseAction()
|
||||
if err != nil {
|
||||
@@ -134,7 +135,7 @@ func (p *actionParser) Parse() ([]*action, error) {
|
||||
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
|
||||
// and then uses populateAction to interpret the keys/values. The function terminates
|
||||
// when EOF or ";" is reached.
|
||||
func (p *actionParser) parseAction() (*action, error) {
|
||||
func (p *actionParser) parseAction() (*model.Action, error) {
|
||||
a := newAction()
|
||||
section := 0
|
||||
for {
|
||||
@@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) {
|
||||
|
||||
// populateAction is the "business logic" of the parser. It applies the key/value
|
||||
// pair to the action instance.
|
||||
func populateAction(newAction *action, section int, key, value string) error {
|
||||
func populateAction(newAction *model.Action, section int, key, value string) error {
|
||||
// Auto-expand keys based on their index
|
||||
if key == "" && section == 0 {
|
||||
key = "action"
|
||||
|
||||
@@ -88,6 +88,7 @@ var (
|
||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||
type Config struct {
|
||||
File string // Config file, only used for testing
|
||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||
BaseURL string
|
||||
ListenHTTP string
|
||||
ListenHTTPS string
|
||||
@@ -192,6 +193,7 @@ type Config struct {
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
File: DefaultConfigFile, // Only used for testing
|
||||
DatabaseURL: "",
|
||||
BaseURL: "",
|
||||
ListenHTTP: DefaultListenHTTP,
|
||||
ListenHTTPS: "",
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -55,12 +56,12 @@ func logvr(v *visitor, r *http.Request) *log.Event {
|
||||
}
|
||||
|
||||
// logvrm creates a new log event with HTTP request, visitor fields and message fields
|
||||
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
|
||||
func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event {
|
||||
return logvr(v, r).With(m)
|
||||
}
|
||||
|
||||
// logvrm creates a new log event with visitor fields and message fields
|
||||
func logvm(v *visitor, m *message) *log.Event {
|
||||
func logvm(v *visitor, m *model.Message) *log.Event {
|
||||
return logv(v).With(m)
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,825 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSqliteCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessages(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "my message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
require.Equal(t, messageEvent, messages[0].Event)
|
||||
require.Equal(t, "", messages[0].Title)
|
||||
require.Equal(t, 0, messages[0].Priority)
|
||||
require.Nil(t, messages[0].Tags)
|
||||
require.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
|
||||
require.Empty(t, messages)
|
||||
|
||||
// mytopic: since m1 (by ID)
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, m2.ID, messages[0].ID)
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = c.Messages("mytopic", newSinceTime(2), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// mytopic: latest
|
||||
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["example"])
|
||||
|
||||
// example: since all
|
||||
messages, _ = c.Messages("example", sinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, counts["doesnotexist"])
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesLock(t *testing.T, c *messageCache) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message")))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
m3 := newDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||
m4 := newDefaultMessage("mytopic2", "message 4")
|
||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
|
||||
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||
require.Equal(t, "message 2", messages[2].Message)
|
||||
|
||||
messages, _ = c.MessagesDue()
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheTopics(t *testing.T, c *messageCache) {
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
|
||||
|
||||
topics, err := c.Topics()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, 2, len(topics))
|
||||
require.Equal(t, "topic1", topics["topic1"].ID)
|
||||
require.Equal(t, "topic2", topics["topic2"].ID)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
require.Equal(t, 5, messages[0].Priority)
|
||||
require.Equal(t, "some title", messages[0].Title)
|
||||
}
|
||||
|
||||
func TestSqliteCache_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := newDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = 200
|
||||
m3 := newDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||
m4 := newDefaultMessage("mytopic", "message 4")
|
||||
m4.Time = 400
|
||||
m5 := newDefaultMessage("mytopic", "message 5")
|
||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||
m6 := newDefaultMessage("mytopic", "message 6")
|
||||
m6.Time = 600
|
||||
m7 := newDefaultMessage("mytopic", "message 7")
|
||||
m7.Time = 700
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
require.Nil(t, c.AddMessage(m4))
|
||||
require.Nil(t, c.AddMessage(m5))
|
||||
require.Nil(t, c.AddMessage(m6))
|
||||
require.Nil(t, c.AddMessage(m7))
|
||||
|
||||
// Case 1: Since ID exists, exclude scheduled
|
||||
messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false)
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
|
||||
// Case 2: Since ID exists, include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true)
|
||||
require.Equal(t, 5, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message)
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||
|
||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true)
|
||||
require.Equal(t, 7, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 2", messages[1].Message)
|
||||
require.Equal(t, "message 4", messages[2].Message)
|
||||
require.Equal(t, "message 6", messages[3].Message)
|
||||
require.Equal(t, "message 7", messages[4].Message)
|
||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||
|
||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "message 5", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Prune(t *testing.T) {
|
||||
testCachePrune(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Prune(t *testing.T) {
|
||||
testCachePrune(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCachePrune(t *testing.T, c *messageCache) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := newDefaultMessage("mytopic", "my message")
|
||||
m1.Time = now - 10
|
||||
m1.Expires = now - 5
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = now - 5
|
||||
m2.Expires = now + 5 // In the future
|
||||
|
||||
m3 := newDefaultMessage("another_topic", "and another one")
|
||||
m3.Time = now - 12
|
||||
m3.Expires = now - 2
|
||||
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
require.Nil(t, c.AddMessage(m3))
|
||||
|
||||
counts, err := c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
require.Equal(t, 1, counts["another_topic"])
|
||||
|
||||
expiredMessageIDs, err := c.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
|
||||
|
||||
counts, err = c.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["mytopic"])
|
||||
require.Equal(t, 0, counts["another_topic"])
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachments(t *testing.T, c *messageCache) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 5000,
|
||||
Expires: expires1,
|
||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: expires2,
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = newDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.User = "u_BAsbaAa"
|
||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||
m.Attachment = &attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: expires3,
|
||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
require.Equal(t, "flower for you", messages[0].Message)
|
||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(10000), size)
|
||||
|
||||
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||
|
||||
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(20000), size)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Attachments_Expired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
|
||||
m := newDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
m = newDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.SequenceID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, c.AddMessage(m))
|
||||
|
||||
ids, err := c.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 0" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
require.Equal(t, "some message 5", messages[5].Message)
|
||||
require.Equal(t, "", messages[5].Title)
|
||||
require.Nil(t, messages[5].Tags)
|
||||
require.Equal(t, 0, messages[5].Priority)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 1" schema
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(512) NOT NULL,
|
||||
title VARCHAR(256) NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags VARCHAR(256) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create cache to trigger migration
|
||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
|
||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, c.AddMessage(delayedMessage))
|
||||
|
||||
// 10, not 11!
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
|
||||
// 11!
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
require.Nil(t, rows.Scan(&indexName))
|
||||
require.Equal(t, "idx_topic", indexName)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Migration_From9(t *testing.T) {
|
||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 8" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
insertQuery := `
|
||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(
|
||||
insertQuery,
|
||||
fmt.Sprintf("abcd%d", i),
|
||||
time.Now().Unix(),
|
||||
"mytopic",
|
||||
fmt.Sprintf("some message %d", i),
|
||||
"", // title
|
||||
0, // priority
|
||||
"", // tags
|
||||
"", // click
|
||||
"", // icon
|
||||
"", // actions
|
||||
"", // attachment_name
|
||||
"", // attachment_type
|
||||
0, // attachment_size
|
||||
0, // attachment_type
|
||||
"", // attachment_url
|
||||
"9.9.9.9", // sender
|
||||
"", // encoding
|
||||
1, // published
|
||||
)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Create cache to trigger migration
|
||||
cacheDuration := 17 * time.Hour
|
||||
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
checkSchemaVersion(t, c.db)
|
||||
|
||||
// Check version
|
||||
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var version int
|
||||
require.Nil(t, rows.Scan(&version))
|
||||
require.Equal(t, currentSchemaVersion, version)
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
for _, m := range messages {
|
||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_WAL(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := `pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;`
|
||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.FileExists(t, filename+"-wal")
|
||||
require.FileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_None(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := ""
|
||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.NoFileExists(t, filename+"-wal")
|
||||
require.NoFileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteCache_StartupQueries_Fail(t *testing.T) {
|
||||
filename := newSqliteTestCacheFile(t)
|
||||
startupQueries := `xx error`
|
||||
_, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSqliteCache_Sender(t *testing.T) {
|
||||
testSender(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_Sender(t *testing.T) {
|
||||
testSender(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testSender(t *testing.T, c *messageCache) {
|
||||
m1 := newDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, c.AddMessage(m1))
|
||||
|
||||
m2 := newDefaultMessage("mytopic", "mymessage without sender")
|
||||
require.Nil(t, c.AddMessage(m2))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||
}
|
||||
|
||||
func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newSqliteTestCache(t))
|
||||
}
|
||||
|
||||
func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newMemTestCache(t))
|
||||
}
|
||||
|
||||
func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) {
|
||||
// Create a scheduled (unpublished) message
|
||||
scheduledMsg := newDefaultMessage("mytopic", "scheduled message")
|
||||
scheduledMsg.ID = "scheduled1"
|
||||
scheduledMsg.SequenceID = "seq123"
|
||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||
require.Nil(t, c.AddMessage(scheduledMsg))
|
||||
|
||||
// Create a published message with different sequence ID
|
||||
publishedMsg := newDefaultMessage("mytopic", "published message")
|
||||
publishedMsg.ID = "published1"
|
||||
publishedMsg.SequenceID = "seq456"
|
||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||
require.Nil(t, c.AddMessage(publishedMsg))
|
||||
|
||||
// Create a scheduled message in a different topic
|
||||
otherTopicMsg := newDefaultMessage("othertopic", "other scheduled")
|
||||
otherTopicMsg.ID = "other1"
|
||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, c.AddMessage(otherTopicMsg))
|
||||
|
||||
// Verify all messages exist (including scheduled)
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Delete scheduled message by sequence ID and verify returned IDs
|
||||
deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(deletedIDs))
|
||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||
|
||||
// Verify scheduled message is deleted
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
|
||||
// Verify other topic's message still exists (topic-scoped deletion)
|
||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "other scheduled", messages[0].Message)
|
||||
|
||||
// Deleting non-existent sequence ID should return empty list
|
||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
// Deleting published message should not affect it (only deletes unpublished)
|
||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
}
|
||||
|
||||
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
func TestMemCache_NopCache(t *testing.T) {
|
||||
c, _ := newNopCache()
|
||||
require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
|
||||
|
||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, messages)
|
||||
|
||||
topics, err := c.Topics()
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestCache(t *testing.T) *messageCache {
|
||||
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func newSqliteTestCacheFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
|
||||
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
}
|
||||
|
||||
func newMemTestCache(t *testing.T) *messageCache {
|
||||
c, err := newMemCache()
|
||||
require.Nil(t, err)
|
||||
return c
|
||||
}
|
||||
108
server/server.go
108
server/server.go
@@ -33,10 +33,13 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/yaml.v2"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"heckel.io/ntfy/v2/util/sprig"
|
||||
"heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
@@ -56,8 +59,8 @@ type Server struct {
|
||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
webPush *webPushStore // Database that stores web push subscriptions
|
||||
messageCache message.Store // Database that stores the messages
|
||||
webPush webpush.Store // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
@@ -90,6 +93,7 @@ var (
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
metricsPath = "/metrics"
|
||||
apiHealthPath = "/v1/health"
|
||||
apiVersionPath = "/v1/version"
|
||||
apiConfigPath = "/v1/config"
|
||||
apiStatsPath = "/v1/stats"
|
||||
apiWebPushPath = "/v1/webpush"
|
||||
@@ -175,17 +179,25 @@ func New(conf *Config) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var webPush *webPushStore
|
||||
var wp webpush.Store
|
||||
if conf.WebPushPublicKey != "" {
|
||||
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
if conf.DatabaseURL != "" {
|
||||
wp, err = webpush.NewPostgresStore(conf.DatabaseURL)
|
||||
} else {
|
||||
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
topics, err := messageCache.Topics()
|
||||
topicIDs, err := messageCache.Topics()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics := make(map[string]*topic, len(topicIDs))
|
||||
for _, id := range topicIDs {
|
||||
topics[id] = newTopic(id)
|
||||
}
|
||||
messages, err := messageCache.Stats()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -198,9 +210,10 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
||||
authConfig := &user.Config{
|
||||
Filename: conf.AuthFile,
|
||||
DatabaseURL: conf.DatabaseURL,
|
||||
StartupQueries: conf.AuthStartupQueries,
|
||||
DefaultAccess: conf.AuthDefault,
|
||||
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||
@@ -210,7 +223,16 @@ func New(conf *Config) (*Server, error) {
|
||||
BcryptCost: conf.AuthBcryptCost,
|
||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||
}
|
||||
userManager, err = user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if conf.DatabaseURL != "" {
|
||||
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
||||
} else {
|
||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userManager, err = user.NewManager(store, authConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -232,7 +254,7 @@ func New(conf *Config) (*Server, error) {
|
||||
s := &Server{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
webPush: webPush,
|
||||
webPush: wp,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
@@ -247,13 +269,15 @@ func New(conf *Config) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func createMessageCache(conf *Config) (*messageCache, error) {
|
||||
func createMessageCache(conf *Config) (message.Store, error) {
|
||||
if conf.CacheDuration == 0 {
|
||||
return newNopCache()
|
||||
return message.NewNopStore()
|
||||
} else if conf.DatabaseURL != "" {
|
||||
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||
} else if conf.CacheFile != "" {
|
||||
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||
}
|
||||
return newMemCache()
|
||||
return message.NewMemStore()
|
||||
}
|
||||
|
||||
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
||||
@@ -467,6 +491,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
||||
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
|
||||
return s.handleHealth(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiVersionPath {
|
||||
return s.ensureAdmin(s.handleVersion)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiConfigPath {
|
||||
return s.handleConfig(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
|
||||
@@ -732,7 +758,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
if s.config.CacheBatchTimeout > 0 {
|
||||
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
||||
// and messages are persisted asynchronously, retry fetching from the database
|
||||
m, err = util.Retry(func() (*message, error) {
|
||||
m, err = util.Retry(func() (*model.Message, error) {
|
||||
return s.messageCache.Message(messageID)
|
||||
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
|
||||
}
|
||||
@@ -778,7 +804,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||
return writeMatrixDiscoveryResponse(w)
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
|
||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) {
|
||||
start := time.Now()
|
||||
t, err := fromContext[*topic](r, contextTopic)
|
||||
if err != nil {
|
||||
@@ -906,7 +932,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||
return err
|
||||
}
|
||||
minc(metricMessagesPublishedSuccess)
|
||||
return s.writeJSON(w, m.forJSON())
|
||||
return s.writeJSON(w, m.ForJSON())
|
||||
}
|
||||
|
||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
@@ -996,10 +1022,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
s.mu.Unlock()
|
||||
return s.writeJSON(w, m.forJSON())
|
||||
return s.writeJSON(w, m.ForJSON())
|
||||
}
|
||||
|
||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
func (s *Server) sendToFirebase(v *visitor, m *model.Message) {
|
||||
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
|
||||
if err := s.firebaseClient.Send(v, m); err != nil {
|
||||
minc(metricFirebasePublishedFailure)
|
||||
@@ -1013,7 +1039,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
|
||||
minc(metricFirebasePublishedSuccess)
|
||||
}
|
||||
|
||||
func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
||||
func (s *Server) sendEmail(v *visitor, m *model.Message, email string) {
|
||||
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
|
||||
if err := s.smtpSender.Send(v, m, email); err != nil {
|
||||
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
|
||||
@@ -1023,7 +1049,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
||||
minc(metricEmailsPublishedSuccess)
|
||||
}
|
||||
|
||||
func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
func (s *Server) forwardPollRequest(v *visitor, m *model.Message) {
|
||||
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
||||
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
|
||||
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
|
||||
@@ -1055,7 +1081,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
||||
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
|
||||
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
|
||||
if err != nil {
|
||||
@@ -1082,7 +1108,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
filename := readParam(r, "x-filename", "filename", "file", "f")
|
||||
attach := readParam(r, "x-attach", "attach", "a")
|
||||
if attach != "" || filename != "" {
|
||||
m.Attachment = &attachment{}
|
||||
m.Attachment = &model.Attachment{}
|
||||
}
|
||||
if filename != "" {
|
||||
m.Attachment.Name = filename
|
||||
@@ -1203,7 +1229,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
|
||||
// 7. curl -T file.txt ntfy.sh/mytopic
|
||||
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||
if m.Event == pollRequestEvent { // Case 1
|
||||
return s.handleBodyDiscard(body)
|
||||
} else if unifiedpush {
|
||||
@@ -1226,7 +1252,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
|
||||
func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if utf8.Valid(body.PeekedBytes) {
|
||||
m.Message = string(body.PeekedBytes) // Do not trim
|
||||
} else {
|
||||
@@ -1236,7 +1262,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
|
||||
func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if !utf8.Valid(body.PeekedBytes) {
|
||||
return errHTTPBadRequestMessageNotUTF8.With(m)
|
||||
}
|
||||
@@ -1249,7 +1275,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
||||
func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
||||
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1274,7 +1300,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM
|
||||
|
||||
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
|
||||
// The template file must be in the templates directory, or in the configured template directory.
|
||||
func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error {
|
||||
func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error {
|
||||
if !templateNameRegex.MatchString(templateName) {
|
||||
return errHTTPBadRequestTemplateFileNotFound
|
||||
}
|
||||
@@ -1316,7 +1342,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str
|
||||
|
||||
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
|
||||
// message, title, and priority parameters.
|
||||
func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error {
|
||||
func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error {
|
||||
var err error
|
||||
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
|
||||
return err
|
||||
@@ -1357,7 +1383,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
|
||||
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
|
||||
}
|
||||
|
||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
|
||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
|
||||
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
|
||||
return errHTTPBadRequestAttachmentsDisallowed.With(m)
|
||||
}
|
||||
@@ -1381,7 +1407,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
||||
}
|
||||
}
|
||||
if m.Attachment == nil {
|
||||
m.Attachment = &attachment{}
|
||||
m.Attachment = &model.Attachment{}
|
||||
}
|
||||
var ext string
|
||||
m.Attachment.Expires = attachmentExpiry
|
||||
@@ -1408,9 +1434,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
||||
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
@@ -1419,9 +1445,9 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
||||
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||
@@ -1433,7 +1459,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
|
||||
}
|
||||
|
||||
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
encoder := func(msg *message) (string, error) {
|
||||
encoder := func(msg *model.Message) (string, error) {
|
||||
if msg.Event == messageEvent { // only handle default events
|
||||
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
||||
}
|
||||
@@ -1469,7 +1495,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||
closed = true
|
||||
wlock.Unlock()
|
||||
}()
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
sub := func(v *visitor, msg *model.Message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
@@ -1631,7 +1657,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
}
|
||||
}
|
||||
})
|
||||
sub := func(v *visitor, msg *message) error {
|
||||
sub := func(v *visitor, msg *model.Message) error {
|
||||
if !filters.Pass(msg) {
|
||||
return nil
|
||||
}
|
||||
@@ -1678,7 +1704,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
|
||||
func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) {
|
||||
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
||||
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
||||
since, err = parseSince(r, poll)
|
||||
@@ -1759,11 +1785,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi
|
||||
|
||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||
// marker, returning only messages that are newer than the marker.
|
||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||
if since.IsNone() {
|
||||
return nil
|
||||
}
|
||||
messages := make([]*message, 0)
|
||||
messages := make([]*model.Message, 0)
|
||||
for _, t := range topics {
|
||||
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
|
||||
if err != nil {
|
||||
@@ -1786,7 +1812,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
||||
//
|
||||
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
|
||||
// "all" for all messages, or "latest" for the most recent message for a topic
|
||||
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
||||
func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
|
||||
since := readParam(r, "x-since", "since", "si")
|
||||
|
||||
// Easy cases (empty, all, none)
|
||||
@@ -2017,7 +2043,7 @@ func (s *Server) sendDelayedMessages() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error {
|
||||
logvm(v, m).Debug("Sending delayed message")
|
||||
s.mu.RLock()
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
|
||||
@@ -38,6 +38,12 @@
|
||||
#
|
||||
# firebase-key-file: <filename>
|
||||
|
||||
# If "database-url" is set, ntfy will use PostgreSQL for database-backed stores instead of SQLite.
|
||||
# Currently this applies to the web push subscription store. Message cache and user manager support
|
||||
# will be added in future releases. When set, the "web-push-file" option is not required.
|
||||
#
|
||||
# database-url: "postgres://user:pass@host:5432/ntfy"
|
||||
|
||||
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
|
||||
# This allows for service restarts without losing messages in support of the since= parameter.
|
||||
#
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,14 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
return s.writeJSON(w, &apiVersionResponse{
|
||||
Version: s.config.BuildVersion,
|
||||
Commit: s.config.BuildCommit,
|
||||
Date: s.config.BuildDate,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
users, err := s.userManager.Users()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
@@ -9,393 +10,452 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestVersion_Admin(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.BuildVersion = "1.2.3"
|
||||
c.BuildCommit = "abcdef0"
|
||||
c.BuildDate = "2026-02-08T00:00:00Z"
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin and regular user
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Admin can access /v1/version
|
||||
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
var versionResponse apiVersionResponse
|
||||
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
|
||||
require.Equal(t, "1.2.3", versionResponse.Version)
|
||||
require.Equal(t, "abcdef0", versionResponse.Commit)
|
||||
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
|
||||
|
||||
// Non-admin user cannot access /v1/version
|
||||
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Unauthenticated user cannot access /v1/version
|
||||
rr = request(t, s, "GET", "/v1/version", "", nil)
|
||||
require.Equal(t, 401, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_AddRemove(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Create user with tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 4, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Nil(t, users[1].Tier)
|
||||
require.Equal(t, "emma", users[2].Name)
|
||||
require.Equal(t, user.RoleUser, users[2].Role)
|
||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||
require.Equal(t, user.Everyone, users[3].Name)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check user was deleted
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "emma", users[1].Name)
|
||||
require.Equal(t, user.Everyone, users[2].Name)
|
||||
|
||||
// Reject invalid user change
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Create user with tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 4, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Nil(t, users[1].Tier)
|
||||
require.Equal(t, "emma", users[2].Name)
|
||||
require.Equal(t, user.RoleUser, users[2].Role)
|
||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||
require.Equal(t, user.Everyone, users[3].Name)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check user was deleted
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "emma", users[1].Name)
|
||||
require.Equal(t, user.Everyone, users[2].Name)
|
||||
|
||||
// Reject invalid user change
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_AddWithPasswordHash(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check that user can login with password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, user.RoleAdmin, users[0].Role)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change password via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserTier(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users again
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check new tier
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "not-ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
||||
|
||||
// Try to change password via API
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Cannot create user with invalid username
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
|
||||
// Cannot create user if user already exists
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot create user with invalid tier
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot delete user as non-admin
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Subscribing not allowed
|
||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
|
||||
// Grant access
|
||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Now subscribing is allowed
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Reset access
|
||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Subscribing not allowed (again)
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Grant access fails, because non-admin
|
||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin, grant access to "gol*" topics
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||
|
||||
start, timeTaken := time.Now(), atomic.Int64{}
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||
// Check that user can login with password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
timeTaken.Store(time.Since(start).Milliseconds())
|
||||
}()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Reset access
|
||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||
waitFor(t, func() bool {
|
||||
return timeTaken.Load() >= 500
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, user.RoleAdmin, users[0].Role)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change password via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserTier(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users again
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier2",
|
||||
}))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Make sure first password fails
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check new tier
|
||||
users, err = s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
|
||||
// Create user with tier via API
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with first password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "not-ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Change user password and tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Try to login with second password
|
||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
||||
|
||||
// Try to change password via API
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Cannot create user with invalid username
|
||||
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
|
||||
// Cannot create user if user already exists
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot create user with invalid tier
|
||||
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot delete user as non-admin
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Subscribing not allowed
|
||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
|
||||
// Grant access
|
||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Now subscribing is allowed
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Reset access
|
||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Subscribing not allowed (again)
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
|
||||
// Grant access fails, because non-admin
|
||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin, grant access to "gol*" topics
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||
|
||||
start, timeTaken := time.Now(), atomic.Int64{}
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
timeTaken.Store(time.Since(start).Milliseconds())
|
||||
}()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Reset access
|
||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||
waitFor(t, func() bool {
|
||||
return timeTaken.Load() >= 500
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"fmt"
|
||||
"google.golang.org/api/option"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"strings"
|
||||
@@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
|
||||
}
|
||||
}
|
||||
|
||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||
if !v.FirebaseAllowed() {
|
||||
return errFirebaseTemporarilyBanned
|
||||
}
|
||||
@@ -121,7 +122,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
|
||||
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
|
||||
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
|
||||
// to Firebase here. This is mainly for iOS to support self-hosted servers.
|
||||
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
|
||||
func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) {
|
||||
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
||||
var apnsConfig *messaging.APNSConfig
|
||||
switch m.Event {
|
||||
@@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
|
||||
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
|
||||
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
|
||||
// Extension in iOS can modify the message.
|
||||
func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig {
|
||||
func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig {
|
||||
apnsData := make(map[string]any)
|
||||
for k, v := range data {
|
||||
apnsData[k] = v
|
||||
@@ -296,7 +297,7 @@ func maybeTruncateAPNSBodyMessage(s string) string {
|
||||
//
|
||||
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
||||
// most importantly, the PollID.
|
||||
func toPollRequest(m *message) *message {
|
||||
func toPollRequest(m *model.Message) *model.Message {
|
||||
pr := newPollRequestMessage(m.Topic, m.ID)
|
||||
pr.ID = m.ID
|
||||
pr.Time = m.Time
|
||||
|
||||
@@ -4,6 +4,7 @@ package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ var (
|
||||
type firebaseClient struct {
|
||||
}
|
||||
|
||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
||||
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||
return errFirebaseNotAvailable
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -131,7 +132,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
m.Click = "https://google.com"
|
||||
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
|
||||
m.Title = "some title"
|
||||
m.Actions = []*action{
|
||||
m.Actions = []*model.Action{
|
||||
{
|
||||
ID: "123",
|
||||
Action: "view",
|
||||
@@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
m.Attachment = &attachment{
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "some file.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 12345,
|
||||
@@ -346,16 +347,16 @@ func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||
client := newFirebaseClient(sender, &testAuther{})
|
||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 1, len(sender.Messages()))
|
||||
|
||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.Messages()))
|
||||
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 2, len(sender.Messages()))
|
||||
|
||||
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
||||
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"}))
|
||||
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||
require.Equal(t, 0, len(sender.Messages()))
|
||||
}
|
||||
|
||||
@@ -6,23 +6,25 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||
c := newTestConfig(t)
|
||||
c.AttachmentCacheDir = ""
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||
c := newTestConfig(t)
|
||||
c.AttachmentCacheDir = ""
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Publish a message
|
||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
m := toMessage(t, rr.Body.String())
|
||||
// Publish a message
|
||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
m := toMessage(t, rr.Body.String())
|
||||
|
||||
// Expire message
|
||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||
// Expire message
|
||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||
|
||||
// Does not panic
|
||||
s.pruneMessages()
|
||||
// Does not panic
|
||||
s.pruneMessages()
|
||||
|
||||
// Actually deleted
|
||||
_, err := s.messageCache.Message(m.ID)
|
||||
require.Equal(t, errMessageNotFound, err)
|
||||
// Actually deleted
|
||||
_, err := s.messageCache.Message(m.ID)
|
||||
require.Equal(t, errMessageNotFound, err)
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -11,6 +11,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *
|
||||
|
||||
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
|
||||
// Failures will be logged, but not returned to the caller.
|
||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
|
||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) {
|
||||
u, sender := v.User(), m.Sender.String()
|
||||
if u != nil {
|
||||
sender = u.Name
|
||||
|
||||
@@ -14,217 +14,224 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
||||
var called, verified atomic.Bool
|
||||
var code atomic.Pointer[string]
|
||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||
if code.Load() != nil {
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
var called, verified atomic.Bool
|
||||
var code atomic.Pointer[string]
|
||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||
if code.Load() != nil {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||
code.Store(util.String("123456"))
|
||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||
if verified.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||
verified.Store(true)
|
||||
} else {
|
||||
t.Fatal("Unexpected path:", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer twilioVerifyServer.Close()
|
||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||
code.Store(util.String("123456"))
|
||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||
if verified.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||
verified.Store(true)
|
||||
} else {
|
||||
t.Fatal("Unexpected path:", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer twilioVerifyServer.Close()
|
||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioCallsServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioVerifyService = "VA1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioCallsServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioVerifyService = "VA1234567890"
|
||||
s := newTestServer(t, c)
|
||||
// Send verification code for phone number
|
||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return *code.Load() == "123456"
|
||||
})
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
// Add phone number with code
|
||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return verified.Load()
|
||||
})
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(phoneNumbers))
|
||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||
|
||||
// Send verification code for phone number
|
||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return *code.Load() == "123456"
|
||||
})
|
||||
// Do the thing
|
||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
|
||||
// Add phone number with code
|
||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return verified.Load()
|
||||
})
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(phoneNumbers))
|
||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||
// Remove the phone number
|
||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Do the thing
|
||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes",
|
||||
// Verify the phone number is gone from the DB
|
||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
|
||||
// Remove the phone number
|
||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Verify the phone number is gone from the DB
|
||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes", // <<<------
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes", // <<<------
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
||||
<Response>
|
||||
<Pause length="1"/>
|
||||
<Say language="de-DE" loop="3">
|
||||
@@ -240,88 +247,97 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||
</Say>
|
||||
<Say language="de-DE">Auf Wiederhören.</Say>
|
||||
</Response>`))
|
||||
s := newTestServer(t, c)
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+invalid",
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+invalid",
|
||||
})
|
||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+123123",
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+123123",
|
||||
})
|
||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+1234",
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+1234",
|
||||
})
|
||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||
})
|
||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
wpush "heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -82,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
||||
if err != nil {
|
||||
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
||||
return
|
||||
}
|
||||
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON()))
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
||||
return
|
||||
@@ -128,7 +130,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
warningSent := make([]*webPushSubscription, 0)
|
||||
warningSent := make([]*wpush.Subscription, 0)
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
|
||||
@@ -143,7 +145,7 @@ func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
|
||||
func (s *Server) sendWebPushNotification(sub *wpush.Subscription, message []byte, contexters ...log.Contexter) error {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
|
||||
payload := &webpush.Subscription{
|
||||
Endpoint: sub.Endpoint,
|
||||
|
||||
@@ -4,6 +4,8 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||
// Nothing to see here
|
||||
}
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -18,6 +14,11 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,237 +26,262 @@ const (
|
||||
)
|
||||
|
||||
func TestServer_WebPush_Enabled(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf2 := newTestConfig(t)
|
||||
s2 := newTestServer(t, conf2)
|
||||
conf2 := newTestConfig(t)
|
||||
s2 := newTestServer(t, conf2)
|
||||
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf3 := newTestConfigWithWebPush(t)
|
||||
s3 := newTestServer(t, conf3)
|
||||
conf3 := newTestConfigWithWebPush(t)
|
||||
s3 := newTestServer(t, conf3)
|
||||
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
|
||||
})
|
||||
}
|
||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Delete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
s := newTestServer(t, config)
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
forEachBackend(t, func(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
var received atomic.Bool
|
||||
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
endpoint := pushService.URL + "/push-receive"
|
||||
addSubscription(t, s, endpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -285,7 +311,9 @@ func newTestConfigWithWebPush(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||
require.Nil(t, err)
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
if conf.DatabaseURL == "" {
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
}
|
||||
conf.WebPushEmailAddress = "testing@example.com"
|
||||
conf.WebPushPrivateKey = privateKey
|
||||
conf.WebPushPublicKey = publicKey
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
type mailer interface {
|
||||
Send(v *visitor, m *message, to string) error
|
||||
Send(v *visitor, m *model.Message, to string) error
|
||||
Counts() (total int64, success int64, failure int64)
|
||||
}
|
||||
|
||||
@@ -27,7 +28,7 @@ type smtpSender struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
||||
func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error {
|
||||
return s.withCount(v, m, func() error {
|
||||
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
|
||||
if err != nil {
|
||||
@@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) {
|
||||
return s.success + s.failure, s.success, s.failure
|
||||
}
|
||||
|
||||
func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
||||
func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error {
|
||||
err := fn()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) {
|
||||
func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) {
|
||||
topicURL := baseURL + "/" + m.Topic
|
||||
subject := m.Title
|
||||
if subject == "" {
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestFormatMail_Basic(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustEmojis(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustOtherTags(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_JustPriority(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_UTF8Subject(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
@@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
||||
}
|
||||
|
||||
func TestFormatMail_WithAllTheThings(t *testing.T) {
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||
ID: "abc",
|
||||
Time: 1640382204,
|
||||
Event: "message",
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -183,7 +184,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *smtpSession) publishMessage(m *message) error {
|
||||
func (s *smtpSession) publishMessage(m *model.Message) error {
|
||||
// Extract remote address (for rate limiting)
|
||||
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ type topicSubscriber struct {
|
||||
}
|
||||
|
||||
// subscriber is a function that is called for every new message on a topic
|
||||
type subscriber func(v *visitor, msg *message) error
|
||||
type subscriber func(v *visitor, msg *model.Message) error
|
||||
|
||||
// newTopic creates a new topic
|
||||
func newTopic(id string) *topic {
|
||||
@@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) {
|
||||
}
|
||||
|
||||
// Publish asynchronously publishes to all subscribers
|
||||
func (t *topic) Publish(v *visitor, m *message) error {
|
||||
func (t *topic) Publish(v *visitor, m *model.Message) error {
|
||||
go func() {
|
||||
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
||||
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
canceled1 := atomic.Bool{}
|
||||
@@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||
func TestTopic_CancelSubscribersUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
canceled1 := atomic.Bool{}
|
||||
@@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) {
|
||||
cancel: func() {},
|
||||
}
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
subFn := func(v *visitor, msg *model.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
277
server/types.go
277
server/types.go
@@ -2,219 +2,78 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// List of possible events
|
||||
// Event constants
|
||||
const (
|
||||
openEvent = "open"
|
||||
keepaliveEvent = "keepalive"
|
||||
messageEvent = "message"
|
||||
messageDeleteEvent = "message_delete"
|
||||
messageClearEvent = "message_clear"
|
||||
pollRequestEvent = "poll_request"
|
||||
openEvent = model.OpenEvent
|
||||
keepaliveEvent = model.KeepaliveEvent
|
||||
messageEvent = model.MessageEvent
|
||||
messageDeleteEvent = model.MessageDeleteEvent
|
||||
messageClearEvent = model.MessageClearEvent
|
||||
pollRequestEvent = model.PollRequestEvent
|
||||
messageIDLength = model.MessageIDLength
|
||||
)
|
||||
|
||||
const (
|
||||
messageIDLength = 12
|
||||
// SinceMarker aliases
|
||||
var (
|
||||
sinceAllMessages = model.SinceAllMessages
|
||||
sinceNoMessages = model.SinceNoMessages
|
||||
sinceLatestMessage = model.SinceLatestMessage
|
||||
)
|
||||
|
||||
// message represents a message published to a topic
|
||||
type message struct {
|
||||
ID string `json:"id"` // Random message ID
|
||||
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*action `json:"actions,omitempty"`
|
||||
Attachment *attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
// Error aliases
|
||||
var (
|
||||
errMessageNotFound = model.ErrMessageNotFound
|
||||
)
|
||||
|
||||
func (m *message) Context() log.Context {
|
||||
fields := map[string]any{
|
||||
"topic": m.Topic,
|
||||
"message_id": m.ID,
|
||||
"message_sequence_id": m.SequenceID,
|
||||
"message_time": m.Time,
|
||||
"message_event": m.Event,
|
||||
"message_body_size": len(m.Message),
|
||||
}
|
||||
if m.Sender.IsValid() {
|
||||
fields["message_sender"] = m.Sender.String()
|
||||
}
|
||||
if m.User != "" {
|
||||
fields["message_user"] = m.User
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// forJSON returns a copy of the message suitable for JSON output.
|
||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||
func (m *message) forJSON() *message {
|
||||
if m.SequenceID == m.ID {
|
||||
clone := *m
|
||||
clone.SequenceID = ""
|
||||
return &clone
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type attachment struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type action struct {
|
||||
ID string `json:"id"`
|
||||
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
||||
Label string `json:"label"` // action button label
|
||||
Clear bool `json:"clear"` // clear notification after successful execution
|
||||
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
||||
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
||||
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
||||
Body string `json:"body,omitempty"` // used in "http" action
|
||||
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
||||
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
||||
Value string `json:"value,omitempty"` // used in "copy" action
|
||||
}
|
||||
|
||||
func newAction() *action {
|
||||
return &action{
|
||||
Headers: make(map[string]string),
|
||||
Extras: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// publishMessage is used as input when publishing as JSON
|
||||
type publishMessage struct {
|
||||
Topic string `json:"topic"`
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Priority int `json:"priority"`
|
||||
Tags []string `json:"tags"`
|
||||
Click string `json:"click"`
|
||||
Icon string `json:"icon"`
|
||||
Actions []action `json:"actions"`
|
||||
Attach string `json:"attach"`
|
||||
Markdown bool `json:"markdown"`
|
||||
Filename string `json:"filename"`
|
||||
Email string `json:"email"`
|
||||
Call string `json:"call"`
|
||||
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
||||
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
||||
Delay string `json:"delay"`
|
||||
}
|
||||
|
||||
// messageEncoder is a function that knows how to encode a message
|
||||
type messageEncoder func(msg *message) (string, error)
|
||||
|
||||
// newMessage creates a new message with the current timestamp
|
||||
func newMessage(event, topic, msg string) *message {
|
||||
return &message{
|
||||
ID: util.RandomString(messageIDLength),
|
||||
Time: time.Now().Unix(),
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
// newOpenMessage is a convenience method to create an open message
|
||||
func newOpenMessage(topic string) *message {
|
||||
return newMessage(openEvent, topic, "")
|
||||
}
|
||||
|
||||
// newKeepaliveMessage is a convenience method to create a keepalive message
|
||||
func newKeepaliveMessage(topic string) *message {
|
||||
return newMessage(keepaliveEvent, topic, "")
|
||||
}
|
||||
|
||||
// newDefaultMessage is a convenience method to create a notification message
|
||||
func newDefaultMessage(topic, msg string) *message {
|
||||
return newMessage(messageEvent, topic, msg)
|
||||
}
|
||||
// Constructors and helpers
|
||||
var (
|
||||
newMessage = model.NewMessage
|
||||
newDefaultMessage = model.NewDefaultMessage
|
||||
newOpenMessage = model.NewOpenMessage
|
||||
newKeepaliveMessage = model.NewKeepaliveMessage
|
||||
newActionMessage = model.NewActionMessage
|
||||
newAction = model.NewAction
|
||||
newSinceTime = model.NewSinceTime
|
||||
newSinceID = model.NewSinceID
|
||||
validMessageID = model.ValidMessageID
|
||||
)
|
||||
|
||||
// newPollRequestMessage is a convenience method to create a poll request message
|
||||
func newPollRequestMessage(topic, pollID string) *message {
|
||||
func newPollRequestMessage(topic, pollID string) *model.Message {
|
||||
m := newMessage(pollRequestEvent, topic, newMessageBody)
|
||||
m.PollID = pollID
|
||||
return m
|
||||
}
|
||||
|
||||
// newActionMessage creates a new action message (message_delete or message_clear)
|
||||
func newActionMessage(event, topic, sequenceID string) *message {
|
||||
m := newMessage(event, topic, "")
|
||||
m.SequenceID = sequenceID
|
||||
return m
|
||||
// publishMessage is used as input when publishing as JSON
|
||||
type publishMessage struct {
|
||||
Topic string `json:"topic"`
|
||||
SequenceID string `json:"sequence_id"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Priority int `json:"priority"`
|
||||
Tags []string `json:"tags"`
|
||||
Click string `json:"click"`
|
||||
Icon string `json:"icon"`
|
||||
Actions []model.Action `json:"actions"`
|
||||
Attach string `json:"attach"`
|
||||
Markdown bool `json:"markdown"`
|
||||
Filename string `json:"filename"`
|
||||
Email string `json:"email"`
|
||||
Call string `json:"call"`
|
||||
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
||||
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
||||
Delay string `json:"delay"`
|
||||
}
|
||||
|
||||
func validMessageID(s string) bool {
|
||||
return util.ValidRandomString(s, messageIDLength)
|
||||
}
|
||||
|
||||
type sinceMarker struct {
|
||||
time time.Time
|
||||
id string
|
||||
}
|
||||
|
||||
func newSinceTime(timestamp int64) sinceMarker {
|
||||
return sinceMarker{time.Unix(timestamp, 0), ""}
|
||||
}
|
||||
|
||||
func newSinceID(id string) sinceMarker {
|
||||
return sinceMarker{time.Unix(0, 0), id}
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsAll() bool {
|
||||
return t == sinceAllMessages
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsNone() bool {
|
||||
return t == sinceNoMessages
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsLatest() bool {
|
||||
return t == sinceLatestMessage
|
||||
}
|
||||
|
||||
func (t sinceMarker) IsID() bool {
|
||||
return t.id != "" && t.id != "latest"
|
||||
}
|
||||
|
||||
func (t sinceMarker) Time() time.Time {
|
||||
return t.time
|
||||
}
|
||||
|
||||
func (t sinceMarker) ID() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
var (
|
||||
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
|
||||
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
|
||||
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
|
||||
)
|
||||
// messageEncoder is a function that knows how to encode a message
|
||||
type messageEncoder func(msg *model.Message) (string, error)
|
||||
|
||||
type queryFilter struct {
|
||||
ID string
|
||||
@@ -246,7 +105,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *queryFilter) Pass(msg *message) bool {
|
||||
func (q *queryFilter) Pass(msg *model.Message) bool {
|
||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||
return true // filters only apply to messages
|
||||
} else if q.ID != "" && msg.ID != q.ID {
|
||||
@@ -320,6 +179,12 @@ type apiHealthResponse struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type apiVersionResponse struct {
|
||||
Version string `json:"version"`
|
||||
Commit string `json:"commit"`
|
||||
Date string `json:"date"`
|
||||
}
|
||||
|
||||
type apiStatsResponse struct {
|
||||
Messages int64 `json:"messages"`
|
||||
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
|
||||
@@ -564,12 +429,12 @@ const (
|
||||
)
|
||||
|
||||
type webPushPayload struct {
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *message `json:"message"`
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *model.Message `json:"message"`
|
||||
}
|
||||
|
||||
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
|
||||
func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload {
|
||||
return &webPushPayload{
|
||||
Event: webPushMessageEvent,
|
||||
SubscriptionID: subscriptionID,
|
||||
@@ -587,22 +452,6 @@ func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
|
||||
}
|
||||
}
|
||||
|
||||
type webPushSubscription struct {
|
||||
ID string
|
||||
Endpoint string
|
||||
Auth string
|
||||
P256dh string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func (w *webPushSubscription) Context() log.Context {
|
||||
return map[string]any{
|
||||
"web_push_subscription_id": w.ID,
|
||||
"web_push_subscription_user_id": w.UserID,
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// https://developer.mozilla.org/en-US/docs/Web/Manifest
|
||||
type webManifestResponse struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
@@ -53,7 +54,7 @@ const (
|
||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||
type visitor struct {
|
||||
config *Config
|
||||
messageCache *messageCache
|
||||
messageCache message.Store
|
||||
userManager *user.Manager // May be nil
|
||||
ip netip.Addr // Visitor IP address
|
||||
user *user.User // Only set if authenticated user, otherwise nil
|
||||
@@ -114,7 +115,7 @@ const (
|
||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messages, emails, calls int64
|
||||
if user != nil {
|
||||
messages = user.Stats.Messages
|
||||
|
||||
@@ -1,285 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
var (
|
||||
errWebPushNoRows = errors.New("no rows found")
|
||||
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
const (
|
||||
createWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
selectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
selectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
insertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentWebPushSchemaVersion = 1
|
||||
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
type webPushStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(selectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||
// existing entries for a given endpoint.
|
||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rowsCount.Close()
|
||||
var subscriptionCount int
|
||||
if !rowsCount.Next() {
|
||||
return errWebPushNoRows
|
||||
}
|
||||
if err := rowsCount.Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rowsCount.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
var subscriptionID string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return errWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
|
||||
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
|
||||
subscriptions := make([]*webPushSubscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &webPushSubscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
|
||||
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return errWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *webPushStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "u_1234")
|
||||
|
||||
subs2, err := webPush.SubscriptionsForTopic("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs2, 1)
|
||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert 10 subscriptions with the same IP address
|
||||
for i := 0; i < 10; i++ {
|
||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
// Another one for the same endpoint should be fine
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different endpoint it should fail
|
||||
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different IP address it should be fine again
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics, and another with one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
// Update the first subscription to have only one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
|
||||
}
|
||||
|
||||
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
|
||||
|
||||
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
require.Nil(t, err)
|
||||
defer rows.Close()
|
||||
var endpoint string
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&endpoint))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, testWebPushEndpoint, endpoint)
|
||||
require.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Should not be cleaned up yet
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// Run expiration
|
||||
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Run expiration
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// List again, should be 0
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func newTestWebPushStore(t *testing.T) *webPushStore {
|
||||
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
return webPush
|
||||
}
|
||||
1604
user/manager.go
1604
user/manager.go
File diff suppressed because it is too large
Load Diff
2239
user/manager_test.go
2239
user/manager_test.go
File diff suppressed because it is too large
Load Diff
986
user/store.go
Normal file
986
user/store.go
Normal file
@@ -0,0 +1,986 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Store is the interface for a user database store
|
||||
type Store interface {
|
||||
// User operations
|
||||
UserByID(id string) (*User, error)
|
||||
User(username string) (*User, error)
|
||||
UserByToken(token string) (*User, error)
|
||||
UserByStripeCustomer(customerID string) (*User, error)
|
||||
UserIDByUsername(username string) (string, error)
|
||||
Users() ([]*User, error)
|
||||
UsersCount() (int64, error)
|
||||
AddUser(username, hash string, role Role, provisioned bool) error
|
||||
RemoveUser(username string) error
|
||||
MarkUserRemoved(userID string) error
|
||||
RemoveDeletedUsers() error
|
||||
ChangePassword(username, hash string) error
|
||||
ChangeRole(username string, role Role) error
|
||||
ChangeProvisioned(username string, provisioned bool) error
|
||||
ChangeSettings(userID string, prefs *Prefs) error
|
||||
ChangeTier(username, tierCode string) error
|
||||
ResetTier(username string) error
|
||||
UpdateStats(userID string, stats *Stats) error
|
||||
ResetStats() error
|
||||
|
||||
// Token operations
|
||||
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
|
||||
Token(userID, token string) (*Token, error)
|
||||
Tokens(userID string) ([]*Token, error)
|
||||
AllProvisionedTokens() ([]*Token, error)
|
||||
ChangeTokenLabel(userID, token, label string) error
|
||||
ChangeTokenExpiry(userID, token string, expires time.Time) error
|
||||
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
|
||||
RemoveToken(userID, token string) error
|
||||
RemoveExpiredTokens() error
|
||||
TokenCount(userID string) (int, error)
|
||||
RemoveExcessTokens(userID string, maxCount int) error
|
||||
|
||||
// Access operations
|
||||
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
|
||||
AllGrants() (map[string][]Grant, error)
|
||||
Grants(username string) ([]Grant, error)
|
||||
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
|
||||
ResetAccess(username, topicPattern string) error
|
||||
ResetAllProvisionedAccess() error
|
||||
Reservations(username string) ([]Reservation, error)
|
||||
HasReservation(username, topic string) (bool, error)
|
||||
ReservationsCount(username string) (int64, error)
|
||||
ReservationOwner(topic string) (string, error)
|
||||
OtherAccessCount(username, topic string) (int, error)
|
||||
|
||||
// Tier operations
|
||||
AddTier(tier *Tier) error
|
||||
UpdateTier(tier *Tier) error
|
||||
RemoveTier(code string) error
|
||||
Tiers() ([]*Tier, error)
|
||||
Tier(code string) (*Tier, error)
|
||||
TierByStripePrice(priceID string) (*Tier, error)
|
||||
|
||||
// Phone operations
|
||||
PhoneNumbers(userID string) ([]string, error)
|
||||
AddPhoneNumber(userID, phoneNumber string) error
|
||||
RemovePhoneNumber(userID, phoneNumber string) error
|
||||
|
||||
// Other stuff
|
||||
ChangeBilling(username string, billing *Billing) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries
|
||||
type storeQueries struct {
|
||||
// User queries
|
||||
selectUserByID string
|
||||
selectUserByName string
|
||||
selectUserByToken string
|
||||
selectUserByStripeID string
|
||||
selectUsernames string
|
||||
selectUserCount string
|
||||
selectUserIDFromUsername string
|
||||
insertUser string
|
||||
updateUserPass string
|
||||
updateUserRole string
|
||||
updateUserProvisioned string
|
||||
updateUserPrefs string
|
||||
updateUserStats string
|
||||
updateUserStatsResetAll string
|
||||
updateUserTier string
|
||||
updateUserDeleted string
|
||||
deleteUser string
|
||||
deleteUserTier string
|
||||
deleteUsersMarked string
|
||||
// Access queries
|
||||
selectTopicPerms string
|
||||
selectUserAllAccess string
|
||||
selectUserAccess string
|
||||
selectUserReservations string
|
||||
selectUserReservationsCount string
|
||||
selectUserReservationsOwner string
|
||||
selectUserHasReservation string
|
||||
selectOtherAccessCount string
|
||||
upsertUserAccess string
|
||||
deleteUserAccess string
|
||||
deleteUserAccessProvisioned string
|
||||
deleteTopicAccess string
|
||||
deleteAllAccess string
|
||||
// Token queries
|
||||
selectToken string
|
||||
selectTokens string
|
||||
selectTokenCount string
|
||||
selectAllProvisionedTokens string
|
||||
upsertToken string
|
||||
updateTokenLabel string
|
||||
updateTokenExpiry string
|
||||
updateTokenLastAccess string
|
||||
deleteToken string
|
||||
deleteProvisionedToken string
|
||||
deleteAllToken string
|
||||
deleteExpiredTokens string
|
||||
deleteExcessTokens string
|
||||
// Tier queries
|
||||
insertTier string
|
||||
selectTiers string
|
||||
selectTierByCode string
|
||||
selectTierByPriceID string
|
||||
updateTier string
|
||||
deleteTier string
|
||||
// Phone queries
|
||||
selectPhoneNumbers string
|
||||
insertPhoneNumber string
|
||||
deletePhoneNumber string
|
||||
// Billing queries
|
||||
updateBilling string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that work across database backends
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByID(id string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) User(username string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByToken(token string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// Users returns a list of users
|
||||
func (s *commonStore) Users() ([]*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUsernames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
rows.Close()
|
||||
users := make([]*User, 0)
|
||||
for _, username := range usernames {
|
||||
user, err := s.User(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UsersCount returns the number of users in the database
|
||||
func (s *commonStore) UsersCount() (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddUser adds a user with the given username, password hash and role
|
||||
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
now := time.Now().Unix()
|
||||
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrUserExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveUser deletes the user with the given username
|
||||
func (s *commonStore) RemoveUser(username string) error {
|
||||
if !AllowedUsername(username) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
|
||||
func (s *commonStore) MarkUserRemoved(userID string) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Get username for deleteUserAccess query
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||
func (s *commonStore) RemoveDeletedUsers() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password
|
||||
func (s *commonStore) ChangePassword(username, hash string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeRole changes a user's role
|
||||
func (s *commonStore) ChangeRole(username string, role Role) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
|
||||
return err
|
||||
}
|
||||
// If changing to admin, remove all access entries
|
||||
if role == RoleAdmin {
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ChangeProvisioned changes the provisioned status of a user
|
||||
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeSettings persists the user settings
|
||||
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
|
||||
b, err := json.Marshal(prefs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTier changes a user's tier using the tier code
|
||||
func (s *commonStore) ChangeTier(username, tierCode string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (s *commonStore) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteUserTier, username)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateStats updates the user statistics
|
||||
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetStats resets all user stats in the user database
|
||||
func (s *commonStore) ResetStats() error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var id, username, hash, role, prefs, syncTopic string
|
||||
var provisioned bool
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
||||
var messages, emails, calls int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := &User{
|
||||
ID: id,
|
||||
Name: username,
|
||||
Hash: hash,
|
||||
Role: Role(role),
|
||||
Prefs: &Prefs{},
|
||||
SyncTopic: syncTopic,
|
||||
Provisioned: provisioned,
|
||||
Stats: &Stats{
|
||||
Messages: messages,
|
||||
Emails: emails,
|
||||
Calls: calls,
|
||||
},
|
||||
Billing: &Billing{
|
||||
StripeCustomerID: stripeCustomerID.String,
|
||||
StripeSubscriptionID: stripeSubscriptionID.String,
|
||||
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
|
||||
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
|
||||
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
|
||||
},
|
||||
Deleted: deleted.Valid,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tierCode.Valid {
|
||||
user.Tier = &Tier{
|
||||
ID: tierID.String,
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token
|
||||
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
|
||||
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: lastAccess,
|
||||
LastOrigin: lastOrigin,
|
||||
Expires: expires,
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Token returns a specific token for a user
|
||||
func (s *commonStore) Token(userID, token string) (*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectToken, userID, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readToken(rows)
|
||||
}
|
||||
|
||||
// Tokens returns all existing tokens for the user with the given user ID
|
||||
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokens, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// AllProvisionedTokens returns all provisioned tokens
|
||||
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ChangeTokenLabel updates a token's label
|
||||
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTokenExpiry updates a token's expiry time
|
||||
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenLastAccess updates a token's last access time and origin
|
||||
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveToken deletes the token
|
||||
func (s *commonStore) RemoveToken(userID, token string) error {
|
||||
if token == "" {
|
||||
return errNoTokenProvided
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveExpiredTokens deletes all expired tokens from the database
|
||||
func (s *commonStore) RemoveExpiredTokens() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenCount returns the number of tokens for a user
|
||||
func (s *commonStore) TokenCount(userID string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
|
||||
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
var provisioned bool
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||
if err != nil {
|
||||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
|
||||
// The found return value indicates whether an ACL entry was found at all.
|
||||
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, false, false, nil
|
||||
}
|
||||
if err := rows.Scan(&read, &write); err != nil {
|
||||
return false, false, false, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
return read, write, true, nil
|
||||
}
|
||||
|
||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAllAccess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make(map[string][]Grant, 0)
|
||||
for rows.Next() {
|
||||
var userID, topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := grants[userID]; !ok {
|
||||
grants[userID] = make([]Grant, 0)
|
||||
}
|
||||
grants[userID] = append(grants[userID], Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// Grants returns all user-specific access control entries
|
||||
func (s *commonStore) Grants(username string) ([]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAccess, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make([]Grant, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grants = append(grants, Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// AllowAccess adds or updates an entry in the access control list
|
||||
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAccess removes an access control list entry
|
||||
func (s *commonStore) ResetAccess(username, topicPattern string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if username == "" && topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteAllAccess)
|
||||
return err
|
||||
} else if topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetAllProvisionedAccess removes all provisioned access control entries
|
||||
func (s *commonStore) ResetAllProvisionedAccess() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
reservations := make([]Reservation, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var ownerRead, ownerWrite bool
|
||||
var everyoneRead, everyoneWrite sql.NullBool
|
||||
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reservations = append(reservations, Reservation{
|
||||
Topic: unescapeUnderscore(topic),
|
||||
Owner: NewPermission(ownerRead, ownerWrite),
|
||||
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
|
||||
})
|
||||
}
|
||||
return reservations, nil
|
||||
}
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (s *commonStore) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||
func (s *commonStore) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", nil
|
||||
}
|
||||
var ownerUserID string
|
||||
if err := rows.Scan(&ownerUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddTier creates a new tier in the database
|
||||
func (s *commonStore) AddTier(tier *Tier) error {
|
||||
if tier.ID == "" {
|
||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTier updates a tier's properties in the database
|
||||
func (s *commonStore) UpdateTier(tier *Tier) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTier deletes the tier with the given code
|
||||
func (s *commonStore) RemoveTier(code string) error {
|
||||
if !AllowedTier(code) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tiers returns a list of all Tier structs
|
||||
func (s *commonStore) Tiers() ([]*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTiers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tiers := make([]*Tier, 0)
|
||||
for {
|
||||
tier, err := s.readTier(rows)
|
||||
if errors.Is(err, ErrTierNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tiers = append(tiers, tier)
|
||||
}
|
||||
return tiers, nil
|
||||
}
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) Tier(code string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
var id, code, name string
|
||||
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTierNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tier{
|
||||
ID: id,
|
||||
Code: code,
|
||||
Name: name,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
phoneNumbers := make([]string, 0)
|
||||
for {
|
||||
phoneNumber, err := s.readPhoneNumber(rows)
|
||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
phoneNumbers = append(phoneNumbers, phoneNumber)
|
||||
}
|
||||
return phoneNumbers, nil
|
||||
}
|
||||
|
||||
// AddPhoneNumber adds a phone number to the user with the given user ID
|
||||
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
|
||||
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrPhoneNumberExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePhoneNumber deletes a phone number from the user with the given user ID
|
||||
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
|
||||
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
|
||||
return err
|
||||
}
|
||||
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
|
||||
var phoneNumber string
|
||||
if !rows.Next() {
|
||||
return "", ErrPhoneNumberNotFound
|
||||
}
|
||||
if err := rows.Scan(&phoneNumber); err != nil {
|
||||
return "", err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return phoneNumber, nil
|
||||
}
|
||||
|
||||
// ChangeBilling updates a user's billing fields
|
||||
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
|
||||
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserIDByUsername returns the user ID for the given username
|
||||
func (s *commonStore) UserIDByUsername(username string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", ErrUserNotFound
|
||||
}
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database
|
||||
func (s *commonStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
|
||||
}
|
||||
292
user/store_postgres.go
Normal file
292
user/store_postgres.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL queries
|
||||
const (
|
||||
// User queries
|
||||
postgresSelectUserByID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = $1
|
||||
`
|
||||
postgresSelectUserByName = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user_name = $1
|
||||
`
|
||||
postgresSelectUserByToken = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||
`
|
||||
postgresSelectUserByStripeID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = $1
|
||||
`
|
||||
postgresSelectUsernames = `
|
||||
SELECT user_name
|
||||
FROM "user"
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user_name
|
||||
`
|
||||
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
|
||||
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
|
||||
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
|
||||
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
|
||||
|
||||
// Access queries
|
||||
postgresSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN "user" u ON u.id = a.user_id
|
||||
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||
`
|
||||
postgresSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
postgresSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
`
|
||||
postgresSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = $1
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
postgresSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
AND topic = $2
|
||||
`
|
||||
postgresSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||
`
|
||||
postgresUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES (
|
||||
(SELECT id FROM "user" WHERE user_name = $1),
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||
$7
|
||||
)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
|
||||
`
|
||||
postgresDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
`
|
||||
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
|
||||
postgresDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||
AND topic = $3
|
||||
`
|
||||
postgresDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||
postgresUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
|
||||
`
|
||||
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
|
||||
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
|
||||
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||
postgresDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = $1
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = $2
|
||||
ORDER BY expires DESC
|
||||
LIMIT $3
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
postgresInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`
|
||||
postgresUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||
WHERE code = $13
|
||||
`
|
||||
postgresSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
postgresSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = $1
|
||||
`
|
||||
postgresSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||
`
|
||||
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
|
||||
|
||||
// Phone queries
|
||||
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||
|
||||
// Billing queries
|
||||
postgresUpdateBilling = `
|
||||
UPDATE "user"
|
||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||
WHERE user_name = $7
|
||||
`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
||||
func NewPostgresStore(dsn string) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgres(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
// User queries
|
||||
selectUserByID: postgresSelectUserByID,
|
||||
selectUserByName: postgresSelectUserByName,
|
||||
selectUserByToken: postgresSelectUserByToken,
|
||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||
selectUsernames: postgresSelectUsernames,
|
||||
selectUserCount: postgresSelectUserCount,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||
insertUser: postgresInsertUser,
|
||||
updateUserPass: postgresUpdateUserPass,
|
||||
updateUserRole: postgresUpdateUserRole,
|
||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||
updateUserPrefs: postgresUpdateUserPrefs,
|
||||
updateUserStats: postgresUpdateUserStats,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||
updateUserTier: postgresUpdateUserTier,
|
||||
updateUserDeleted: postgresUpdateUserDeleted,
|
||||
deleteUser: postgresDeleteUser,
|
||||
deleteUserTier: postgresDeleteUserTier,
|
||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms: postgresSelectTopicPerms,
|
||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||
selectUserAccess: postgresSelectUserAccess,
|
||||
selectUserReservations: postgresSelectUserReservations,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||
upsertUserAccess: postgresUpsertUserAccess,
|
||||
deleteUserAccess: postgresDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||
deleteAllAccess: postgresDeleteAllAccess,
|
||||
|
||||
// Token queries
|
||||
selectToken: postgresSelectToken,
|
||||
selectTokens: postgresSelectTokens,
|
||||
selectTokenCount: postgresSelectTokenCount,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||
upsertToken: postgresUpsertToken,
|
||||
updateTokenLabel: postgresUpdateTokenLabel,
|
||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||
deleteToken: postgresDeleteToken,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||
deleteAllToken: postgresDeleteAllToken,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||
|
||||
// Tier queries
|
||||
insertTier: postgresInsertTier,
|
||||
selectTiers: postgresSelectTiers,
|
||||
selectTierByCode: postgresSelectTierByCode,
|
||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||
updateTier: postgresUpdateTier,
|
||||
deleteTier: postgresDeleteTier,
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||
|
||||
// Billing queries
|
||||
updateBilling: postgresUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
113
user/store_postgres_schema.go
Normal file
113
user/store_postgres_schema.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema table management queries for Postgres
|
||||
const (
|
||||
postgresCurrentSchemaVersion = 6
|
||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgres(db)
|
||||
}
|
||||
if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgres(db *sql.DB) error {
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
208
user/store_postgres_test.go
Normal file
208
user/store_postgres_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func newTestPostgresStore(t *testing.T) user.Store {
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||
}
|
||||
// Create a unique schema for this test
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
setupDB, err := sql.Open("pgx", dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupDB.Close())
|
||||
// Open store with search_path set to the new schema
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
store, err := user.NewPostgresStore(u.String())
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
store.Close()
|
||||
cleanDB, err := sql.Open("pgx", dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
}
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresStoreAddUser(t *testing.T) {
|
||||
testStoreAddUser(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
|
||||
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveUser(t *testing.T) {
|
||||
testStoreRemoveUser(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByID(t *testing.T) {
|
||||
testStoreUserByID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByToken(t *testing.T) {
|
||||
testStoreUserByToken(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
|
||||
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUsers(t *testing.T) {
|
||||
testStoreUsers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUsersCount(t *testing.T) {
|
||||
testStoreUsersCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangePassword(t *testing.T) {
|
||||
testStoreChangePassword(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeRole(t *testing.T) {
|
||||
testStoreChangeRole(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokens(t *testing.T) {
|
||||
testStoreTokens(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
|
||||
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemove(t *testing.T) {
|
||||
testStoreTokenRemove(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
|
||||
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
|
||||
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
|
||||
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllowAccess(t *testing.T) {
|
||||
testStoreAllowAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
|
||||
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetAccess(t *testing.T) {
|
||||
testStoreResetAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetAccessAll(t *testing.T) {
|
||||
testStoreResetAccessAll(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservations(t *testing.T) {
|
||||
testStoreReservations(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservationsCount(t *testing.T) {
|
||||
testStoreReservationsCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreHasReservation(t *testing.T) {
|
||||
testStoreHasReservation(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservationOwner(t *testing.T) {
|
||||
testStoreReservationOwner(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTiers(t *testing.T) {
|
||||
testStoreTiers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierUpdate(t *testing.T) {
|
||||
testStoreTierUpdate(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierRemove(t *testing.T) {
|
||||
testStoreTierRemove(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierByStripePrice(t *testing.T) {
|
||||
testStoreTierByStripePrice(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeTier(t *testing.T) {
|
||||
testStoreChangeTier(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStorePhoneNumbers(t *testing.T) {
|
||||
testStorePhoneNumbers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeSettings(t *testing.T) {
|
||||
testStoreChangeSettings(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeBilling(t *testing.T) {
|
||||
testStoreChangeBilling(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpdateStats(t *testing.T) {
|
||||
testStoreUpdateStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetStats(t *testing.T) {
|
||||
testStoreResetStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
|
||||
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
|
||||
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllGrants(t *testing.T) {
|
||||
testStoreAllGrants(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreOtherAccessCount(t *testing.T) {
|
||||
testStoreOtherAccessCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
273
user/store_sqlite.go
Normal file
273
user/store_sqlite.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
// User queries
|
||||
sqliteSelectUserByID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
sqliteSelectUserByName = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
sqliteSelectUserByToken = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||
`
|
||||
sqliteSelectUserByStripeID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
`
|
||||
sqliteSelectUsernames = `
|
||||
SELECT user
|
||||
FROM user
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
|
||||
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
|
||||
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
|
||||
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
|
||||
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
|
||||
|
||||
// Access queries
|
||||
sqliteSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN user u ON u.id = a.user_id
|
||||
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||
`
|
||||
sqliteSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
sqliteSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = ?
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
sqliteSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||
`
|
||||
sqliteUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||
`
|
||||
sqliteDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
|
||||
sqliteDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||
sqliteUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||
`
|
||||
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
|
||||
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
|
||||
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
sqliteDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = ?
|
||||
ORDER BY expires DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
sqliteInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
sqliteSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||
`
|
||||
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
|
||||
|
||||
// Phone queries
|
||||
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||
|
||||
// Billing queries
|
||||
sqliteUpdateBilling = `
|
||||
UPDATE user
|
||||
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||
WHERE user = ?
|
||||
`
|
||||
)
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed user store
|
||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
selectUserByID: sqliteSelectUserByID,
|
||||
selectUserByName: sqliteSelectUserByName,
|
||||
selectUserByToken: sqliteSelectUserByToken,
|
||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||
selectUsernames: sqliteSelectUsernames,
|
||||
selectUserCount: sqliteSelectUserCount,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||
insertUser: sqliteInsertUser,
|
||||
updateUserPass: sqliteUpdateUserPass,
|
||||
updateUserRole: sqliteUpdateUserRole,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||
updateUserStats: sqliteUpdateUserStats,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||
updateUserTier: sqliteUpdateUserTier,
|
||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||
deleteUser: sqliteDeleteUser,
|
||||
deleteUserTier: sqliteDeleteUserTier,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||
selectTopicPerms: sqliteSelectTopicPerms,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||
selectUserAccess: sqliteSelectUserAccess,
|
||||
selectUserReservations: sqliteSelectUserReservations,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||
upsertUserAccess: sqliteUpsertUserAccess,
|
||||
deleteUserAccess: sqliteDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||
deleteAllAccess: sqliteDeleteAllAccess,
|
||||
selectToken: sqliteSelectToken,
|
||||
selectTokens: sqliteSelectTokens,
|
||||
selectTokenCount: sqliteSelectTokenCount,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||
upsertToken: sqliteUpsertToken,
|
||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||
deleteToken: sqliteDeleteToken,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||
deleteAllToken: sqliteDeleteAllToken,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||
insertTier: sqliteInsertTier,
|
||||
selectTiers: sqliteSelectTiers,
|
||||
selectTierByCode: sqliteSelectTierByCode,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||
updateTier: sqliteUpdateTier,
|
||||
deleteTier: sqliteDeleteTier,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||
updateBilling: sqliteUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
483
user/store_sqlite_schema.go
Normal file
483
user/store_sqlite_schema.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
|
||||
)
|
||||
|
||||
// Schema version table management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 6
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 1 -> 2 (complex migration!)
|
||||
sqliteMigrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
sqliteMigrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
JOIN access a ON u.user = a.user;
|
||||
|
||||
DROP TABLE access;
|
||||
DROP TABLE user_old;
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
sqliteMigrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
DROP INDEX IF EXISTS idx_tier_price_id;
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
`
|
||||
|
||||
// 3 -> 4
|
||||
sqliteMigrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
sqliteMigrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
sqliteMigrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
INSERT INTO user
|
||||
SELECT
|
||||
id,
|
||||
tier_id,
|
||||
user,
|
||||
pass,
|
||||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
stripe_customer_id,
|
||||
stripe_subscription_id,
|
||||
stripe_subscription_status,
|
||||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
-- Alter user_access table: Add provisioned column
|
||||
ALTER TABLE user_access RENAME TO user_access_old;
|
||||
CREATE TABLE user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INTEGER NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom2(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom3(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom4(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom5(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
180
user/store_sqlite_test.go
Normal file
180
user/store_sqlite_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
func newTestSQLiteStore(t *testing.T) user.Store {
|
||||
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAddUser(t *testing.T) {
|
||||
testStoreAddUser(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
|
||||
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveUser(t *testing.T) {
|
||||
testStoreRemoveUser(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByID(t *testing.T) {
|
||||
testStoreUserByID(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByToken(t *testing.T) {
|
||||
testStoreUserByToken(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
|
||||
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUsers(t *testing.T) {
|
||||
testStoreUsers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUsersCount(t *testing.T) {
|
||||
testStoreUsersCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangePassword(t *testing.T) {
|
||||
testStoreChangePassword(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeRole(t *testing.T) {
|
||||
testStoreChangeRole(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokens(t *testing.T) {
|
||||
testStoreTokens(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
|
||||
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemove(t *testing.T) {
|
||||
testStoreTokenRemove(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
|
||||
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
|
||||
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
|
||||
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllowAccess(t *testing.T) {
|
||||
testStoreAllowAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
|
||||
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetAccess(t *testing.T) {
|
||||
testStoreResetAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetAccessAll(t *testing.T) {
|
||||
testStoreResetAccessAll(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservations(t *testing.T) {
|
||||
testStoreReservations(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservationsCount(t *testing.T) {
|
||||
testStoreReservationsCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreHasReservation(t *testing.T) {
|
||||
testStoreHasReservation(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservationOwner(t *testing.T) {
|
||||
testStoreReservationOwner(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTiers(t *testing.T) {
|
||||
testStoreTiers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierUpdate(t *testing.T) {
|
||||
testStoreTierUpdate(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierRemove(t *testing.T) {
|
||||
testStoreTierRemove(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
|
||||
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeTier(t *testing.T) {
|
||||
testStoreChangeTier(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStorePhoneNumbers(t *testing.T) {
|
||||
testStorePhoneNumbers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeSettings(t *testing.T) {
|
||||
testStoreChangeSettings(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeBilling(t *testing.T) {
|
||||
testStoreChangeBilling(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpdateStats(t *testing.T) {
|
||||
testStoreUpdateStats(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetStats(t *testing.T) {
|
||||
testStoreResetStats(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
|
||||
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
|
||||
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllGrants(t *testing.T) {
|
||||
testStoreAllGrants(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
|
||||
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
619
user/store_test.go
Normal file
619
user/store_test.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
func testStoreAddUser(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
require.Equal(t, user.RoleUser, u.Role)
|
||||
require.False(t, u.Provisioned)
|
||||
require.NotEmpty(t, u.ID)
|
||||
require.NotEmpty(t, u.SyncTopic)
|
||||
}
|
||||
|
||||
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
}
|
||||
|
||||
func testStoreRemoveUser(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
|
||||
require.Nil(t, store.RemoveUser("phil"))
|
||||
_, err = store.User("phil")
|
||||
require.Equal(t, user.ErrUserNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreUserByID(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
u2, err := store.UserByID(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, u.Name, u2.Name)
|
||||
require.Equal(t, u.ID, u2.ID)
|
||||
}
|
||||
|
||||
func testStoreUserByToken(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_test123", tk.Value)
|
||||
|
||||
u2, err := store.UserByToken(tk.Value)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u2.Name)
|
||||
}
|
||||
|
||||
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
|
||||
StripeCustomerID: "cus_test123",
|
||||
StripeSubscriptionID: "sub_test123",
|
||||
}))
|
||||
|
||||
u, err := store.UserByStripeCustomer("cus_test123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
|
||||
}
|
||||
|
||||
func testStoreUsers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
|
||||
|
||||
users, err := store.Users()
|
||||
require.Nil(t, err)
|
||||
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
|
||||
}
|
||||
|
||||
func testStoreUsersCount(t *testing.T, store user.Store) {
|
||||
count, err := store.UsersCount()
|
||||
require.Nil(t, err)
|
||||
require.True(t, count >= 1) // At least the everyone user
|
||||
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
count2, err := store.UsersCount()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, count+1, count2)
|
||||
}
|
||||
|
||||
func testStoreChangePassword(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "philhash", u.Hash)
|
||||
|
||||
require.Nil(t, store.ChangePassword("phil", "newhash"))
|
||||
u, err = store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "newhash", u.Hash)
|
||||
}
|
||||
|
||||
func testStoreChangeRole(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.RoleUser, u.Role)
|
||||
|
||||
require.Nil(t, store.ChangeRole("phil", user.RoleAdmin))
|
||||
u, err = store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.RoleAdmin, u.Role)
|
||||
}
|
||||
|
||||
func testStoreTokens(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
now := time.Now()
|
||||
expires := now.Add(24 * time.Hour)
|
||||
origin := netip.MustParseAddr("9.9.9.9")
|
||||
|
||||
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_abc", tk.Value)
|
||||
require.Equal(t, "my token", tk.Label)
|
||||
|
||||
// Get single token
|
||||
tk2, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_abc", tk2.Value)
|
||||
require.Equal(t, "my token", tk2.Label)
|
||||
|
||||
// Get all tokens
|
||||
tokens, err := store.Tokens(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, "tk_abc", tokens[0].Value)
|
||||
|
||||
// Token count
|
||||
count, err := store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
|
||||
tk, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "new label", tk.Label)
|
||||
}
|
||||
|
||||
func testStoreTokenRemove(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
|
||||
_, err = store.Token(u.ID, "tk_abc")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create expired token and active token
|
||||
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.RemoveExpiredTokens())
|
||||
|
||||
// Expired token should be gone
|
||||
_, err = store.Token(u.ID, "tk_expired")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
|
||||
// Active token should still exist
|
||||
tk, err := store.Token(u.ID, "tk_active")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_active", tk.Value)
|
||||
}
|
||||
|
||||
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create 3 tokens with increasing expiry
|
||||
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
|
||||
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
count, err := store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, count)
|
||||
|
||||
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
|
||||
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
|
||||
|
||||
count, err = store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
// tk_a should be removed (earliest expiry)
|
||||
_, err = store.Token(u.ID, "tk_a")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
|
||||
// tk_b and tk_c should remain
|
||||
_, err = store.Token(u.ID, "tk_b")
|
||||
require.Nil(t, err)
|
||||
_, err = store.Token(u.ID, "tk_c")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
newTime := time.Now().Add(5 * time.Minute)
|
||||
newOrigin := netip.MustParseAddr("5.5.5.5")
|
||||
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin))
|
||||
|
||||
tk, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
|
||||
require.Equal(t, newOrigin, tk.LastOrigin)
|
||||
}
|
||||
|
||||
func testStoreAllowAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.Equal(t, "mytopic", grants[0].TopicPattern)
|
||||
require.True(t, grants[0].Permission.IsReadWrite())
|
||||
}
|
||||
|
||||
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.True(t, grants[0].Permission.IsRead())
|
||||
require.False(t, grants[0].Permission.IsWrite())
|
||||
}
|
||||
|
||||
func testStoreResetAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 2)
|
||||
|
||||
require.Nil(t, store.ResetAccess("phil", "topic1"))
|
||||
grants, err = store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.Equal(t, "topic2", grants[0].TopicPattern)
|
||||
}
|
||||
|
||||
func testStoreResetAccessAll(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||
|
||||
require.Nil(t, store.ResetAccess("phil", ""))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 0)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||
|
||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.True(t, found)
|
||||
require.True(t, read)
|
||||
require.True(t, write)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
|
||||
require.Nil(t, err)
|
||||
require.False(t, found)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
|
||||
|
||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "secret")
|
||||
require.Nil(t, err)
|
||||
require.True(t, found)
|
||||
require.False(t, read)
|
||||
require.False(t, write)
|
||||
}
|
||||
|
||||
func testStoreReservations(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
|
||||
|
||||
reservations, err := store.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reservations, 1)
|
||||
require.Equal(t, "mytopic", reservations[0].Topic)
|
||||
require.True(t, reservations[0].Owner.IsReadWrite())
|
||||
require.True(t, reservations[0].Everyone.IsRead())
|
||||
require.False(t, reservations[0].Everyone.IsWrite())
|
||||
}
|
||||
|
||||
func testStoreReservationsCount(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
|
||||
|
||||
count, err := store.ReservationsCount("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func testStoreHasReservation(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
|
||||
has, err := store.HasReservation("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.True(t, has)
|
||||
|
||||
has, err = store.HasReservation("phil", "other")
|
||||
require.Nil(t, err)
|
||||
require.False(t, has)
|
||||
}
|
||||
|
||||
func testStoreReservationOwner(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
|
||||
owner, err := store.ReservationOwner("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, owner) // Returns the user ID
|
||||
|
||||
owner, err = store.ReservationOwner("unowned")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, owner)
|
||||
}
|
||||
|
||||
func testStoreTiers(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessageLimit: 5000,
|
||||
MessageExpiryDuration: 24 * time.Hour,
|
||||
EmailLimit: 100,
|
||||
CallLimit: 10,
|
||||
ReservationLimit: 20,
|
||||
AttachmentFileSizeLimit: 10 * 1024 * 1024,
|
||||
AttachmentTotalSizeLimit: 100 * 1024 * 1024,
|
||||
AttachmentExpiryDuration: 48 * time.Hour,
|
||||
AttachmentBandwidthLimit: 500 * 1024 * 1024,
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
// Get by code
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "ti_test", t2.ID)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
require.Equal(t, "Pro", t2.Name)
|
||||
require.Equal(t, int64(5000), t2.MessageLimit)
|
||||
require.Equal(t, int64(100), t2.EmailLimit)
|
||||
require.Equal(t, int64(10), t2.CallLimit)
|
||||
require.Equal(t, int64(20), t2.ReservationLimit)
|
||||
|
||||
// List all tiers
|
||||
tiers, err := store.Tiers()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, tiers, 1)
|
||||
require.Equal(t, "pro", tiers[0].Code)
|
||||
}
|
||||
|
||||
func testStoreTierUpdate(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
tier.Name = "Professional"
|
||||
tier.MessageLimit = 9999
|
||||
require.Nil(t, store.UpdateTier(tier))
|
||||
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Professional", t2.Name)
|
||||
require.Equal(t, int64(9999), t2.MessageLimit)
|
||||
}
|
||||
|
||||
func testStoreTierRemove(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
|
||||
require.Nil(t, store.RemoveTier("pro"))
|
||||
_, err = store.Tier("pro")
|
||||
require.Equal(t, user.ErrTierNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
StripeMonthlyPriceID: "price_monthly",
|
||||
StripeYearlyPriceID: "price_yearly",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
t2, err := store.TierByStripePrice("price_monthly")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
|
||||
t3, err := store.TierByStripePrice("price_yearly")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t3.Code)
|
||||
}
|
||||
|
||||
func testStoreChangeTier(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.ChangeTier("phil", "pro"))
|
||||
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, u.Tier)
|
||||
require.Equal(t, "pro", u.Tier.Code)
|
||||
}
|
||||
|
||||
func testStorePhoneNumbers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890"))
|
||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321"))
|
||||
|
||||
numbers, err := store.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, numbers, 2)
|
||||
|
||||
require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890"))
|
||||
numbers, err = store.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, numbers, 1)
|
||||
require.Equal(t, "+0987654321", numbers[0])
|
||||
}
|
||||
|
||||
func testStoreChangeSettings(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
lang := "de"
|
||||
prefs := &user.Prefs{Language: &lang}
|
||||
require.Nil(t, store.ChangeSettings(u.ID, prefs))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, u2.Prefs)
|
||||
require.Equal(t, "de", *u2.Prefs.Language)
|
||||
}
|
||||
|
||||
func testStoreChangeBilling(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: "cus_123",
|
||||
StripeSubscriptionID: "sub_456",
|
||||
}
|
||||
require.Nil(t, store.ChangeBilling("phil", billing))
|
||||
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
|
||||
}
|
||||
|
||||
func testStoreUpdateStats(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
|
||||
require.Nil(t, store.UpdateStats(u.ID, stats))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(42), u2.Stats.Messages)
|
||||
require.Equal(t, int64(3), u2.Stats.Emails)
|
||||
require.Equal(t, int64(1), u2.Stats.Calls)
|
||||
}
|
||||
|
||||
func testStoreResetStats(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
|
||||
require.Nil(t, store.ResetStats())
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), u2.Stats.Messages)
|
||||
require.Equal(t, int64(0), u2.Stats.Emails)
|
||||
require.Equal(t, int64(0), u2.Stats.Calls)
|
||||
}
|
||||
|
||||
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.True(t, u2.Deleted)
|
||||
}
|
||||
|
||||
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||
|
||||
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
|
||||
// Immediately after marking, the user should still exist.
|
||||
require.Nil(t, store.RemoveDeletedUsers())
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.True(t, u2.Deleted)
|
||||
}
|
||||
|
||||
func testStoreAllGrants(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||
phil, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
ben, err := store.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false))
|
||||
|
||||
grants, err := store.AllGrants()
|
||||
require.Nil(t, err)
|
||||
require.Contains(t, grants, phil.ID)
|
||||
require.Contains(t, grants, ben.ID)
|
||||
}
|
||||
|
||||
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
|
||||
|
||||
count, err := store.OtherAccessCount("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
@@ -242,6 +242,20 @@ const (
|
||||
everyoneID = "u_everyone"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the user Manager
|
||||
type Config struct {
|
||||
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
|
||||
DatabaseURL string // Database connection string (PostgreSQL)
|
||||
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
|
||||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
|
||||
// Error constants used by the package
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
|
||||
40
user/util.go
40
user/util.go
@@ -1,10 +1,12 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -77,3 +79,37 @@ func hashPassword(password string, cost int) (string, error) {
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
func nullInt64(v int64) sql.NullInt64 {
|
||||
if v == 0 {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
return sql.NullInt64{Int64: v, Valid: true}
|
||||
}
|
||||
|
||||
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
|
||||
// and escapes '_', assuming '\' as escape character.
|
||||
func toSQLWildcard(s string) string {
|
||||
return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
|
||||
}
|
||||
|
||||
// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
|
||||
// and removes the '\_' escape character.
|
||||
func fromSQLWildcard(s string) string {
|
||||
return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
|
||||
}
|
||||
|
||||
func escapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "_", "\\_")
|
||||
}
|
||||
|
||||
func unescapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "\\_", "_")
|
||||
}
|
||||
|
||||
281
user/util_test.go
Normal file
281
user/util_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAllowedRole(t *testing.T) {
|
||||
require.True(t, AllowedRole(RoleUser))
|
||||
require.True(t, AllowedRole(RoleAdmin))
|
||||
require.False(t, AllowedRole(RoleAnonymous))
|
||||
require.False(t, AllowedRole(Role("invalid")))
|
||||
require.False(t, AllowedRole(Role("")))
|
||||
require.False(t, AllowedRole(Role("superadmin")))
|
||||
}
|
||||
|
||||
func TestAllowedTopic(t *testing.T) {
|
||||
// Valid topics
|
||||
require.True(t, AllowedTopic("test"))
|
||||
require.True(t, AllowedTopic("mytopic"))
|
||||
require.True(t, AllowedTopic("topic123"))
|
||||
require.True(t, AllowedTopic("my-topic"))
|
||||
require.True(t, AllowedTopic("my_topic"))
|
||||
require.True(t, AllowedTopic("Topic123"))
|
||||
require.True(t, AllowedTopic("a"))
|
||||
require.True(t, AllowedTopic(strings.Repeat("a", 64))) // Max length
|
||||
|
||||
// Invalid topics - wildcards not allowed
|
||||
require.False(t, AllowedTopic("topic*"))
|
||||
require.False(t, AllowedTopic("*"))
|
||||
require.False(t, AllowedTopic("my*topic"))
|
||||
|
||||
// Invalid topics - special characters
|
||||
require.False(t, AllowedTopic("my topic")) // Space
|
||||
require.False(t, AllowedTopic("my.topic")) // Dot
|
||||
require.False(t, AllowedTopic("my/topic")) // Slash
|
||||
require.False(t, AllowedTopic("my@topic")) // At sign
|
||||
require.False(t, AllowedTopic("my+topic")) // Plus
|
||||
require.False(t, AllowedTopic("topic!")) // Exclamation
|
||||
require.False(t, AllowedTopic("topic#")) // Hash
|
||||
require.False(t, AllowedTopic("topic$")) // Dollar
|
||||
require.False(t, AllowedTopic("topic%")) // Percent
|
||||
require.False(t, AllowedTopic("topic&")) // Ampersand
|
||||
require.False(t, AllowedTopic("my\\topic")) // Backslash
|
||||
|
||||
// Invalid topics - length
|
||||
require.False(t, AllowedTopic("")) // Empty
|
||||
require.False(t, AllowedTopic(strings.Repeat("a", 65))) // Too long
|
||||
}
|
||||
|
||||
func TestAllowedTopicPattern(t *testing.T) {
|
||||
// Valid patterns - same as AllowedTopic
|
||||
require.True(t, AllowedTopicPattern("test"))
|
||||
require.True(t, AllowedTopicPattern("mytopic"))
|
||||
require.True(t, AllowedTopicPattern("topic123"))
|
||||
require.True(t, AllowedTopicPattern("my-topic"))
|
||||
require.True(t, AllowedTopicPattern("my_topic"))
|
||||
require.True(t, AllowedTopicPattern("a"))
|
||||
require.True(t, AllowedTopicPattern(strings.Repeat("a", 64))) // Max length
|
||||
|
||||
// Valid patterns - with wildcards
|
||||
require.True(t, AllowedTopicPattern("*"))
|
||||
require.True(t, AllowedTopicPattern("topic*"))
|
||||
require.True(t, AllowedTopicPattern("*topic"))
|
||||
require.True(t, AllowedTopicPattern("my*topic"))
|
||||
require.True(t, AllowedTopicPattern("***"))
|
||||
require.True(t, AllowedTopicPattern("test_*"))
|
||||
require.True(t, AllowedTopicPattern("my-*-topic"))
|
||||
require.True(t, AllowedTopicPattern(strings.Repeat("*", 64))) // Max length with wildcards
|
||||
|
||||
// Invalid patterns - special characters (other than wildcard)
|
||||
require.False(t, AllowedTopicPattern("my topic")) // Space
|
||||
require.False(t, AllowedTopicPattern("my.topic")) // Dot
|
||||
require.False(t, AllowedTopicPattern("my/topic")) // Slash
|
||||
require.False(t, AllowedTopicPattern("my@topic")) // At sign
|
||||
require.False(t, AllowedTopicPattern("my+topic")) // Plus
|
||||
require.False(t, AllowedTopicPattern("topic!")) // Exclamation
|
||||
require.False(t, AllowedTopicPattern("topic#")) // Hash
|
||||
require.False(t, AllowedTopicPattern("topic$")) // Dollar
|
||||
require.False(t, AllowedTopicPattern("topic%")) // Percent
|
||||
require.False(t, AllowedTopicPattern("topic&")) // Ampersand
|
||||
require.False(t, AllowedTopicPattern("my\\topic")) // Backslash
|
||||
|
||||
// Invalid patterns - length
|
||||
require.False(t, AllowedTopicPattern("")) // Empty
|
||||
require.False(t, AllowedTopicPattern(strings.Repeat("a", 65))) // Too long
|
||||
}
|
||||
|
||||
func TestValidPasswordHash(t *testing.T) {
|
||||
// Valid bcrypt hashes with different versions
|
||||
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Nil(t, ValidPasswordHash("$2b$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", 10))
|
||||
require.Nil(t, ValidPasswordHash("$2y$12$1234567890123456789012u1234567890123456789012345678901", 10))
|
||||
|
||||
// Valid hash with minimum cost
|
||||
require.Nil(t, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 4))
|
||||
|
||||
// Invalid - wrong prefix
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$2c$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$3a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("bcrypt$10$hash", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("nothash", 10))
|
||||
require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("", 10))
|
||||
|
||||
// Invalid - malformed hash
|
||||
require.NotNil(t, ValidPasswordHash("$2a$10$tooshort", 10))
|
||||
require.NotNil(t, ValidPasswordHash("$2a$10", 10))
|
||||
require.NotNil(t, ValidPasswordHash("$2a$", 10))
|
||||
|
||||
// Invalid - cost too low
|
||||
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 10))
|
||||
require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$09$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
|
||||
// Edge case - cost exactly at minimum
|
||||
require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10))
|
||||
}
|
||||
|
||||
func TestValidToken(t *testing.T) {
|
||||
// Valid tokens
|
||||
require.True(t, ValidToken("tk_1234567890123456789012345678x"))
|
||||
require.True(t, ValidToken("tk_abcdefghijklmnopqrstuvwxyzabc"))
|
||||
require.True(t, ValidToken("tk_ABCDEFGHIJKLMNOPQRSTUVWXYZABC"))
|
||||
require.True(t, ValidToken("tk_012345678901234567890123456ab"))
|
||||
require.True(t, ValidToken("tk_-----------------------------"))
|
||||
require.True(t, ValidToken("tk______________________________"))
|
||||
|
||||
// Invalid tokens - wrong prefix
|
||||
require.False(t, ValidToken("tx_1234567890123456789012345678x"))
|
||||
require.False(t, ValidToken("tk1234567890123456789012345678xy"))
|
||||
require.False(t, ValidToken("token_1234567890123456789012345"))
|
||||
|
||||
// Invalid tokens - wrong length
|
||||
require.False(t, ValidToken("tk_")) // Too short
|
||||
require.False(t, ValidToken("tk_123")) // Too short
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567890")) // Too long (30 chars after prefix)
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567")) // Too short (28 chars)
|
||||
|
||||
// Invalid tokens - invalid characters
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567!@"))
|
||||
require.False(t, ValidToken("tk_12345678901234567890123456 8x"))
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567.x"))
|
||||
require.False(t, ValidToken("tk_123456789012345678901234567*x"))
|
||||
|
||||
// Invalid tokens - no prefix
|
||||
require.False(t, ValidToken("1234567890123456789012345678901x"))
|
||||
require.False(t, ValidToken(""))
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
// Generate multiple tokens
|
||||
tokens := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
token := GenerateToken()
|
||||
|
||||
// Check format
|
||||
require.True(t, strings.HasPrefix(token, "tk_"), "Token should start with tk_")
|
||||
require.Equal(t, 32, len(token), "Token should be 32 characters long")
|
||||
|
||||
// Check it's valid
|
||||
require.True(t, ValidToken(token), "Generated token should be valid")
|
||||
|
||||
// Check it's lowercase
|
||||
require.Equal(t, strings.ToLower(token), token, "Token should be lowercase")
|
||||
|
||||
// Check uniqueness
|
||||
require.False(t, tokens[token], "Token should be unique")
|
||||
tokens[token] = true
|
||||
}
|
||||
|
||||
// Verify we got 100 unique tokens
|
||||
require.Equal(t, 100, len(tokens))
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
password := "test-password-123"
|
||||
|
||||
// Hash the password
|
||||
hash, err := HashPassword(password)
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
|
||||
// Check it's a valid bcrypt hash
|
||||
require.Nil(t, ValidPasswordHash(hash, DefaultUserPasswordBcryptCost))
|
||||
|
||||
// Check it starts with correct prefix
|
||||
require.True(t, strings.HasPrefix(hash, "$2a$"))
|
||||
|
||||
// Hash the same password again - should produce different hash
|
||||
hash2, err := HashPassword(password)
|
||||
require.Nil(t, err)
|
||||
require.NotEqual(t, hash, hash2, "Same password should produce different hashes (salt)")
|
||||
|
||||
// Empty password should still work
|
||||
emptyHash, err := HashPassword("")
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, emptyHash)
|
||||
require.Nil(t, ValidPasswordHash(emptyHash, DefaultUserPasswordBcryptCost))
|
||||
}
|
||||
|
||||
func TestHashPassword_WithCost(t *testing.T) {
|
||||
password := "test-password"
|
||||
|
||||
// Test with different costs
|
||||
hash4, err := hashPassword(password, 4)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash4, "$2a$04$"))
|
||||
|
||||
hash10, err := hashPassword(password, 10)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash10, "$2a$10$"))
|
||||
|
||||
hash12, err := hashPassword(password, 12)
|
||||
require.Nil(t, err)
|
||||
require.True(t, strings.HasPrefix(hash12, "$2a$12$"))
|
||||
|
||||
// All should be valid
|
||||
require.Nil(t, ValidPasswordHash(hash4, 4))
|
||||
require.Nil(t, ValidPasswordHash(hash10, 10))
|
||||
require.Nil(t, ValidPasswordHash(hash12, 12))
|
||||
}
|
||||
|
||||
func TestUser_TierID(t *testing.T) {
|
||||
// User with tier
|
||||
u := &User{
|
||||
Tier: &Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
},
|
||||
}
|
||||
require.Equal(t, "ti_123", u.TierID())
|
||||
|
||||
// User without tier
|
||||
u2 := &User{
|
||||
Tier: nil,
|
||||
}
|
||||
require.Equal(t, "", u2.TierID())
|
||||
|
||||
// Nil user
|
||||
var u3 *User
|
||||
require.Equal(t, "", u3.TierID())
|
||||
}
|
||||
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
admin := &User{Role: RoleAdmin}
|
||||
require.True(t, admin.IsAdmin())
|
||||
require.False(t, admin.IsUser())
|
||||
|
||||
user := &User{Role: RoleUser}
|
||||
require.False(t, user.IsAdmin())
|
||||
|
||||
anonymous := &User{Role: RoleAnonymous}
|
||||
require.False(t, anonymous.IsAdmin())
|
||||
|
||||
// Nil user
|
||||
var nilUser *User
|
||||
require.False(t, nilUser.IsAdmin())
|
||||
}
|
||||
|
||||
func TestUser_IsUser(t *testing.T) {
|
||||
user := &User{Role: RoleUser}
|
||||
require.True(t, user.IsUser())
|
||||
require.False(t, user.IsAdmin())
|
||||
|
||||
admin := &User{Role: RoleAdmin}
|
||||
require.False(t, admin.IsUser())
|
||||
|
||||
anonymous := &User{Role: RoleAnonymous}
|
||||
require.False(t, anonymous.IsUser())
|
||||
|
||||
// Nil user
|
||||
var nilUser *User
|
||||
require.False(t, nilUser.IsUser())
|
||||
}
|
||||
|
||||
func TestPermission_String(t *testing.T) {
|
||||
require.Equal(t, "read-write", PermissionReadWrite.String())
|
||||
require.Equal(t, "read-only", PermissionRead.String())
|
||||
require.Equal(t, "write-only", PermissionWrite.String())
|
||||
require.Equal(t, "deny-all", PermissionDenyAll.String())
|
||||
}
|
||||
@@ -406,5 +406,7 @@
|
||||
"web_push_unknown_notification_title": "Neznáme oznámenie prijaté zo servera",
|
||||
"web_push_unknown_notification_body": "Možno budete musieť aktualizovať ntfy otvorením webovej aplikácie",
|
||||
"alert_notification_permission_required_title": "Oznámenia sú vypnuté",
|
||||
"alert_notification_ios_install_required_description": "Kliknutím na Zdieľať a Pridať na domovskú obrazovku povolíte oznámenia v systéme iOS"
|
||||
"alert_notification_ios_install_required_description": "Kliknutím na Zdieľať a Pridať na domovskú obrazovku povolíte oznámenia v systéme iOS",
|
||||
"account_basics_cannot_edit_or_delete_provisioned_user": "Prideleného používateľa nemožno upraviť ani odstrániť",
|
||||
"account_tokens_table_cannot_delete_or_edit_provisioned_token": "Pridelený token nemožno upraviť ani odstrániť"
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { NavigationRoute, registerRoute } from "workbox-routing";
|
||||
import { NetworkFirst } from "workbox-strategies";
|
||||
import { clientsClaim } from "workbox-core";
|
||||
import { dbAsync } from "../src/app/db";
|
||||
import { ACTION_COPY, ACTION_HTTP, ACTION_VIEW } from "../src/app/actions";
|
||||
import { ACTION_HTTP, ACTION_VIEW } from "../src/app/actions";
|
||||
import { badge, icon, messageWithSequenceId, notificationTag, toNotificationParams } from "../src/app/notificationUtils";
|
||||
import initI18n from "../src/app/i18n";
|
||||
import {
|
||||
@@ -256,26 +256,6 @@ const handleClick = async (event) => {
|
||||
if (action.clear) {
|
||||
await clearNotification();
|
||||
}
|
||||
} else if (action.action === ACTION_COPY) {
|
||||
try {
|
||||
// Service worker can't access the clipboard API directly, so we try to
|
||||
// open a focused client and use it, or fall back to opening a window
|
||||
const allClients = await self.clients.matchAll({ type: "window" });
|
||||
const focusedClient = allClients.find((c) => c.focused) || allClients[0];
|
||||
if (focusedClient) {
|
||||
focusedClient.postMessage({ type: "copy", value: action.value });
|
||||
}
|
||||
if (action.clear) {
|
||||
await clearNotification();
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("[ServiceWorker] Error performing copy action", e);
|
||||
self.registration.showNotification(`${t("notifications_actions_failed_notification")}: ${action.label} (${action.action})`, {
|
||||
body: e.message,
|
||||
icon,
|
||||
badge,
|
||||
});
|
||||
}
|
||||
} else if (action.action === ACTION_HTTP) {
|
||||
try {
|
||||
const response = await fetch(action.url, {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// and cannot be used in the service worker
|
||||
|
||||
import emojisMapped from "./emojisMapped";
|
||||
import { ACTION_COPY, ACTION_HTTP, ACTION_VIEW } from "./actions";
|
||||
import { ACTION_HTTP, ACTION_VIEW } from "./actions";
|
||||
|
||||
const toEmojis = (tags) => {
|
||||
if (!tags) return [];
|
||||
@@ -82,7 +82,7 @@ export const toNotificationParams = ({ message, defaultTitle, topicRoute, baseUr
|
||||
topicRoute,
|
||||
},
|
||||
actions: message.actions
|
||||
?.filter(({ action }) => action === ACTION_VIEW || action === ACTION_HTTP || action === ACTION_COPY)
|
||||
?.filter(({ action }) => action === ACTION_VIEW || action === ACTION_HTTP)
|
||||
.map(({ label }) => ({
|
||||
action: label,
|
||||
title: label,
|
||||
|
||||
@@ -12,15 +12,6 @@ const registerSW = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Listen for messages from the service worker (e.g., "copy" action)
|
||||
navigator.serviceWorker.addEventListener("message", (event) => {
|
||||
if (event.data?.type === "copy" && event.data?.value) {
|
||||
navigator.clipboard?.writeText(event.data.value).catch((e) => {
|
||||
console.error("[ServiceWorker] Failed to copy to clipboard", e);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
viteRegisterSW({
|
||||
onRegisteredSW(swUrl, registration) {
|
||||
console.log("[ServiceWorker] Registered:", { swUrl, registration });
|
||||
|
||||
188
webpush/store.go
Normal file
188
webpush/store.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
// Errors returned by the store
|
||||
var (
|
||||
ErrWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
// Store is the interface for a web push subscription store.
|
||||
type Store interface {
|
||||
UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
|
||||
SubscriptionsForTopic(topic string) ([]*Subscription, error)
|
||||
SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error)
|
||||
MarkExpiryWarningSent(subscriptions []*Subscription) error
|
||||
RemoveSubscriptionsByEndpoint(endpoint string) error
|
||||
RemoveSubscriptionsByUserID(userID string) error
|
||||
RemoveExpiredSubscriptions(expireAfter time.Duration) error
|
||||
SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries.
|
||||
type storeQueries struct {
|
||||
selectSubscriptionIDByEndpoint string
|
||||
selectSubscriptionCountBySubscriberIP string
|
||||
selectSubscriptionsForTopic string
|
||||
selectSubscriptionsExpiringSoon string
|
||||
insertSubscription string
|
||||
updateSubscriptionWarningSent string
|
||||
updateSubscriptionUpdatedAt string
|
||||
deleteSubscriptionByEndpoint string
|
||||
deleteSubscriptionByUserID string
|
||||
deleteSubscriptionByAge string
|
||||
insertSubscriptionTopic string
|
||||
deleteSubscriptionTopicAll string
|
||||
deleteSubscriptionTopicWithoutSubscription string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that are identical across database backends.
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
||||
func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
var subscriptionCount int
|
||||
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
var subscriptionID string
|
||||
err = tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return ErrWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(s.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||
func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
||||
func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
||||
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
||||
func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
||||
func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
|
||||
func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return ErrWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByUserID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
||||
func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(s.queries.deleteSubscriptionTopicWithoutSubscription)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
||||
// exported for testing purposes and is not part of the Store interface.
|
||||
func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
|
||||
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection.
|
||||
func (s *commonStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
|
||||
subscriptions := make([]*Subscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &Subscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
130
webpush/store_postgres.go
Normal file
130
webpush/store_postgres.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
const (
|
||||
pgCreateTablesQuery = `
|
||||
CREATE TABLE IF NOT EXISTS webpush_subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL UNIQUE,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at BIGINT NOT NULL,
|
||||
warned_at BIGINT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
|
||||
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
pgSelectSubscriptionIDByEndpoint = `SELECT id FROM webpush_subscription WHERE endpoint = $1`
|
||||
pgSelectSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1`
|
||||
pgSelectSubscriptionsForTopicQuery = `
|
||||
SELECT s.id, s.endpoint, s.key_auth, s.key_p256dh, s.user_id
|
||||
FROM webpush_subscription_topic st
|
||||
JOIN webpush_subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = $1
|
||||
ORDER BY s.endpoint
|
||||
`
|
||||
pgSelectSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM webpush_subscription
|
||||
WHERE warned_at = 0 AND updated_at <= $1
|
||||
`
|
||||
pgInsertSubscriptionQuery = `
|
||||
INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
pgUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
|
||||
pgUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
|
||||
pgDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1`
|
||||
pgDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1`
|
||||
pgDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1`
|
||||
|
||||
pgInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)`
|
||||
pgDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1`
|
||||
pgDeleteSubscriptionTopicWithoutSubscription = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)`
|
||||
)
|
||||
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 1
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
||||
func NewPostgresStore(dsn string) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgresDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
selectSubscriptionIDByEndpoint: pgSelectSubscriptionIDByEndpoint,
|
||||
selectSubscriptionCountBySubscriberIP: pgSelectSubscriptionCountBySubscriberIP,
|
||||
selectSubscriptionsForTopic: pgSelectSubscriptionsForTopicQuery,
|
||||
selectSubscriptionsExpiringSoon: pgSelectSubscriptionsExpiringSoonQuery,
|
||||
insertSubscription: pgInsertSubscriptionQuery,
|
||||
updateSubscriptionWarningSent: pgUpdateSubscriptionWarningSentQuery,
|
||||
updateSubscriptionUpdatedAt: pgUpdateSubscriptionUpdatedAtQuery,
|
||||
deleteSubscriptionByEndpoint: pgDeleteSubscriptionByEndpointQuery,
|
||||
deleteSubscriptionByUserID: pgDeleteSubscriptionByUserIDQuery,
|
||||
deleteSubscriptionByAge: pgDeleteSubscriptionByAgeQuery,
|
||||
insertSubscriptionTopic: pgInsertSubscriptionTopicQuery,
|
||||
deleteSubscriptionTopicAll: pgDeleteSubscriptionTopicAllQuery,
|
||||
deleteSubscriptionTopicWithoutSubscription: pgDeleteSubscriptionTopicWithoutSubscription,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
91
webpush/store_postgres_test.go
Normal file
91
webpush/store_postgres_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package webpush_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
func newTestPostgresStore(t *testing.T) webpush.Store {
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||
}
|
||||
// Create a unique schema for this test
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
setupDB, err := sql.Open("pgx", dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupDB.Close())
|
||||
// Open store with search_path set to the new schema
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
store, err := webpush.NewPostgresStore(u.String())
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
store.Close()
|
||||
cleanDB, err := sql.Open("pgx", dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
}
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
||||
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
||||
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
||||
testStoreUpsertSubscriptionUpdateTopics(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
||||
testStoreUpsertSubscriptionUpdateFields(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveByUserIDMultiple(t *testing.T) {
|
||||
testStoreRemoveByUserIDMultiple(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveByEndpoint(t *testing.T) {
|
||||
testStoreRemoveByEndpoint(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveByUserID(t *testing.T) {
|
||||
testStoreRemoveByUserID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveByUserIDEmpty(t *testing.T) {
|
||||
testStoreRemoveByUserIDEmpty(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreExpiryWarningSent(t *testing.T) {
|
||||
store := newTestPostgresStore(t)
|
||||
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
|
||||
func TestPostgresStoreExpiring(t *testing.T) {
|
||||
store := newTestPostgresStore(t)
|
||||
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveExpired(t *testing.T) {
|
||||
store := newTestPostgresStore(t)
|
||||
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
142
webpush/store_sqlite.go
Normal file
142
webpush/store_sqlite.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
sqliteCreateWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
sqliteBuiltinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
sqliteSelectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
sqliteSelectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
sqliteSelectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
sqliteSelectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
sqliteInsertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
sqliteUpdateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
sqliteUpdateWebPushSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
|
||||
sqliteDeleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
sqliteDeleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
sqliteDeleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
sqliteInsertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
sqliteDeleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
sqliteDeleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
||||
)
|
||||
|
||||
// SQLite schema management queries
|
||||
const (
|
||||
sqliteCurrentWebPushSchemaVersion = 1
|
||||
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed web push store.
|
||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpoint,
|
||||
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIP,
|
||||
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,
|
||||
selectSubscriptionsExpiringSoon: sqliteSelectWebPushSubscriptionsExpiringSoonQuery,
|
||||
insertSubscription: sqliteInsertWebPushSubscriptionQuery,
|
||||
updateSubscriptionWarningSent: sqliteUpdateWebPushSubscriptionWarningSentQuery,
|
||||
updateSubscriptionUpdatedAt: sqliteUpdateWebPushSubscriptionUpdatedAtQuery,
|
||||
deleteSubscriptionByEndpoint: sqliteDeleteWebPushSubscriptionByEndpointQuery,
|
||||
deleteSubscriptionByUserID: sqliteDeleteWebPushSubscriptionByUserIDQuery,
|
||||
deleteSubscriptionByAge: sqliteDeleteWebPushSubscriptionByAgeQuery,
|
||||
insertSubscriptionTopic: sqliteInsertWebPushSubscriptionTopicQuery,
|
||||
deleteSubscriptionTopicAll: sqliteDeleteWebPushSubscriptionTopicAllQuery,
|
||||
deleteSubscriptionTopicWithoutSubscription: sqliteDeleteWebPushSubscriptionTopicWithoutSubscription,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
63
webpush/store_sqlite_test.go
Normal file
63
webpush/store_sqlite_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package webpush_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
func newTestSQLiteStore(t *testing.T) webpush.Store {
|
||||
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
||||
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
||||
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
||||
testStoreUpsertSubscriptionUpdateTopics(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
||||
testStoreUpsertSubscriptionUpdateFields(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveByUserIDMultiple(t *testing.T) {
|
||||
testStoreRemoveByUserIDMultiple(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveByEndpoint(t *testing.T) {
|
||||
testStoreRemoveByEndpoint(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveByUserID(t *testing.T) {
|
||||
testStoreRemoveByUserID(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveByUserIDEmpty(t *testing.T) {
|
||||
testStoreRemoveByUserIDEmpty(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreExpiryWarningSent(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
|
||||
func TestSQLiteStoreExpiring(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveExpired(t *testing.T) {
|
||||
store := newTestSQLiteStore(t)
|
||||
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
|
||||
}
|
||||
213
webpush/store_test.go
Normal file
213
webpush/store_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package webpush_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/webpush"
|
||||
)
|
||||
|
||||
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
||||
|
||||
func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpush.Store) {
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
subs, err := store.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "u_1234")
|
||||
|
||||
subs2, err := store.SubscriptionsForTopic("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs2, 1)
|
||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||
}
|
||||
|
||||
func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store webpush.Store) {
|
||||
// Insert 10 subscriptions with the same IP address
|
||||
for i := 0; i < 10; i++ {
|
||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||
require.Nil(t, store.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
// Another one for the same endpoint should be fine
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different endpoint it should fail
|
||||
require.Equal(t, webpush.ErrWebPushTooManySubscriptions, store.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different IP address it should be fine again
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store) {
|
||||
// Insert subscription with two topics, and another with one topic
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
|
||||
|
||||
subs, err = store.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
// Update the first subscription to have only one topic
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
subs, err = store.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store) {
|
||||
// Insert a subscription
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, "auth-key", subs[0].Auth)
|
||||
require.Equal(t, "p256dh-key", subs[0].P256dh)
|
||||
require.Equal(t, "u_1234", subs[0].UserID)
|
||||
|
||||
// Re-upsert the same endpoint with different auth, p256dh, and userID
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "new-auth", "new-p256dh", "u_5678", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
require.Equal(t, "new-auth", subs[0].Auth)
|
||||
require.Equal(t, "new-p256dh", subs[0].P256dh)
|
||||
require.Equal(t, "u_5678", subs[0].UserID)
|
||||
}
|
||||
|
||||
func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
|
||||
// Insert two subscriptions for u_1234 and one for u_5678
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"2", "auth-key", "p256dh-key", "u_5678", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 3)
|
||||
|
||||
// Remove all subscriptions for u_1234
|
||||
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
|
||||
|
||||
// Only u_5678's subscription should remain
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
|
||||
require.Equal(t, "u_5678", subs[0].UserID)
|
||||
}
|
||||
|
||||
func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, store.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, store.RemoveSubscriptionsByUserID("u_1234"))
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func testStoreRemoveByUserIDEmpty(t *testing.T, store webpush.Store) {
|
||||
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
|
||||
}
|
||||
|
||||
func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
|
||||
// Set updated_at to the past so it shows up as expiring
|
||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
||||
|
||||
// Verify subscription appears in expiring list (warned_at == 0)
|
||||
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, store.MarkExpiryWarningSent(subs))
|
||||
|
||||
// Verify subscription no longer appears in expiring list (warned_at > 0)
|
||||
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
||||
|
||||
// Should not be cleaned up yet
|
||||
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// Run expiration
|
||||
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
}
|
||||
|
||||
func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
|
||||
|
||||
// Run expiration
|
||||
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// List again, should be 0
|
||||
subs, err = store.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
21
webpush/types.go
Normal file
21
webpush/types.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package webpush
|
||||
|
||||
import "heckel.io/ntfy/v2/log"
|
||||
|
||||
// Subscription represents a web push subscription.
|
||||
type Subscription struct {
|
||||
ID string
|
||||
Endpoint string
|
||||
Auth string
|
||||
P256dh string
|
||||
UserID string
|
||||
}
|
||||
|
||||
// Context returns the logging context for the subscription.
|
||||
func (w *Subscription) Context() log.Context {
|
||||
return map[string]any{
|
||||
"web_push_subscription_id": w.ID,
|
||||
"web_push_subscription_user_id": w.UserID,
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user