Compare commits
52 Commits
postgres-w
...
postgres-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11c79a6369 | ||
|
|
10a6939d8e | ||
|
|
9736973286 | ||
|
|
039566bcaf | ||
|
|
544ce112b5 | ||
|
|
c19377109e | ||
|
|
ccbd02331c | ||
|
|
542aa403d2 | ||
|
|
ebb48e217d | ||
|
|
7710ace184 | ||
|
|
5c26e70fe7 | ||
|
|
c66fa92341 | ||
|
|
a7d5a9c5d8 | ||
|
|
391cd2c920 | ||
|
|
9eec72adcc | ||
|
|
28a436c0d2 | ||
|
|
b02366b42b | ||
|
|
90d0eca14d | ||
|
|
811c7ae25a | ||
|
|
850a9d4cc4 | ||
|
|
43280fbc0a | ||
|
|
35a54407a8 | ||
|
|
f726cc768e | ||
|
|
5d301e7dce | ||
|
|
8b12bdeb3a | ||
|
|
6375c2ce60 | ||
|
|
459c80ef9b | ||
|
|
b1eb90addc | ||
|
|
4b6979aa89 | ||
|
|
c76e39bb0e | ||
|
|
b82e1c3915 | ||
|
|
07e60ba041 | ||
|
|
4e22e7f4c8 | ||
|
|
cf3ae187ce | ||
|
|
d19b825192 | ||
|
|
2e499389fc | ||
|
|
7c69a76345 | ||
|
|
a28d8e7924 | ||
|
|
13a3062a7f | ||
|
|
eb6e1ac44a | ||
|
|
a4c836b531 | ||
|
|
e818b063f7 | ||
|
|
039d555689 | ||
|
|
209d5a4c62 | ||
|
|
bf265449ac | ||
|
|
4cbd80c68e | ||
|
|
305e3fc9af | ||
|
|
93cd7f99f8 | ||
|
|
28e85df36e | ||
|
|
2bc7b5217b | ||
|
|
046c0e8c79 | ||
|
|
652b2097ad |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,6 +7,7 @@ build/
|
|||||||
server/docs/
|
server/docs/
|
||||||
server/site/
|
server/site/
|
||||||
tools/fbsend/fbsend
|
tools/fbsend/fbsend
|
||||||
|
tools/pgimport/pgimport
|
||||||
playground/
|
playground/
|
||||||
secrets/
|
secrets/
|
||||||
*.iml
|
*.iml
|
||||||
|
|||||||
4
Makefile
4
Makefile
@@ -268,10 +268,10 @@ check: test web-fmt-check fmt-check vet web-lint lint staticcheck
|
|||||||
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
|
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||||
|
|
||||||
test: .PHONY
|
test: .PHONY
|
||||||
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -parallel 3 $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|
||||||
testv: .PHONY
|
testv: .PHONY
|
||||||
go test -v $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -v -parallel 3 $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|
||||||
race: .PHONY
|
race: .PHONY
|
||||||
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test -v -race $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|||||||
29
cmd/serve.go
29
cmd/serve.go
@@ -282,7 +282,9 @@ func execServe(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check values
|
// Check values
|
||||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||||
|
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||||
|
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||||
return errors.New("if set, FCM key file must exist")
|
return errors.New("if set, FCM key file must exist")
|
||||||
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
||||||
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
||||||
@@ -323,8 +325,8 @@ func execServe(c *cli.Context) error {
|
|||||||
return errors.New("if upstream-base-url is set, base-url must also be set")
|
return errors.New("if upstream-base-url is set, base-url must also be set")
|
||||||
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
|
} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
|
||||||
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
|
return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
|
||||||
} else if authFile == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
|
} else if authFile == "" && databaseURL == "" && (enableSignup || enableLogin || requireLogin || enableReservations || stripeSecretKey != "") {
|
||||||
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
|
return errors.New("cannot set enable-signup, enable-login, require-login, enable-reserve-topics, or stripe-secret-key if auth-file or database-url is not set")
|
||||||
} else if enableSignup && !enableLogin {
|
} else if enableSignup && !enableLogin {
|
||||||
return errors.New("cannot set enable-signup without also setting enable-login")
|
return errors.New("cannot set enable-signup without also setting enable-login")
|
||||||
} else if requireLogin && !enableLogin {
|
} else if requireLogin && !enableLogin {
|
||||||
@@ -333,8 +335,8 @@ func execServe(c *cli.Context) error {
|
|||||||
return errors.New("cannot set stripe-secret-key or stripe-webhook-key, support for payments is not available in this build (nopayments)")
|
return errors.New("cannot set stripe-secret-key or stripe-webhook-key, support for payments is not available in this build (nopayments)")
|
||||||
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
|
} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
|
||||||
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
|
return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
|
||||||
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || authFile == "") {
|
} else if twilioAccount != "" && (twilioAuthToken == "" || twilioPhoneNumber == "" || twilioVerifyService == "" || baseURL == "" || (authFile == "" && databaseURL == "")) {
|
||||||
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file must also be set")
|
return errors.New("if twilio-account is set, twilio-auth-token, twilio-phone-number, twilio-verify-service, base-url, and auth-file (or database-url) must also be set")
|
||||||
} else if messageSizeLimit > server.DefaultMessageSizeLimit {
|
} else if messageSizeLimit > server.DefaultMessageSizeLimit {
|
||||||
log.Warn("message-size-limit is greater than 4K, this is not recommended and largely untested, and may lead to issues with some clients")
|
log.Warn("message-size-limit is greater than 4K, this is not recommended and largely untested, and may lead to issues with some clients")
|
||||||
if messageSizeLimit > 5*1024*1024 {
|
if messageSizeLimit > 5*1024*1024 {
|
||||||
@@ -414,6 +416,15 @@ func execServe(c *cli.Context) error {
|
|||||||
payments.Setup(stripeSecretKey)
|
payments.Setup(stripeSecretKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse Twilio template
|
||||||
|
var twilioCallFormatTemplate *template.Template
|
||||||
|
if twilioCallFormat != "" {
|
||||||
|
twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add default forbidden topics
|
// Add default forbidden topics
|
||||||
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
|
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
|
||||||
|
|
||||||
@@ -461,13 +472,7 @@ func execServe(c *cli.Context) error {
|
|||||||
conf.TwilioAuthToken = twilioAuthToken
|
conf.TwilioAuthToken = twilioAuthToken
|
||||||
conf.TwilioPhoneNumber = twilioPhoneNumber
|
conf.TwilioPhoneNumber = twilioPhoneNumber
|
||||||
conf.TwilioVerifyService = twilioVerifyService
|
conf.TwilioVerifyService = twilioVerifyService
|
||||||
if twilioCallFormat != "" {
|
conf.TwilioCallFormat = twilioCallFormatTemplate
|
||||||
tmpl, err := template.New("twiml").Parse(twilioCallFormat)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
|
||||||
}
|
|
||||||
conf.TwilioCallFormat = tmpl
|
|
||||||
}
|
|
||||||
conf.MessageSizeLimit = int(messageSizeLimit)
|
conf.MessageSizeLimit = int(messageSizeLimit)
|
||||||
conf.MessageDelayMax = messageDelayLimit
|
conf.MessageDelayMax = messageDelayLimit
|
||||||
conf.TotalTopicLimit = totalTopicLimit
|
conf.TotalTopicLimit = totalTopicLimit
|
||||||
|
|||||||
21
cmd/user.go
21
cmd/user.go
@@ -6,13 +6,14 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/v2/server"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
|
"heckel.io/ntfy/v2/server"
|
||||||
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -377,22 +378,20 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
var store user.Store
|
|
||||||
if databaseURL != "" {
|
if databaseURL != "" {
|
||||||
store, err = user.NewPostgresStore(databaseURL)
|
pool, dbErr := db.OpenPostgres(databaseURL)
|
||||||
|
if dbErr != nil {
|
||||||
|
return nil, dbErr
|
||||||
|
}
|
||||||
|
return user.NewPostgresManager(pool, authConfig)
|
||||||
} else if authFile != "" {
|
} else if authFile != "" {
|
||||||
if !util.FileExists(authFile) {
|
if !util.FileExists(authFile) {
|
||||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||||
}
|
}
|
||||||
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
|
return user.NewSQLiteManager(authFile, authStartupQueries, authConfig)
|
||||||
} else {
|
}
|
||||||
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return user.NewManager(store, authConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||||
fmt.Fprint(c.App.ErrWriter, "password: ")
|
fmt.Fprint(c.App.ErrWriter, "password: ")
|
||||||
|
|||||||
93
db/db.go
Normal file
93
db/db.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
paramMaxOpenConns = "pool_max_conns"
|
||||||
|
paramMaxIdleConns = "pool_max_idle_conns"
|
||||||
|
paramConnMaxLifetime = "pool_conn_max_lifetime"
|
||||||
|
paramConnMaxIdleTime = "pool_conn_max_idle_time"
|
||||||
|
|
||||||
|
defaultMaxOpenConns = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenPostgres opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||||
|
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||||
|
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||||
|
// the DSN before passing it to the driver.
|
||||||
|
func OpenPostgres(dsn string) (*sql.DB, error) {
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
maxOpenConns, err := extractIntParam(q, paramMaxOpenConns, defaultMaxOpenConns)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
maxIdleConns, err := extractIntParam(q, paramMaxIdleConns, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxLifetime, err := extractDurationParam(q, paramConnMaxLifetime, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxIdleTime, err := extractDurationParam(q, paramConnMaxIdleTime, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
db, err := sql.Open("pgx", u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.SetMaxOpenConns(maxOpenConns)
|
||||||
|
if maxIdleConns > 0 {
|
||||||
|
db.SetMaxIdleConns(maxIdleConns)
|
||||||
|
}
|
||||||
|
if connMaxLifetime > 0 {
|
||||||
|
db.SetConnMaxLifetime(connMaxLifetime)
|
||||||
|
}
|
||||||
|
if connMaxIdleTime > 0 {
|
||||||
|
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, fmt.Errorf("ping failed: %w", err)
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
v, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
63
db/test/test.go
Normal file
63
db/test/test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package dbtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testPoolMaxConns = "2"
|
||||||
|
|
||||||
|
// CreateTestPostgresSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it.
|
||||||
|
// It registers a cleanup function to drop the schema when the test finishes.
|
||||||
|
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||||
|
func CreateTestPostgresSchema(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||||
|
}
|
||||||
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("pool_max_conns", testPoolMaxConns)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
dsn = u.String()
|
||||||
|
setupDB, err := db.OpenPostgres(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
q.Set("search_path", schema)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
schemaDSN := u.String()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cleanDB, err := db.OpenPostgres(dsn)
|
||||||
|
if err == nil {
|
||||||
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return schemaDSN
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTestPostgres creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
|
||||||
|
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
||||||
|
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||||
|
func CreateTestPostgres(t *testing.T) *sql.DB {
|
||||||
|
t.Helper()
|
||||||
|
schemaDSN := CreateTestPostgresSchema(t)
|
||||||
|
testDB, err := db.OpenPostgres(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
testDB.Close()
|
||||||
|
})
|
||||||
|
return testDB
|
||||||
|
}
|
||||||
@@ -53,6 +53,16 @@ Here are a few working sample configs using a `/etc/ntfy/server.yml` file:
|
|||||||
behind-proxy: true
|
behind-proxy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "server.yml (PostgreSQL, behind proxy)"
|
||||||
|
``` yaml
|
||||||
|
base-url: "https://ntfy.example.com"
|
||||||
|
listen-http: ":2586"
|
||||||
|
database-url: "postgres://ntfy:mypassword@db.example.com:5432/ntfy?sslmode=require"
|
||||||
|
attachment-cache-dir: "/var/cache/ntfy/attachments"
|
||||||
|
behind-proxy: true
|
||||||
|
auth-default-access: "deny-all"
|
||||||
|
```
|
||||||
|
|
||||||
=== "server.yml (ntfy.sh config)"
|
=== "server.yml (ntfy.sh config)"
|
||||||
``` yaml
|
``` yaml
|
||||||
# All the things: Behind a proxy, Firebase, cache, attachments,
|
# All the things: Behind a proxy, Firebase, cache, attachments,
|
||||||
@@ -125,16 +135,63 @@ using Docker Compose (i.e. `docker-compose.yml`):
|
|||||||
command: serve
|
command: serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Database options
|
||||||
|
ntfy uses a database for storing messages ([message cache](#message-cache)), users and [access control](#access-control), and [web push](#web-push) subscriptions.
|
||||||
|
You can choose between **SQLite** and **PostgreSQL** as the database backend.
|
||||||
|
|
||||||
|
### SQLite
|
||||||
|
By default, ntfy uses SQLite with separate database files for each store. This is the simplest setup and requires
|
||||||
|
no external dependencies:
|
||||||
|
|
||||||
|
* `cache-file`: Database file for the [message cache](#message-cache).
|
||||||
|
* `auth-file`: Database file for authentication and [access control](#access-control). If set, enables auth.
|
||||||
|
* `web-push-file`: Database file for [web push](#web-push) subscriptions.
|
||||||
|
|
||||||
|
### PostgreSQL (EXPERIMENTAL)
|
||||||
|
As an alternative, you can configure ntfy to use PostgreSQL for **all** database-backed stores by setting the
|
||||||
|
`database-url` option to a PostgreSQL connection string:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
database-url: "postgres://user:pass@host:5432/ntfy"
|
||||||
|
```
|
||||||
|
|
||||||
|
When `database-url` is set, ntfy will use PostgreSQL for the [message cache](#message-cache),
|
||||||
|
[access control](#access-control), and [web push](#web-push) subscriptions instead of SQLite. The `cache-file`,
|
||||||
|
`auth-file`, and `web-push-file` options **must not** be set in this case.
|
||||||
|
|
||||||
|
Note that setting `database-url` implicitly enables authentication and access control (equivalent to setting
|
||||||
|
`auth-file` with SQLite). The default access is `read-write`, so anonymous users can still read and write to all
|
||||||
|
topics. To restrict access, set `auth-default-access` to `deny-all` (see [access control](#access-control)).
|
||||||
|
|
||||||
|
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
|
||||||
|
|
||||||
|
The database URL supports the standard [PostgreSQL connection parameters](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
|
||||||
|
as query parameters, such as `sslmode`, `connect_timeout`, `sslcert`, `sslkey`, `sslrootcert`, and `application_name`.
|
||||||
|
See the [pgx driver documentation](https://pkg.go.dev/github.com/jackc/pgx/v5) for the full list of supported parameters.
|
||||||
|
|
||||||
|
In addition, ntfy supports the following custom query parameters to tune the connection pool:
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
|---------------------------|---------|----------------------------------------------------------------------------------|
|
||||||
|
| `pool_max_conns` | 10 | Maximum number of open connections to the database |
|
||||||
|
| `pool_max_idle_conns` | - | Maximum number of idle connections in the pool |
|
||||||
|
| `pool_conn_max_lifetime` | - | Maximum amount of time a connection may be reused (Go duration, e.g. `5m`, `1h`) |
|
||||||
|
| `pool_conn_max_idle_time` | - | Maximum amount of time a connection may be idle (Go duration, e.g. `30s`, `5m`) |
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50&pool_conn_max_idle_time=5m"
|
||||||
|
```
|
||||||
|
|
||||||
## Message cache
|
## Message cache
|
||||||
If desired, ntfy can temporarily keep notifications in an in-memory or an on-disk cache. Caching messages for a short period
|
If desired, ntfy can temporarily keep notifications in an in-memory or an on-disk cache. Caching messages for a short period
|
||||||
of time is important to allow [phones](subscribe/phone.md) and other devices with brittle Internet connections to be able to retrieve
|
of time is important to allow [phones](subscribe/phone.md) and other devices with brittle Internet connections to be able to retrieve
|
||||||
notifications that they may have missed.
|
notifications that they may have missed.
|
||||||
|
|
||||||
By default, ntfy keeps messages **in-memory for 12 hours**, which means that **cached messages do not survive an application
|
By default, ntfy keeps messages **in-memory for 12 hours**, which means that **cached messages do not survive an application
|
||||||
restart**. You can override this behavior using the following config settings:
|
restart**. You can override this behavior by setting `cache-file` (SQLite) or `database-url` (PostgreSQL).
|
||||||
|
|
||||||
* `cache-file`: if set, ntfy will store messages in a SQLite based cache (default is empty, which means in-memory cache).
|
|
||||||
**This is required if you'd like messages to be retained across restarts**.
|
|
||||||
* `cache-duration`: defines the duration for which messages are stored in the cache (default is `12h`).
|
* `cache-duration`: defines the duration for which messages are stored in the cache (default is `12h`).
|
||||||
|
|
||||||
You can also entirely disable the cache by setting `cache-duration` to `0`. When the cache is disabled, messages are only
|
You can also entirely disable the cache by setting `cache-duration` to `0`. When the cache is disabled, messages are only
|
||||||
@@ -144,20 +201,6 @@ the message to the subscribers.
|
|||||||
Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscribe/api.md#poll-for-messages), as well as the
|
Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscribe/api.md#poll-for-messages), as well as the
|
||||||
[`since=` parameter](subscribe/api.md#fetch-cached-messages).
|
[`since=` parameter](subscribe/api.md#fetch-cached-messages).
|
||||||
|
|
||||||
## PostgreSQL database
|
|
||||||
By default, ntfy uses SQLite for all database-backed stores. As an alternative, you can configure ntfy to use PostgreSQL
|
|
||||||
by setting the `database-url` option to a PostgreSQL connection string:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
database-url: "postgres://user:pass@host:5432/ntfy"
|
|
||||||
```
|
|
||||||
|
|
||||||
When `database-url` is set, ntfy will use PostgreSQL for the web push subscription store instead of SQLite. The
|
|
||||||
`web-push-file` option is not required in this case. Support for PostgreSQL for the message cache and user manager
|
|
||||||
will be added in future releases.
|
|
||||||
|
|
||||||
You can also set this via the environment variable `NTFY_DATABASE_URL` or the command line flag `--database-url`.
|
|
||||||
|
|
||||||
## Attachments
|
## Attachments
|
||||||
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
|
If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable
|
||||||
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
|
this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`).
|
||||||
@@ -199,14 +242,15 @@ and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is
|
|||||||
By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how
|
By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how
|
||||||
ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization.
|
ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization.
|
||||||
|
|
||||||
ntfy's auth is implemented with a simple [SQLite](https://www.sqlite.org/)-based backend. It implements two roles
|
ntfy's auth implements two roles (`user` and `admin`) and per-topic `read` and `write` permissions using an
|
||||||
(`user` and `admin`) and per-topic `read` and `write` permissions using an [access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list).
|
[access control list (ACL)](https://en.wikipedia.org/wiki/Access-control_list). Access control entries can be applied
|
||||||
Access control entries can be applied to users as well as the special everyone user (`*`), which represents anonymous API access.
|
to users as well as the special everyone user (`*`), which represents anonymous API access.
|
||||||
|
|
||||||
To set up auth, **configure the following options**:
|
To set up auth, **configure the following options**:
|
||||||
|
|
||||||
* `auth-file` is the user/access database; it is created automatically if it doesn't already exist; suggested
|
* `auth-file` is the user/access database (SQLite); it is created automatically if it doesn't already exist; suggested
|
||||||
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used)
|
location `/var/lib/ntfy/user.db` (easiest if deb/rpm package is used). Alternatively, if `database-url` is set,
|
||||||
|
auth is automatically enabled using PostgreSQL (see [database options](#database-options)).
|
||||||
* `auth-default-access` defines the default/fallback access if no access control entry is found; it can be
|
* `auth-default-access` defines the default/fallback access if no access control entry is found; it can be
|
||||||
set to `read-write` (default), `read-only`, `write-only` or `deny-all`. **If you are setting up a private instance,
|
set to `read-write` (default), `read-only`, `write-only` or `deny-all`. **If you are setting up a private instance,
|
||||||
you'll want to set this to `deny-all`** (see [private instance example](#example-private-instance)).
|
you'll want to set this to `deny-all`** (see [private instance example](#example-private-instance)).
|
||||||
@@ -1161,8 +1205,8 @@ a database to keep track of the browser's subscriptions, and an admin email addr
|
|||||||
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
|
- `web-push-expiry-warning-duration` defines the duration after which unused subscriptions are sent a warning (default is `55d`)
|
||||||
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
|
- `web-push-expiry-duration` defines the duration after which unused subscriptions will expire (default is `60d`)
|
||||||
|
|
||||||
Alternatively, you can use PostgreSQL instead of SQLite for the web push subscription store by setting `database-url`
|
Alternatively, you can use PostgreSQL instead of SQLite by setting `database-url`
|
||||||
(see [PostgreSQL database](#postgresql-database)).
|
(see [PostgreSQL database](#postgresql-experimental)).
|
||||||
|
|
||||||
Limitations:
|
Limitations:
|
||||||
|
|
||||||
@@ -1773,13 +1817,13 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
|
|||||||
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
|
| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. |
|
||||||
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
|
| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. |
|
||||||
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
|
| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). |
|
||||||
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for database-backed stores instead of SQLite. Currently applies to the web push store. See [PostgreSQL database](#postgresql-database). |
|
| `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for all database-backed stores (message cache, user manager, web push) instead of SQLite. See [database options](#database-options). |
|
||||||
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
|
| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). |
|
||||||
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
|
| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. |
|
||||||
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
|
| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) |
|
||||||
| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) |
|
| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) |
|
||||||
| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) |
|
| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) |
|
||||||
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control. If set, enables authentication and access control. See [access control](#access-control). |
|
| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). |
|
||||||
| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. |
|
| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. |
|
||||||
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
|
| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) |
|
||||||
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
|
| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) |
|
||||||
|
|||||||
@@ -661,6 +661,8 @@ Add the following function and alias to your `.bashrc` or `.bash_profile`:
|
|||||||
local token=$(< ~/.ntfy_token) # Securely read the token
|
local token=$(< ~/.ntfy_token) # Securely read the token
|
||||||
local status_icon="$([ $exit_status -eq 0 ] && echo magic_wand || echo warning)"
|
local status_icon="$([ $exit_status -eq 0 ] && echo magic_wand || echo warning)"
|
||||||
local last_command=$(history | tail -n1 | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
|
local last_command=$(history | tail -n1 | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
|
||||||
|
# for zsh users, use the same sed pattern but get the history differently.
|
||||||
|
# local last_command=$(history "$HISTCMD" | sed -e 's/^[[:space:]]*[0-9]\{1,\}[[:space:]]*//' -e 's/[;&|][[:space:]]*alert$//')
|
||||||
|
|
||||||
curl -s -X POST "https://n.example.dev/alerts" \
|
curl -s -X POST "https://n.example.dev/alerts" \
|
||||||
-H "Authorization: Bearer $token" \
|
-H "Authorization: Bearer $token" \
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ I've added a ⭐ to projects or posts that have a significant following, or had
|
|||||||
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
|
- [Uptime Monitor](https://uptime-monitor.org) - Self-hosted, enterprise-grade uptime monitoring and alerting system (TS)
|
||||||
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
|
- [send_to_ntfy_extension](https://github.com/TheDuffman85/send_to_ntfy_extension/) ⭐ - A browser extension to send the notifications to ntfy (JS)
|
||||||
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
|
- [SIA-Server](https://github.com/ZebMcKayhan/SIA-Server) - A light weight, self-hosted notification Server for Honywell Galaxy Flex alarm systems (Python)
|
||||||
|
- [zabbix-ntfy](https://github.com/torgrimt/zabbix-ntfy) - Zabbix server Mediatype to add support for ntfy.sh services
|
||||||
|
|
||||||
## Blog + forum posts
|
## Blog + forum posts
|
||||||
|
|
||||||
|
|||||||
@@ -7,11 +7,32 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
|||||||
| Component | Version | Release date |
|
| Component | Version | Release date |
|
||||||
|------------------|---------|--------------|
|
|------------------|---------|--------------|
|
||||||
| ntfy server | v2.17.0 | Feb 8, 2026 |
|
| ntfy server | v2.17.0 | Feb 8, 2026 |
|
||||||
| ntfy Android app | v1.22.2 | Jan 25, 2026 |
|
| ntfy Android app | v1.23.0 | Deb 22, 2026 |
|
||||||
| ntfy iOS app | v1.3 | Nov 26, 2023 |
|
| ntfy iOS app | v1.3 | Nov 26, 2023 |
|
||||||
|
|
||||||
Please check out the release notes for [upcoming releases](#not-released-yet) below.
|
Please check out the release notes for [upcoming releases](#not-released-yet) below.
|
||||||
|
|
||||||
|
## ntfy Android v1.23.0
|
||||||
|
Released February 22, 2026
|
||||||
|
|
||||||
|
This release adds support for search within a topic, and adds [copy action](publish.md#copy-to-clipboard) support
|
||||||
|
to the Android app.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
|
||||||
|
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
|
||||||
|
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
|
||||||
|
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
|
||||||
|
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
|
||||||
|
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
||||||
|
|
||||||
|
**Bug fixes + maintenance:**
|
||||||
|
|
||||||
|
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
||||||
|
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
|
||||||
|
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
|
||||||
|
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
|
||||||
|
|
||||||
## ntfy server v2.17.0
|
## ntfy server v2.17.0
|
||||||
Released February 8, 2026
|
Released February 8, 2026
|
||||||
|
|
||||||
@@ -1698,26 +1719,18 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
|
|||||||
|
|
||||||
## Not released yet
|
## Not released yet
|
||||||
|
|
||||||
### ntfy Android v1.23.x (UNRELEASED)
|
### ntfy server v2.18.x (UNRELEASED)
|
||||||
|
|
||||||
**Features:**
|
**Features:**
|
||||||
|
|
||||||
* Search within a topic ([#141](https://github.com/binwiederhier/ntfy/issues/141), [ntfy-android#153](https://github.com/binwiederhier/ntfy-android/pull/153), thanks to [@Copephobia](https://github.com/Copephobia) and [@StoyanYonkov](https://github.com/StoyanYonkov) for reporting and sponsoring)
|
* Add experimental [PostgreSQL support](config.md#postgresql-experimental) as an alternative database backend (message cache, user manager, web push subscriptions) via `database-url` config option ([#1114](https://github.com/binwiederhier/ntfy/issues/1114), thanks to [@brettinternet](https://github.com/brettinternet) for reporting)
|
||||||
* Add "reconnecting to N topics ..." to foreground notification ([#1101](https://github.com/binwiederhier/ntfy/issues/1101), thanks to [@milosivanovic](https://github.com/milosivanovic) for reporting)
|
|
||||||
* Improved default server dialog with full-screen UI and stricter URL validation ([#1582](https://github.com/binwiederhier/ntfy/issues/1582))
|
**Bug fixes + maintenance:**
|
||||||
* Show last notification time for UnifiedPush subscriptions ([#1230](https://github.com/binwiederhier/ntfy/issues/1230), [#1454](https://github.com/binwiederhier/ntfy/issues/1454), thanks to [@Tealk](https://github.com/Tealk) and [@user4andre](https://github.com/user4andre) for reporting)
|
|
||||||
* Support "copy" action button to copy a value to the clipboard ([#1364](https://github.com/binwiederhier/ntfy/issues/1364), thanks to [@SudoWatson](https://github.com/SudoWatson) for reporting)
|
* Preserve `<br>` line breaks in HTML-only emails received via SMTP ([#690](https://github.com/binwiederhier/ntfy/issues/690), [#1620](https://github.com/binwiederhier/ntfy/pull/1620), thanks to [@uzkikh](https://github.com/uzkikh) for the fix and to [@teastrainer](https://github.com/teastrainer) for reporting)
|
||||||
|
|
||||||
|
### ntfy Android v1.24.x (UNRELEASED)
|
||||||
|
|
||||||
**Bug fixes + maintenance:**
|
**Bug fixes + maintenance:**
|
||||||
|
|
||||||
* Fix `clear=true` on action buttons not marking notification as read ([#1029](https://github.com/binwiederhier/ntfy/issues/1029), thanks to [@ElFishi](https://github.com/ElFishi) for reporting)
|
|
||||||
* Fix crash when default server URL is missing scheme by auto-prepending `https://` ([#1582](https://github.com/binwiederhier/ntfy/issues/1582), thanks to [@hard-zero1](https://github.com/hard-zero1))
|
|
||||||
* Fix notification timestamp to use original send time instead of receive time ([#1112](https://github.com/binwiederhier/ntfy/issues/1112), thanks to [@voruti](https://github.com/voruti) for reporting)
|
|
||||||
* Fix notifications being missed after service restart by using persisted lastNotificationId ([#1591](https://github.com/binwiederhier/ntfy/issues/1591), thanks to @Epifeny for reporting)
|
|
||||||
* Fix crash in settings when fragment is detached during backup/restore or log operations
|
* Fix crash in settings when fragment is detached during backup/restore or log operations
|
||||||
|
|
||||||
### ntfy server v2.12.x (UNRELEASED)
|
|
||||||
|
|
||||||
**Features:**
|
|
||||||
|
|
||||||
* Add PostgreSQL as an alternative database backend for the web push subscription store via `database-url` config option
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -21,39 +20,13 @@ const (
|
|||||||
|
|
||||||
var errNoRows = errors.New("no rows found")
|
var errNoRows = errors.New("no rows found")
|
||||||
|
|
||||||
// Store is the interface for a message cache store
|
// queries holds the database-specific SQL queries
|
||||||
type Store interface {
|
type queries struct {
|
||||||
AddMessage(m *model.Message) error
|
|
||||||
AddMessages(ms []*model.Message) error
|
|
||||||
DB() *sql.DB
|
|
||||||
Message(id string) (*model.Message, error)
|
|
||||||
MessageCounts() (map[string]int, error)
|
|
||||||
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
|
||||||
MessagesDue() ([]*model.Message, error)
|
|
||||||
MessagesExpired() ([]string, error)
|
|
||||||
MarkPublished(m *model.Message) error
|
|
||||||
UpdateMessageTime(messageID string, timestamp int64) error
|
|
||||||
Topics() ([]string, error)
|
|
||||||
DeleteMessages(ids ...string) error
|
|
||||||
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
|
|
||||||
ExpireMessages(topics ...string) error
|
|
||||||
AttachmentsExpired() ([]string, error)
|
|
||||||
MarkAttachmentsDeleted(ids ...string) error
|
|
||||||
AttachmentBytesUsedBySender(sender string) (int64, error)
|
|
||||||
AttachmentBytesUsedByUser(userID string) (int64, error)
|
|
||||||
UpdateStats(messages int64) error
|
|
||||||
Stats() (int64, error)
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// storeQueries holds the database-specific SQL queries
|
|
||||||
type storeQueries struct {
|
|
||||||
insertMessage string
|
insertMessage string
|
||||||
deleteMessage string
|
deleteMessage string
|
||||||
selectScheduledMessageIDsBySeqID string
|
selectScheduledMessageIDsBySeqID string
|
||||||
deleteScheduledBySequenceID string
|
deleteScheduledBySequenceID string
|
||||||
updateMessagesForTopicExpiry string
|
updateMessagesForTopicExpiry string
|
||||||
selectRowIDFromMessageID string
|
|
||||||
selectMessagesByID string
|
selectMessagesByID string
|
||||||
selectMessagesSinceTime string
|
selectMessagesSinceTime string
|
||||||
selectMessagesSinceTimeScheduled string
|
selectMessagesSinceTimeScheduled string
|
||||||
@@ -64,7 +37,6 @@ type storeQueries struct {
|
|||||||
selectMessagesExpired string
|
selectMessagesExpired string
|
||||||
updateMessagePublished string
|
updateMessagePublished string
|
||||||
selectMessagesCount string
|
selectMessagesCount string
|
||||||
selectMessageCountPerTopic string
|
|
||||||
selectTopics string
|
selectTopics string
|
||||||
updateAttachmentDeleted string
|
updateAttachmentDeleted string
|
||||||
selectAttachmentsExpired string
|
selectAttachmentsExpired string
|
||||||
@@ -75,37 +47,46 @@ type storeQueries struct {
|
|||||||
updateMessageTime string
|
updateMessageTime string
|
||||||
}
|
}
|
||||||
|
|
||||||
// commonStore implements store operations that are identical across database backends
|
// Cache stores published messages
|
||||||
type commonStore struct {
|
type Cache struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
queue *util.BatchingQueue[*model.Message]
|
queue *util.BatchingQueue[*model.Message]
|
||||||
nop bool
|
nop bool
|
||||||
mu sync.Mutex
|
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
||||||
queries storeQueries
|
queries queries
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
|
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
||||||
var queue *util.BatchingQueue[*model.Message]
|
var queue *util.BatchingQueue[*model.Message]
|
||||||
if batchSize > 0 || batchTimeout > 0 {
|
if batchSize > 0 || batchTimeout > 0 {
|
||||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||||
}
|
}
|
||||||
c := &commonStore{
|
c := &Cache{
|
||||||
db: db,
|
db: db,
|
||||||
queue: queue,
|
queue: queue,
|
||||||
nop: nop,
|
nop: nop,
|
||||||
|
mu: mu,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
}
|
}
|
||||||
go c.processMessageBatches()
|
go c.processMessageBatches()
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB returns the underlying database connection
|
func (c *Cache) maybeLock() {
|
||||||
func (c *commonStore) DB() *sql.DB {
|
if c.mu != nil {
|
||||||
return c.db
|
c.mu.Lock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) maybeUnlock() {
|
||||||
|
if c.mu != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||||
func (c *commonStore) AddMessage(m *model.Message) error {
|
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
||||||
|
func (c *Cache) AddMessage(m *model.Message) error {
|
||||||
if c.queue != nil {
|
if c.queue != nil {
|
||||||
c.queue.Enqueue(m)
|
c.queue.Enqueue(m)
|
||||||
return nil
|
return nil
|
||||||
@@ -114,13 +95,13 @@ func (c *commonStore) AddMessage(m *model.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddMessages synchronously stores a batch of messages to the message cache
|
// AddMessages synchronously stores a batch of messages to the message cache
|
||||||
func (c *commonStore) AddMessages(ms []*model.Message) error {
|
func (c *Cache) AddMessages(ms []*model.Message) error {
|
||||||
return c.addMessages(ms)
|
return c.addMessages(ms)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) addMessages(ms []*model.Message) error {
|
func (c *Cache) addMessages(ms []*model.Message) error {
|
||||||
c.mu.Lock()
|
c.maybeLock()
|
||||||
defer c.mu.Unlock()
|
defer c.maybeUnlock()
|
||||||
if c.nop {
|
if c.nop {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -204,7 +185,8 @@ func (c *commonStore) addMessages(ms []*model.Message) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
|
||||||
|
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
if since.IsNone() {
|
if since.IsNone() {
|
||||||
return make([]*model.Message, 0), nil
|
return make([]*model.Message, 0), nil
|
||||||
} else if since.IsLatest() {
|
} else if since.IsLatest() {
|
||||||
@@ -215,7 +197,7 @@ func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled
|
|||||||
return c.messagesSinceTime(topic, since, scheduled)
|
return c.messagesSinceTime(topic, since, scheduled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if scheduled {
|
if scheduled {
|
||||||
@@ -229,25 +211,13 @@ func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, s
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer idrows.Close()
|
|
||||||
if !idrows.Next() {
|
|
||||||
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
|
|
||||||
}
|
|
||||||
var rowID int64
|
|
||||||
if err := idrows.Scan(&rowID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
idrows.Close()
|
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
if scheduled {
|
if scheduled {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
|
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
||||||
} else {
|
} else {
|
||||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
|
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -255,7 +225,7 @@ func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, sch
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -263,7 +233,8 @@ func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
|||||||
return readMessages(rows)
|
return readMessages(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
// MessagesDue returns all messages that are due for publishing
|
||||||
|
func (c *Cache) MessagesDue() ([]*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -272,7 +243,7 @@ func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||||
func (c *commonStore) MessagesExpired() ([]string, error) {
|
func (c *Cache) MessagesExpired() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -292,7 +263,8 @@ func (c *commonStore) MessagesExpired() ([]string, error) {
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Message(id string) (*model.Message, error) {
|
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||||
|
func (c *Cache) Message(id string) (*model.Message, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -305,39 +277,40 @@ func (c *commonStore) Message(id string) (*model.Message, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
||||||
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
|
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
|
||||||
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MarkPublished(m *model.Message) error {
|
// MarkPublished marks a message as published
|
||||||
c.mu.Lock()
|
func (c *Cache) MarkPublished(m *model.Message) error {
|
||||||
defer c.mu.Unlock()
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MessageCounts() (map[string]int, error) {
|
// MessagesCount returns the total number of messages in the cache
|
||||||
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
|
func (c *Cache) MessagesCount() (int, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
var topic string
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
var count int
|
var count int
|
||||||
counts := make(map[string]int)
|
if err := rows.Scan(&count); err != nil {
|
||||||
for rows.Next() {
|
return 0, err
|
||||||
if err := rows.Scan(&topic, &count); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
counts[topic] = count
|
return count, nil
|
||||||
}
|
|
||||||
return counts, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Topics() ([]string, error) {
|
// Topics returns a list of all topics with messages in the cache
|
||||||
|
func (c *Cache) Topics() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectTopics)
|
rows, err := c.db.Query(c.queries.selectTopics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -357,9 +330,10 @@ func (c *commonStore) Topics() ([]string, error) {
|
|||||||
return topics, nil
|
return topics, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) DeleteMessages(ids ...string) error {
|
// DeleteMessages deletes the messages with the given IDs
|
||||||
c.mu.Lock()
|
func (c *Cache) DeleteMessages(ids ...string) error {
|
||||||
defer c.mu.Unlock()
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -375,14 +349,15 @@ func (c *commonStore) DeleteMessages(ids ...string) error {
|
|||||||
|
|
||||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||||
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||||
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||||
c.mu.Lock()
|
c.maybeLock()
|
||||||
defer c.mu.Unlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
// First, get the message IDs of scheduled messages to be deleted
|
||||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -399,7 +374,8 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s
|
|||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
rows.Close()
|
rows.Close() // Close rows before executing delete in same transaction
|
||||||
|
// Then delete the messages
|
||||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -409,9 +385,10 @@ func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]s
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) ExpireMessages(topics ...string) error {
|
// ExpireMessages marks messages in the given topics as expired
|
||||||
c.mu.Lock()
|
func (c *Cache) ExpireMessages(topics ...string) error {
|
||||||
defer c.mu.Unlock()
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -425,7 +402,8 @@ func (c *commonStore) ExpireMessages(topics ...string) error {
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||||
|
func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -445,9 +423,10 @@ func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
|
||||||
c.mu.Lock()
|
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||||
defer c.mu.Unlock()
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
tx, err := c.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -461,7 +440,8 @@ func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
|||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||||
|
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -469,7 +449,8 @@ func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error)
|
|||||||
return c.readAttachmentBytesUsed(rows)
|
return c.readAttachmentBytesUsed(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
||||||
|
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -477,7 +458,7 @@ func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
|||||||
return c.readAttachmentBytesUsed(rows)
|
return c.readAttachmentBytesUsed(rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
var size int64
|
var size int64
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
@@ -491,14 +472,16 @@ func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
|||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) UpdateStats(messages int64) error {
|
// UpdateStats updates the total message count statistic
|
||||||
c.mu.Lock()
|
func (c *Cache) UpdateStats(messages int64) error {
|
||||||
defer c.mu.Unlock()
|
c.maybeLock()
|
||||||
|
defer c.maybeUnlock()
|
||||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Stats() (messages int64, err error) {
|
// Stats returns the total message count statistic
|
||||||
|
func (c *Cache) Stats() (messages int64, err error) {
|
||||||
rows, err := c.db.Query(c.queries.selectStats)
|
rows, err := c.db.Query(c.queries.selectStats)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -513,11 +496,12 @@ func (c *commonStore) Stats() (messages int64, err error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) Close() error {
|
// Close closes the underlying database connection
|
||||||
|
func (c *Cache) Close() error {
|
||||||
return c.db.Close()
|
return c.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *commonStore) processMessageBatches() {
|
func (c *Cache) processMessageBatches() {
|
||||||
if c.queue == nil {
|
if c.queue == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -620,9 +604,3 @@ func readMessage(rows *sql.Rows) (*model.Message, error) {
|
|||||||
Encoding: encoding,
|
Encoding: encoding,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure commonStore implements Store
|
|
||||||
var _ Store = (*commonStore)(nil)
|
|
||||||
|
|
||||||
// Needed by store.go but not part of Store interface; unused import guard
|
|
||||||
var _ = fmt.Sprintf
|
|
||||||
110
message/cache_postgres.go
Normal file
110
message/cache_postgres.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL runtime query constants
|
||||||
|
const (
|
||||||
|
postgresInsertMessageQuery = `
|
||||||
|
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
||||||
|
`
|
||||||
|
postgresDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
||||||
|
postgresSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||||
|
postgresDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||||
|
postgresUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
||||||
|
postgresSelectMessagesByIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE mid = $1
|
||||||
|
`
|
||||||
|
postgresSelectMessagesSinceTimeQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
postgresSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND time >= $2
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
postgresSelectMessagesSinceIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1
|
||||||
|
AND id > COALESCE((SELECT id FROM message WHERE mid = $2), 0)
|
||||||
|
AND published = TRUE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
postgresSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1
|
||||||
|
AND (id > COALESCE((SELECT id FROM message WHERE mid = $2), 0) OR published = FALSE)
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
postgresSelectMessagesLatestQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND published = TRUE
|
||||||
|
ORDER BY time DESC, id DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
postgresSelectMessagesDueQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE time <= $1 AND published = FALSE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
postgresSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
||||||
|
postgresUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
||||||
|
postgresSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
||||||
|
postgresSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
||||||
|
|
||||||
|
postgresUpdateAttachmentDeletedQuery = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
||||||
|
postgresSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
||||||
|
postgresSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
||||||
|
postgresSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
||||||
|
|
||||||
|
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||||
|
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||||
|
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||||
|
)
|
||||||
|
|
||||||
|
var pgQueries = queries{
|
||||||
|
insertMessage: postgresInsertMessageQuery,
|
||||||
|
deleteMessage: postgresDeleteMessageQuery,
|
||||||
|
selectScheduledMessageIDsBySeqID: postgresSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
|
deleteScheduledBySequenceID: postgresDeleteScheduledBySequenceIDQuery,
|
||||||
|
updateMessagesForTopicExpiry: postgresUpdateMessagesForTopicExpiryQuery,
|
||||||
|
selectMessagesByID: postgresSelectMessagesByIDQuery,
|
||||||
|
selectMessagesSinceTime: postgresSelectMessagesSinceTimeQuery,
|
||||||
|
selectMessagesSinceTimeScheduled: postgresSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||||
|
selectMessagesSinceID: postgresSelectMessagesSinceIDQuery,
|
||||||
|
selectMessagesSinceIDScheduled: postgresSelectMessagesSinceIDIncludeScheduledQuery,
|
||||||
|
selectMessagesLatest: postgresSelectMessagesLatestQuery,
|
||||||
|
selectMessagesDue: postgresSelectMessagesDueQuery,
|
||||||
|
selectMessagesExpired: postgresSelectMessagesExpiredQuery,
|
||||||
|
updateMessagePublished: postgresUpdateMessagePublishedQuery,
|
||||||
|
selectMessagesCount: postgresSelectMessagesCountQuery,
|
||||||
|
selectTopics: postgresSelectTopicsQuery,
|
||||||
|
updateAttachmentDeleted: postgresUpdateAttachmentDeletedQuery,
|
||||||
|
selectAttachmentsExpired: postgresSelectAttachmentsExpiredQuery,
|
||||||
|
selectAttachmentsSizeBySender: postgresSelectAttachmentsSizeBySenderQuery,
|
||||||
|
selectAttachmentsSizeByUserID: postgresSelectAttachmentsSizeByUserIDQuery,
|
||||||
|
selectStats: postgresSelectStatsQuery,
|
||||||
|
updateStats: postgresUpdateStatsQuery,
|
||||||
|
updateMessageTime: postgresUpdateMessageTimeQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||||
|
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (*Cache, error) {
|
||||||
|
if err := setupPostgres(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newCache(db, pgQueries, nil, batchSize, batchTimeout, false), nil
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
// Initial PostgreSQL schema
|
// Initial PostgreSQL schema
|
||||||
const (
|
const (
|
||||||
pgCreateTablesQuery = `
|
postgresCreateTablesQuery = `
|
||||||
CREATE TABLE IF NOT EXISTS message (
|
CREATE TABLE IF NOT EXISTS message (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
mid TEXT NOT NULL,
|
mid TEXT NOT NULL,
|
||||||
@@ -37,12 +37,10 @@ const (
|
|||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
|
CREATE INDEX IF NOT EXISTS idx_message_topic_published_time ON message (topic, published, time, id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
|
CREATE INDEX IF NOT EXISTS idx_message_published_expires ON message (published, expires);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
|
CREATE INDEX IF NOT EXISTS idx_message_sender_attachment_expires ON message (sender, attachment_expires) WHERE user_id = '';
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
|
CREATE INDEX IF NOT EXISTS idx_message_user_id_attachment_expires ON message (user_id, attachment_expires);
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
|
|
||||||
CREATE TABLE IF NOT EXISTS message_stats (
|
CREATE TABLE IF NOT EXISTS message_stats (
|
||||||
key TEXT PRIMARY KEY,
|
key TEXT PRIMARY KEY,
|
||||||
value BIGINT
|
value BIGINT
|
||||||
@@ -58,13 +56,13 @@ const (
|
|||||||
// PostgreSQL schema management queries
|
// PostgreSQL schema management queries
|
||||||
const (
|
const (
|
||||||
pgCurrentSchemaVersion = 14
|
pgCurrentSchemaVersion = 14
|
||||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupPostgresDB(db *sql.DB) error {
|
func setupPostgres(db *sql.DB) error {
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewPostgresDB(db)
|
return setupNewPostgresDB(db)
|
||||||
}
|
}
|
||||||
@@ -80,10 +78,10 @@ func setupNewPostgresDB(db *sql.DB) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
@@ -20,7 +21,6 @@ const (
|
|||||||
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||||
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||||
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||||
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
|
|
||||||
sqliteSelectMessagesByIDQuery = `
|
sqliteSelectMessagesByIDQuery = `
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
FROM messages
|
FROM messages
|
||||||
@@ -41,13 +41,13 @@ const (
|
|||||||
sqliteSelectMessagesSinceIDQuery = `
|
sqliteSelectMessagesSinceIDQuery = `
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE topic = ? AND id > ? AND published = 1
|
WHERE topic = ? AND id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) AND published = 1
|
||||||
ORDER BY time, id
|
ORDER BY time, id
|
||||||
`
|
`
|
||||||
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE topic = ? AND (id > ? OR published = 0)
|
WHERE topic = ? AND (id > COALESCE((SELECT id FROM messages WHERE mid = ?), 0) OR published = 0)
|
||||||
ORDER BY time, id
|
ORDER BY time, id
|
||||||
`
|
`
|
||||||
sqliteSelectMessagesLatestQuery = `
|
sqliteSelectMessagesLatestQuery = `
|
||||||
@@ -66,10 +66,9 @@ const (
|
|||||||
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||||
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||||
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||||
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
|
||||||
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||||
|
|
||||||
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
sqliteUpdateAttachmentDeletedQuery = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||||
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||||
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||||
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||||
@@ -79,13 +78,12 @@ const (
|
|||||||
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||||
)
|
)
|
||||||
|
|
||||||
var sqliteQueries = storeQueries{
|
var sqliteQueries = queries{
|
||||||
insertMessage: sqliteInsertMessageQuery,
|
insertMessage: sqliteInsertMessageQuery,
|
||||||
deleteMessage: sqliteDeleteMessageQuery,
|
deleteMessage: sqliteDeleteMessageQuery,
|
||||||
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
||||||
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
||||||
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
|
|
||||||
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
||||||
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
||||||
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||||
@@ -96,9 +94,8 @@ var sqliteQueries = storeQueries{
|
|||||||
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
||||||
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
||||||
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
||||||
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
|
|
||||||
selectTopics: sqliteSelectTopicsQuery,
|
selectTopics: sqliteSelectTopicsQuery,
|
||||||
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
|
updateAttachmentDeleted: sqliteUpdateAttachmentDeletedQuery,
|
||||||
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
||||||
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
||||||
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
||||||
@@ -108,7 +105,7 @@ var sqliteQueries = storeQueries{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSQLiteStore creates a SQLite file-backed cache
|
// NewSQLiteStore creates a SQLite file-backed cache
|
||||||
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
|
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*Cache, error) {
|
||||||
parentDir := filepath.Dir(filename)
|
parentDir := filepath.Dir(filename)
|
||||||
if !util.FileExists(parentDir) {
|
if !util.FileExists(parentDir) {
|
||||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||||
@@ -120,21 +117,26 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
|
|||||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
|
return newCache(db, sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemStore creates an in-memory cache
|
// NewMemStore creates an in-memory cache
|
||||||
func NewMemStore() (Store, error) {
|
func NewMemStore() (*Cache, error) {
|
||||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNopStore creates an in-memory cache that discards all messages;
|
// NewNopStore creates an in-memory cache that discards all messages;
|
||||||
// it is always empty and can be used if caching is entirely disabled
|
// it is always empty and can be used if caching is entirely disabled
|
||||||
func NewNopStore() (Store, error) {
|
func NewNopStore() (*Cache, error) {
|
||||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||||
|
// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory
|
||||||
|
// sql database, so if the stdlib's sql engine happens to open another connection and
|
||||||
|
// you've only specified ":memory:", that connection will see a brand new database.
|
||||||
|
// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared").
|
||||||
|
// Every connection to this string will point to the same in-memory database."
|
||||||
func createMemoryFilename() string {
|
func createMemoryFilename() string {
|
||||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||||
}
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
// Initial SQLite schema
|
// Initial SQLite schema
|
||||||
const (
|
const (
|
||||||
sqliteCreateMessagesTableQuery = `
|
sqliteCreateTablesQuery = `
|
||||||
BEGIN;
|
BEGIN;
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -65,8 +65,8 @@ const (
|
|||||||
version INT NOT NULL
|
version INT NOT NULL
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -259,13 +259,13 @@ func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupNewSQLite(db *sql.DB) error {
|
func setupNewSQLite(db *sql.DB) error {
|
||||||
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -288,7 +288,7 @@ func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
|
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, 1); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -299,7 +299,7 @@ func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -310,7 +310,7 @@ func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -321,7 +321,7 @@ func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -332,7 +332,7 @@ func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -343,7 +343,7 @@ func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -354,7 +354,7 @@ func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 7); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -365,7 +365,7 @@ func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 8); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -376,7 +376,7 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
|
if _, err := db.Exec(sqliteUpdateSchemaVersionQuery, 9); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -395,7 +395,7 @@ func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -411,7 +411,7 @@ func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -427,7 +427,7 @@ func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -443,7 +443,7 @@ func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -459,7 +459,7 @@ func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -13,158 +13,6 @@ import (
|
|||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSqliteStore_Messages(t *testing.T) {
|
|
||||||
testCacheMessages(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Messages(t *testing.T) {
|
|
||||||
testCacheMessages(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessagesLock(t *testing.T) {
|
|
||||||
testCacheMessagesLock(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessagesLock(t *testing.T) {
|
|
||||||
testCacheMessagesLock(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessagesScheduled(t *testing.T) {
|
|
||||||
testCacheMessagesScheduled(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessagesScheduled(t *testing.T) {
|
|
||||||
testCacheMessagesScheduled(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Topics(t *testing.T) {
|
|
||||||
testCacheTopics(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Topics(t *testing.T) {
|
|
||||||
testCacheTopics(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessagesSinceID(t *testing.T) {
|
|
||||||
testCacheMessagesSinceID(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessagesSinceID(t *testing.T) {
|
|
||||||
testCacheMessagesSinceID(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Prune(t *testing.T) {
|
|
||||||
testCachePrune(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Prune(t *testing.T) {
|
|
||||||
testCachePrune(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Attachments(t *testing.T) {
|
|
||||||
testCacheAttachments(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Attachments(t *testing.T) {
|
|
||||||
testCacheAttachments(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
|
|
||||||
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_AttachmentsExpired(t *testing.T) {
|
|
||||||
testCacheAttachmentsExpired(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Sender(t *testing.T) {
|
|
||||||
testSender(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Sender(t *testing.T) {
|
|
||||||
testSender(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
|
|
||||||
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
|
|
||||||
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessageByID(t *testing.T) {
|
|
||||||
testMessageByID(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessageByID(t *testing.T) {
|
|
||||||
testMessageByID(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MarkPublished(t *testing.T) {
|
|
||||||
testMarkPublished(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MarkPublished(t *testing.T) {
|
|
||||||
testMarkPublished(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_ExpireMessages(t *testing.T) {
|
|
||||||
testExpireMessages(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_ExpireMessages(t *testing.T) {
|
|
||||||
testExpireMessages(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
|
|
||||||
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
|
|
||||||
testMarkAttachmentsDeleted(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Stats(t *testing.T) {
|
|
||||||
testStats(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_Stats(t *testing.T) {
|
|
||||||
testStats(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_AddMessages(t *testing.T) {
|
|
||||||
testAddMessages(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_AddMessages(t *testing.T) {
|
|
||||||
testAddMessages(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessagesDue(t *testing.T) {
|
|
||||||
testMessagesDue(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessagesDue(t *testing.T) {
|
|
||||||
testMessagesDue(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
|
|
||||||
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
|
|
||||||
testMessageFieldRoundTrip(t, newMemTestStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteStore_Migration_From0(t *testing.T) {
|
func TestSqliteStore_Migration_From0(t *testing.T) {
|
||||||
filename := newSqliteTestStoreFile(t)
|
filename := newSqliteTestStoreFile(t)
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
@@ -419,32 +267,17 @@ func TestNopStore(t *testing.T) {
|
|||||||
require.Empty(t, topics)
|
require.Empty(t, topics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSqliteTestStore(t *testing.T) message.Store {
|
|
||||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
|
||||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() { s.Close() })
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSqliteTestStoreFile(t *testing.T) string {
|
func newSqliteTestStoreFile(t *testing.T) string {
|
||||||
return filepath.Join(t.TempDir(), "cache.db")
|
return filepath.Join(t.TempDir(), "cache.db")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
|
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) *message.Cache {
|
||||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() { s.Close() })
|
t.Cleanup(func() { s.Close() })
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMemTestStore(t *testing.T) message.Store {
|
|
||||||
s, err := message.NewMemStore()
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() { s.Close() })
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
829
message/cache_test.go
Normal file
829
message/cache_test.go
Normal file
@@ -0,0 +1,829 @@
|
|||||||
|
package message_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSqliteTestStore(t *testing.T) *message.Cache {
|
||||||
|
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||||
|
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemTestStore(t *testing.T) *message.Cache {
|
||||||
|
s, err := message.NewMemStore()
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestPostgresStore(t *testing.T) *message.Cache {
|
||||||
|
testDB := dbtest.CreateTestPostgres(t)
|
||||||
|
store, err := message.NewPostgresStore(testDB, 0, 0)
|
||||||
|
require.Nil(t, err)
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func forEachBackend(t *testing.T, f func(t *testing.T, s *message.Cache)) {
|
||||||
|
t.Run("sqlite", func(t *testing.T) {
|
||||||
|
f(t, newSqliteTestStore(t))
|
||||||
|
})
|
||||||
|
t.Run("mem", func(t *testing.T) {
|
||||||
|
f(t, newMemTestStore(t))
|
||||||
|
})
|
||||||
|
t.Run("postgres", func(t *testing.T) {
|
||||||
|
f(t, newTestPostgresStore(t))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Messages(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
|
m1.Time = 1
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||||
|
m2.Time = 2
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Adding invalid
|
||||||
|
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
||||||
|
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
||||||
|
|
||||||
|
// count
|
||||||
|
count, err := s.MessagesCount()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, count)
|
||||||
|
|
||||||
|
// mytopic: since all
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, "my message", messages[0].Message)
|
||||||
|
require.Equal(t, "mytopic", messages[0].Topic)
|
||||||
|
require.Equal(t, model.MessageEvent, messages[0].Event)
|
||||||
|
require.Equal(t, "", messages[0].Title)
|
||||||
|
require.Equal(t, 0, messages[0].Priority)
|
||||||
|
require.Nil(t, messages[0].Tags)
|
||||||
|
require.Equal(t, "my other message", messages[1].Message)
|
||||||
|
|
||||||
|
// mytopic: since none
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
||||||
|
require.Empty(t, messages)
|
||||||
|
|
||||||
|
// mytopic: since m1 (by ID)
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, m2.ID, messages[0].ID)
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
require.Equal(t, "mytopic", messages[0].Topic)
|
||||||
|
|
||||||
|
// mytopic: since 2
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
|
||||||
|
// mytopic: latest
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
|
||||||
|
// example: since all
|
||||||
|
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, "my example message", messages[0].Message)
|
||||||
|
|
||||||
|
// non-existing: since all
|
||||||
|
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
||||||
|
require.Empty(t, messages)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessagesLock(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5000; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessagesScheduled(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
|
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||||
|
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||||
|
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
||||||
|
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
||||||
|
require.Equal(t, 3, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||||
|
require.Equal(t, "message 2", messages[2].Message)
|
||||||
|
|
||||||
|
messages, _ = s.MessagesDue()
|
||||||
|
require.Empty(t, messages)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Topics(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
||||||
|
|
||||||
|
topics, err := s.Topics()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
require.Equal(t, 2, len(topics))
|
||||||
|
require.Contains(t, topics, "topic1")
|
||||||
|
require.Contains(t, topics, "topic2")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
|
m.Priority = 5
|
||||||
|
m.Title = "some title"
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||||
|
require.Equal(t, 5, messages[0].Priority)
|
||||||
|
require.Equal(t, "some title", messages[0].Title)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessagesSinceID(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
|
m1.Time = 100
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
|
m2.Time = 200
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||||
|
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||||
|
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
||||||
|
m4.Time = 400
|
||||||
|
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
||||||
|
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||||
|
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
||||||
|
m6.Time = 600
|
||||||
|
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
||||||
|
m7.Time = 700
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
require.Nil(t, s.AddMessage(m4))
|
||||||
|
require.Nil(t, s.AddMessage(m5))
|
||||||
|
require.Nil(t, s.AddMessage(m6))
|
||||||
|
require.Nil(t, s.AddMessage(m7))
|
||||||
|
|
||||||
|
// Case 1: Since ID exists, exclude scheduled
|
||||||
|
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
||||||
|
require.Equal(t, 3, len(messages))
|
||||||
|
require.Equal(t, "message 4", messages[0].Message)
|
||||||
|
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||||
|
require.Equal(t, "message 7", messages[2].Message)
|
||||||
|
|
||||||
|
// Case 2: Since ID exists, include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
||||||
|
require.Equal(t, 5, len(messages))
|
||||||
|
require.Equal(t, "message 4", messages[0].Message)
|
||||||
|
require.Equal(t, "message 6", messages[1].Message)
|
||||||
|
require.Equal(t, "message 7", messages[2].Message)
|
||||||
|
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||||
|
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||||
|
|
||||||
|
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
||||||
|
require.Equal(t, 7, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
require.Equal(t, "message 2", messages[1].Message)
|
||||||
|
require.Equal(t, "message 4", messages[2].Message)
|
||||||
|
require.Equal(t, "message 6", messages[3].Message)
|
||||||
|
require.Equal(t, "message 7", messages[4].Message)
|
||||||
|
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||||
|
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||||
|
|
||||||
|
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
||||||
|
require.Equal(t, 0, len(messages))
|
||||||
|
|
||||||
|
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, "message 5", messages[0].Message)
|
||||||
|
require.Equal(t, "message 3", messages[1].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Prune(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
|
m1.Time = now - 10
|
||||||
|
m1.Expires = now - 5
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||||
|
m2.Time = now - 5
|
||||||
|
m2.Expires = now + 5 // In the future
|
||||||
|
|
||||||
|
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
||||||
|
m3.Time = now - 12
|
||||||
|
m3.Expires = now - 2
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
count, err := s.MessagesCount()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, count)
|
||||||
|
|
||||||
|
expiredMessageIDs, err := s.MessagesExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
||||||
|
|
||||||
|
count, err = s.MessagesCount()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Attachments(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||||
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
|
m.ID = "m1"
|
||||||
|
m.SequenceID = "m1"
|
||||||
|
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "flower.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 5000,
|
||||||
|
Expires: expires1,
|
||||||
|
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||||
|
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
||||||
|
m.ID = "m2"
|
||||||
|
m.SequenceID = "m2"
|
||||||
|
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 10000,
|
||||||
|
Expires: expires2,
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||||
|
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
||||||
|
m.ID = "m3"
|
||||||
|
m.SequenceID = "m3"
|
||||||
|
m.User = "u_BAsbaAa"
|
||||||
|
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "another-car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 20000,
|
||||||
|
Expires: expires3,
|
||||||
|
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
require.Equal(t, "flower for you", messages[0].Message)
|
||||||
|
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||||
|
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||||
|
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||||
|
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||||
|
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||||
|
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||||
|
|
||||||
|
require.Equal(t, "sending you a car", messages[1].Message)
|
||||||
|
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||||
|
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||||
|
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||||
|
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||||
|
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||||
|
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||||
|
|
||||||
|
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(10000), size)
|
||||||
|
|
||||||
|
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||||
|
|
||||||
|
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(20000), size)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_AttachmentsExpired(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
|
m.ID = "m1"
|
||||||
|
m.SequenceID = "m1"
|
||||||
|
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
||||||
|
m.ID = "m2"
|
||||||
|
m.SequenceID = "m2"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 10000,
|
||||||
|
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
||||||
|
m.ID = "m3"
|
||||||
|
m.SequenceID = "m3"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Expires: 0, // Unknown!
|
||||||
|
URL: "https://somedomain.com/car.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
||||||
|
m.ID = "m4"
|
||||||
|
m.SequenceID = "m4"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "expired-car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 20000,
|
||||||
|
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
ids, err := s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(ids))
|
||||||
|
require.Equal(t, "m4", ids[0])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Sender(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||||
|
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||||
|
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Create a scheduled (unpublished) message
|
||||||
|
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
|
scheduledMsg.ID = "scheduled1"
|
||||||
|
scheduledMsg.SequenceID = "seq123"
|
||||||
|
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||||
|
require.Nil(t, s.AddMessage(scheduledMsg))
|
||||||
|
|
||||||
|
// Create a published message with different sequence ID
|
||||||
|
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
||||||
|
publishedMsg.ID = "published1"
|
||||||
|
publishedMsg.SequenceID = "seq456"
|
||||||
|
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||||
|
require.Nil(t, s.AddMessage(publishedMsg))
|
||||||
|
|
||||||
|
// Create a scheduled message in a different topic
|
||||||
|
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
||||||
|
otherTopicMsg.ID = "other1"
|
||||||
|
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||||
|
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(otherTopicMsg))
|
||||||
|
|
||||||
|
// Verify all messages exist (including scheduled)
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Delete scheduled message by sequence ID and verify returned IDs
|
||||||
|
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(deletedIDs))
|
||||||
|
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||||
|
|
||||||
|
// Verify scheduled message is deleted
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "published message", messages[0].Message)
|
||||||
|
|
||||||
|
// Verify other topic's message still exists (topic-scoped deletion)
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "other scheduled", messages[0].Message)
|
||||||
|
|
||||||
|
// Deleting non-existent sequence ID should return empty list
|
||||||
|
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, deletedIDs)
|
||||||
|
|
||||||
|
// Deleting published message should not affect it (only deletes unpublished)
|
||||||
|
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, deletedIDs)
|
||||||
|
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "published message", messages[0].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessageByID(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Add a message
|
||||||
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
|
m.Title = "some title"
|
||||||
|
m.Priority = 4
|
||||||
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Retrieve by ID
|
||||||
|
retrieved, err := s.Message(m.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, m.ID, retrieved.ID)
|
||||||
|
require.Equal(t, "mytopic", retrieved.Topic)
|
||||||
|
require.Equal(t, "some message", retrieved.Message)
|
||||||
|
require.Equal(t, "some title", retrieved.Title)
|
||||||
|
require.Equal(t, 4, retrieved.Priority)
|
||||||
|
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
||||||
|
|
||||||
|
// Non-existent ID returns ErrMessageNotFound
|
||||||
|
_, err = s.Message("doesnotexist")
|
||||||
|
require.Equal(t, model.ErrMessageNotFound, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MarkPublished(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Add a scheduled message (future time -> unpublished)
|
||||||
|
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
|
m.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Verify it does not appear in non-scheduled queries
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(messages))
|
||||||
|
|
||||||
|
// Verify it does appear in scheduled queries
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Mark as published
|
||||||
|
require.Nil(t, s.MarkPublished(m))
|
||||||
|
|
||||||
|
// Now it should appear in non-scheduled queries too
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "scheduled message", messages[0].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_ExpireMessages(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Add messages to two topics
|
||||||
|
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m2 := model.NewDefaultMessage("topic1", "message 2")
|
||||||
|
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m3 := model.NewDefaultMessage("topic2", "message 3")
|
||||||
|
m3.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
// Verify all messages exist
|
||||||
|
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Expire topic1 messages
|
||||||
|
require.Nil(t, s.ExpireMessages("topic1"))
|
||||||
|
|
||||||
|
// topic1 messages should now be expired (expires set to past)
|
||||||
|
expiredIDs, err := s.MessagesExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(expiredIDs))
|
||||||
|
sort.Strings(expiredIDs)
|
||||||
|
expectedIDs := []string{m1.ID, m2.ID}
|
||||||
|
sort.Strings(expectedIDs)
|
||||||
|
require.Equal(t, expectedIDs, expiredIDs)
|
||||||
|
|
||||||
|
// topic2 should be unaffected
|
||||||
|
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "message 3", messages[0].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Add a message with an expired attachment (file needs cleanup)
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||||
|
m1.ID = "msg1"
|
||||||
|
m1.SequenceID = "msg1"
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m1.Attachment = &model.Attachment{
|
||||||
|
Name: "old.pdf",
|
||||||
|
Type: "application/pdf",
|
||||||
|
Size: 50000,
|
||||||
|
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||||
|
URL: "https://ntfy.sh/file/old.pdf",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
// Add a message with another expired attachment
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
||||||
|
m2.ID = "msg2"
|
||||||
|
m2.SequenceID = "msg2"
|
||||||
|
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m2.Attachment = &model.Attachment{
|
||||||
|
Name: "another.pdf",
|
||||||
|
Type: "application/pdf",
|
||||||
|
Size: 30000,
|
||||||
|
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||||
|
URL: "https://ntfy.sh/file/another.pdf",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Both should show as expired attachments needing cleanup
|
||||||
|
ids, err := s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(ids))
|
||||||
|
|
||||||
|
// Mark msg1's attachment as deleted (file cleaned up)
|
||||||
|
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
||||||
|
|
||||||
|
// Now only msg2 should show as needing cleanup
|
||||||
|
ids, err = s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(ids))
|
||||||
|
require.Equal(t, "msg2", ids[0])
|
||||||
|
|
||||||
|
// Mark msg2 too
|
||||||
|
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
||||||
|
|
||||||
|
// No more expired attachments to clean up
|
||||||
|
ids, err = s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(ids))
|
||||||
|
|
||||||
|
// Messages themselves still exist
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_Stats(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Initial stats should be zero
|
||||||
|
messages, err := s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), messages)
|
||||||
|
|
||||||
|
// Update stats
|
||||||
|
require.Nil(t, s.UpdateStats(42))
|
||||||
|
messages, err = s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(42), messages)
|
||||||
|
|
||||||
|
// Update again (overwrites)
|
||||||
|
require.Nil(t, s.UpdateStats(100))
|
||||||
|
messages, err = s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(100), messages)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_AddMessages(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Batch add multiple messages
|
||||||
|
msgs := []*model.Message{
|
||||||
|
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||||
|
model.NewDefaultMessage("mytopic", "batch 2"),
|
||||||
|
model.NewDefaultMessage("othertopic", "batch 3"),
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessages(msgs))
|
||||||
|
|
||||||
|
// Verify all were inserted
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "batch 3", messages[0].Message)
|
||||||
|
|
||||||
|
// Empty batch should succeed
|
||||||
|
require.Nil(t, s.AddMessages([]*model.Message{}))
|
||||||
|
|
||||||
|
// Batch with invalid event type should fail
|
||||||
|
badMsgs := []*model.Message{
|
||||||
|
model.NewKeepaliveMessage("mytopic"),
|
||||||
|
}
|
||||||
|
require.NotNil(t, s.AddMessages(badMsgs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessagesDue(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Add a message scheduled in the past (i.e. it's due now)
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||||
|
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||||
|
// Set expires in the future so it doesn't get pruned
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
// Add a message scheduled in the future (not due)
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "future message")
|
||||||
|
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Mark m1 as published so it won't be "due"
|
||||||
|
// (MessagesDue returns unpublished messages whose time <= now)
|
||||||
|
// m1 is auto-published (time <= now), so it should not be due
|
||||||
|
// m2 is unpublished (time in future), not due yet
|
||||||
|
due, err := s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(due))
|
||||||
|
|
||||||
|
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
||||||
|
// We need to manipulate the database to create a truly "due" message:
|
||||||
|
// a message with published=false and time <= now
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
||||||
|
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
// Not due yet
|
||||||
|
due, err = s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(due))
|
||||||
|
|
||||||
|
// Wait for it to become due
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
due, err = s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(due))
|
||||||
|
require.Equal(t, "truly due message", due[0].Message)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_MessageFieldRoundTrip(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, s *message.Cache) {
|
||||||
|
// Create a message with all fields populated
|
||||||
|
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||||
|
m.SequenceID = "custom_seq_id"
|
||||||
|
m.Title = "A Title"
|
||||||
|
m.Priority = 4
|
||||||
|
m.Tags = []string{"warning", "srv01"}
|
||||||
|
m.Click = "https://example.com/click"
|
||||||
|
m.Icon = "https://example.com/icon.png"
|
||||||
|
m.Actions = []*model.Action{
|
||||||
|
{
|
||||||
|
ID: "action1",
|
||||||
|
Action: "view",
|
||||||
|
Label: "Open Site",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Clear: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "action2",
|
||||||
|
Action: "http",
|
||||||
|
Label: "Call Webhook",
|
||||||
|
URL: "https://example.com/hook",
|
||||||
|
Method: "PUT",
|
||||||
|
Headers: map[string]string{"X-Token": "secret"},
|
||||||
|
Body: `{"key":"value"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m.ContentType = "text/markdown"
|
||||||
|
m.Encoding = "base64"
|
||||||
|
m.Sender = netip.MustParseAddr("9.8.7.6")
|
||||||
|
m.User = "u_TestUser123"
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Retrieve and verify every field
|
||||||
|
retrieved, err := s.Message(m.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, m.ID, retrieved.ID)
|
||||||
|
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
||||||
|
require.Equal(t, m.Time, retrieved.Time)
|
||||||
|
require.Equal(t, m.Expires, retrieved.Expires)
|
||||||
|
require.Equal(t, model.MessageEvent, retrieved.Event)
|
||||||
|
require.Equal(t, "mytopic", retrieved.Topic)
|
||||||
|
require.Equal(t, "hello world", retrieved.Message)
|
||||||
|
require.Equal(t, "A Title", retrieved.Title)
|
||||||
|
require.Equal(t, 4, retrieved.Priority)
|
||||||
|
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
||||||
|
require.Equal(t, "https://example.com/click", retrieved.Click)
|
||||||
|
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
||||||
|
require.Equal(t, "text/markdown", retrieved.ContentType)
|
||||||
|
require.Equal(t, "base64", retrieved.Encoding)
|
||||||
|
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
||||||
|
require.Equal(t, "u_TestUser123", retrieved.User)
|
||||||
|
|
||||||
|
// Verify actions round-trip
|
||||||
|
require.Equal(t, 2, len(retrieved.Actions))
|
||||||
|
|
||||||
|
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
||||||
|
require.Equal(t, "view", retrieved.Actions[0].Action)
|
||||||
|
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
||||||
|
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
||||||
|
require.Equal(t, true, retrieved.Actions[0].Clear)
|
||||||
|
|
||||||
|
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
||||||
|
require.Equal(t, "http", retrieved.Actions[1].Action)
|
||||||
|
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
||||||
|
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
||||||
|
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
||||||
|
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
||||||
|
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostgreSQL runtime query constants
|
|
||||||
const (
|
|
||||||
pgInsertMessageQuery = `
|
|
||||||
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
|
||||||
`
|
|
||||||
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
|
||||||
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
|
||||||
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
|
||||||
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
|
||||||
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
|
|
||||||
pgSelectMessagesByIDQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE mid = $1
|
|
||||||
`
|
|
||||||
pgSelectMessagesSinceTimeQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
|
||||||
ORDER BY time, id
|
|
||||||
`
|
|
||||||
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE topic = $1 AND time >= $2
|
|
||||||
ORDER BY time, id
|
|
||||||
`
|
|
||||||
pgSelectMessagesSinceIDQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE topic = $1 AND id > $2 AND published = TRUE
|
|
||||||
ORDER BY time, id
|
|
||||||
`
|
|
||||||
pgSelectMessagesSinceIDIncludeScheduledQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE topic = $1 AND (id > $2 OR published = FALSE)
|
|
||||||
ORDER BY time, id
|
|
||||||
`
|
|
||||||
pgSelectMessagesLatestQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE topic = $1 AND published = TRUE
|
|
||||||
ORDER BY time DESC, id DESC
|
|
||||||
LIMIT 1
|
|
||||||
`
|
|
||||||
pgSelectMessagesDueQuery = `
|
|
||||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
|
||||||
FROM message
|
|
||||||
WHERE time <= $1 AND published = FALSE
|
|
||||||
ORDER BY time, id
|
|
||||||
`
|
|
||||||
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
|
||||||
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
|
||||||
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
|
||||||
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
|
|
||||||
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
|
||||||
|
|
||||||
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
|
||||||
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
|
||||||
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
|
||||||
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
|
||||||
|
|
||||||
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
|
||||||
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
|
||||||
pgUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
|
||||||
)
|
|
||||||
|
|
||||||
var pgQueries = storeQueries{
|
|
||||||
insertMessage: pgInsertMessageQuery,
|
|
||||||
deleteMessage: pgDeleteMessageQuery,
|
|
||||||
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
|
|
||||||
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
|
|
||||||
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
|
|
||||||
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
|
|
||||||
selectMessagesByID: pgSelectMessagesByIDQuery,
|
|
||||||
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
|
|
||||||
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
|
|
||||||
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
|
|
||||||
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
|
|
||||||
selectMessagesLatest: pgSelectMessagesLatestQuery,
|
|
||||||
selectMessagesDue: pgSelectMessagesDueQuery,
|
|
||||||
selectMessagesExpired: pgSelectMessagesExpiredQuery,
|
|
||||||
updateMessagePublished: pgUpdateMessagePublishedQuery,
|
|
||||||
selectMessagesCount: pgSelectMessagesCountQuery,
|
|
||||||
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
|
|
||||||
selectTopics: pgSelectTopicsQuery,
|
|
||||||
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
|
|
||||||
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
|
|
||||||
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
|
|
||||||
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
|
|
||||||
selectStats: pgSelectStatsQuery,
|
|
||||||
updateStats: pgUpdateStatsQuery,
|
|
||||||
updateMessageTime: pgUpdateMessageTimesQuery,
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
|
|
||||||
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
|
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgresDB(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
|
|
||||||
}
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
package message_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/message"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) message.Store {
|
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
store, err := message.NewPostgresStore(u.String(), 0, 0)
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Messages(t *testing.T) {
|
|
||||||
testCacheMessages(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessagesLock(t *testing.T) {
|
|
||||||
testCacheMessagesLock(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessagesScheduled(t *testing.T) {
|
|
||||||
testCacheMessagesScheduled(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Topics(t *testing.T) {
|
|
||||||
testCacheTopics(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessagesSinceID(t *testing.T) {
|
|
||||||
testCacheMessagesSinceID(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Prune(t *testing.T) {
|
|
||||||
testCachePrune(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Attachments(t *testing.T) {
|
|
||||||
testCacheAttachments(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
|
|
||||||
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Sender(t *testing.T) {
|
|
||||||
testSender(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
|
|
||||||
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessageByID(t *testing.T) {
|
|
||||||
testMessageByID(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MarkPublished(t *testing.T) {
|
|
||||||
testMarkPublished(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_ExpireMessages(t *testing.T) {
|
|
||||||
testExpireMessages(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
|
|
||||||
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_Stats(t *testing.T) {
|
|
||||||
testStats(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_AddMessages(t *testing.T) {
|
|
||||||
testAddMessages(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessagesDue(t *testing.T) {
|
|
||||||
testMessagesDue(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
|
|
||||||
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
@@ -1,767 +0,0 @@
|
|||||||
package message_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/message"
|
|
||||||
"heckel.io/ntfy/v2/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testCacheMessages(t *testing.T, s message.Store) {
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
|
||||||
m1.Time = 1
|
|
||||||
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
|
||||||
m2.Time = 2
|
|
||||||
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
|
|
||||||
// Adding invalid
|
|
||||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
|
||||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
|
||||||
|
|
||||||
// mytopic: count
|
|
||||||
counts, err := s.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, counts["mytopic"])
|
|
||||||
|
|
||||||
// mytopic: since all
|
|
||||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, "my message", messages[0].Message)
|
|
||||||
require.Equal(t, "mytopic", messages[0].Topic)
|
|
||||||
require.Equal(t, model.MessageEvent, messages[0].Event)
|
|
||||||
require.Equal(t, "", messages[0].Title)
|
|
||||||
require.Equal(t, 0, messages[0].Priority)
|
|
||||||
require.Nil(t, messages[0].Tags)
|
|
||||||
require.Equal(t, "my other message", messages[1].Message)
|
|
||||||
|
|
||||||
// mytopic: since none
|
|
||||||
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
|
||||||
require.Empty(t, messages)
|
|
||||||
|
|
||||||
// mytopic: since m1 (by ID)
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, m2.ID, messages[0].ID)
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
require.Equal(t, "mytopic", messages[0].Topic)
|
|
||||||
|
|
||||||
// mytopic: since 2
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
|
|
||||||
// mytopic: latest
|
|
||||||
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
|
|
||||||
// example: count
|
|
||||||
counts, err = s.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, counts["example"])
|
|
||||||
|
|
||||||
// example: since all
|
|
||||||
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
|
||||||
require.Equal(t, "my example message", messages[0].Message)
|
|
||||||
|
|
||||||
// non-existing: count
|
|
||||||
counts, err = s.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, counts["doesnotexist"])
|
|
||||||
|
|
||||||
// non-existing: since all
|
|
||||||
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
|
||||||
require.Empty(t, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesLock(t *testing.T, s message.Store) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 5000; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
|
||||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
|
||||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
|
||||||
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
|
||||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
require.Nil(t, s.AddMessage(m3))
|
|
||||||
|
|
||||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
|
|
||||||
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
|
||||||
require.Equal(t, 3, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
|
||||||
require.Equal(t, "message 2", messages[2].Message)
|
|
||||||
|
|
||||||
messages, _ = s.MessagesDue()
|
|
||||||
require.Empty(t, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheTopics(t *testing.T, s message.Store) {
|
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
|
||||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
|
||||||
|
|
||||||
topics, err := s.Topics()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
require.Equal(t, 2, len(topics))
|
|
||||||
require.Contains(t, topics, "topic1")
|
|
||||||
require.Contains(t, topics, "topic2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
|
|
||||||
m := model.NewDefaultMessage("mytopic", "some message")
|
|
||||||
m.Tags = []string{"tag1", "tag2"}
|
|
||||||
m.Priority = 5
|
|
||||||
m.Title = "some title"
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
|
||||||
require.Equal(t, 5, messages[0].Priority)
|
|
||||||
require.Equal(t, "some title", messages[0].Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
|
||||||
m1.Time = 100
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
|
||||||
m2.Time = 200
|
|
||||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
|
||||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
|
||||||
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
|
||||||
m4.Time = 400
|
|
||||||
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
|
||||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
|
||||||
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
|
||||||
m6.Time = 600
|
|
||||||
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
|
||||||
m7.Time = 700
|
|
||||||
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
require.Nil(t, s.AddMessage(m3))
|
|
||||||
require.Nil(t, s.AddMessage(m4))
|
|
||||||
require.Nil(t, s.AddMessage(m5))
|
|
||||||
require.Nil(t, s.AddMessage(m6))
|
|
||||||
require.Nil(t, s.AddMessage(m7))
|
|
||||||
|
|
||||||
// Case 1: Since ID exists, exclude scheduled
|
|
||||||
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
|
||||||
require.Equal(t, 3, len(messages))
|
|
||||||
require.Equal(t, "message 4", messages[0].Message)
|
|
||||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
|
||||||
require.Equal(t, "message 7", messages[2].Message)
|
|
||||||
|
|
||||||
// Case 2: Since ID exists, include scheduled
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
|
||||||
require.Equal(t, 5, len(messages))
|
|
||||||
require.Equal(t, "message 4", messages[0].Message)
|
|
||||||
require.Equal(t, "message 6", messages[1].Message)
|
|
||||||
require.Equal(t, "message 7", messages[2].Message)
|
|
||||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
|
||||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
|
||||||
|
|
||||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
|
||||||
require.Equal(t, 7, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
require.Equal(t, "message 2", messages[1].Message)
|
|
||||||
require.Equal(t, "message 4", messages[2].Message)
|
|
||||||
require.Equal(t, "message 6", messages[3].Message)
|
|
||||||
require.Equal(t, "message 7", messages[4].Message)
|
|
||||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
|
||||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
|
||||||
|
|
||||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
|
||||||
require.Equal(t, 0, len(messages))
|
|
||||||
|
|
||||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
|
||||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, "message 5", messages[0].Message)
|
|
||||||
require.Equal(t, "message 3", messages[1].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCachePrune(t *testing.T, s message.Store) {
|
|
||||||
now := time.Now().Unix()
|
|
||||||
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
|
||||||
m1.Time = now - 10
|
|
||||||
m1.Expires = now - 5
|
|
||||||
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
|
||||||
m2.Time = now - 5
|
|
||||||
m2.Expires = now + 5 // In the future
|
|
||||||
|
|
||||||
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
|
||||||
m3.Time = now - 12
|
|
||||||
m3.Expires = now - 2
|
|
||||||
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
require.Nil(t, s.AddMessage(m3))
|
|
||||||
|
|
||||||
counts, err := s.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, counts["mytopic"])
|
|
||||||
require.Equal(t, 1, counts["another_topic"])
|
|
||||||
|
|
||||||
expiredMessageIDs, err := s.MessagesExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
|
||||||
|
|
||||||
counts, err = s.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, counts["mytopic"])
|
|
||||||
require.Equal(t, 0, counts["another_topic"])
|
|
||||||
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheAttachments(t *testing.T, s message.Store) {
|
|
||||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
|
||||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
|
||||||
m.ID = "m1"
|
|
||||||
m.SequenceID = "m1"
|
|
||||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "flower.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 5000,
|
|
||||||
Expires: expires1,
|
|
||||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
|
||||||
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
|
||||||
m.ID = "m2"
|
|
||||||
m.SequenceID = "m2"
|
|
||||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 10000,
|
|
||||||
Expires: expires2,
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
|
||||||
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
|
||||||
m.ID = "m3"
|
|
||||||
m.SequenceID = "m3"
|
|
||||||
m.User = "u_BAsbaAa"
|
|
||||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "another-car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 20000,
|
|
||||||
Expires: expires3,
|
|
||||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
|
|
||||||
require.Equal(t, "flower for you", messages[0].Message)
|
|
||||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
|
||||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
|
||||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
|
||||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
|
||||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
|
||||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
|
||||||
|
|
||||||
require.Equal(t, "sending you a car", messages[1].Message)
|
|
||||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
|
||||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
|
||||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
|
||||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
|
||||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
|
||||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
|
||||||
|
|
||||||
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(10000), size)
|
|
||||||
|
|
||||||
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
|
||||||
|
|
||||||
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(20000), size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
|
|
||||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
|
||||||
m.ID = "m1"
|
|
||||||
m.SequenceID = "m1"
|
|
||||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
|
||||||
m.ID = "m2"
|
|
||||||
m.SequenceID = "m2"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 10000,
|
|
||||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
|
||||||
m.ID = "m3"
|
|
||||||
m.SequenceID = "m3"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Expires: 0, // Unknown!
|
|
||||||
URL: "https://somedomain.com/car.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
|
||||||
m.ID = "m4"
|
|
||||||
m.SequenceID = "m4"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &model.Attachment{
|
|
||||||
Name: "expired-car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 20000,
|
|
||||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
ids, err := s.AttachmentsExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(ids))
|
|
||||||
require.Equal(t, "m4", ids[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSender(t *testing.T, s message.Store) {
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
|
||||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
|
||||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
|
|
||||||
// Create a scheduled (unpublished) message
|
|
||||||
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
|
||||||
scheduledMsg.ID = "scheduled1"
|
|
||||||
scheduledMsg.SequenceID = "seq123"
|
|
||||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
|
||||||
require.Nil(t, s.AddMessage(scheduledMsg))
|
|
||||||
|
|
||||||
// Create a published message with different sequence ID
|
|
||||||
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
|
||||||
publishedMsg.ID = "published1"
|
|
||||||
publishedMsg.SequenceID = "seq456"
|
|
||||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
|
||||||
require.Nil(t, s.AddMessage(publishedMsg))
|
|
||||||
|
|
||||||
// Create a scheduled message in a different topic
|
|
||||||
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
|
||||||
otherTopicMsg.ID = "other1"
|
|
||||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
|
||||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(otherTopicMsg))
|
|
||||||
|
|
||||||
// Verify all messages exist (including scheduled)
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
|
|
||||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
|
|
||||||
// Delete scheduled message by sequence ID and verify returned IDs
|
|
||||||
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(deletedIDs))
|
|
||||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
|
||||||
|
|
||||||
// Verify scheduled message is deleted
|
|
||||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "published message", messages[0].Message)
|
|
||||||
|
|
||||||
// Verify other topic's message still exists (topic-scoped deletion)
|
|
||||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "other scheduled", messages[0].Message)
|
|
||||||
|
|
||||||
// Deleting non-existent sequence ID should return empty list
|
|
||||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, deletedIDs)
|
|
||||||
|
|
||||||
// Deleting published message should not affect it (only deletes unpublished)
|
|
||||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, deletedIDs)
|
|
||||||
|
|
||||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "published message", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMessageByID(t *testing.T, s message.Store) {
|
|
||||||
// Add a message
|
|
||||||
m := model.NewDefaultMessage("mytopic", "some message")
|
|
||||||
m.Title = "some title"
|
|
||||||
m.Priority = 4
|
|
||||||
m.Tags = []string{"tag1", "tag2"}
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
// Retrieve by ID
|
|
||||||
retrieved, err := s.Message(m.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, m.ID, retrieved.ID)
|
|
||||||
require.Equal(t, "mytopic", retrieved.Topic)
|
|
||||||
require.Equal(t, "some message", retrieved.Message)
|
|
||||||
require.Equal(t, "some title", retrieved.Title)
|
|
||||||
require.Equal(t, 4, retrieved.Priority)
|
|
||||||
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
|
||||||
|
|
||||||
// Non-existent ID returns ErrMessageNotFound
|
|
||||||
_, err = s.Message("doesnotexist")
|
|
||||||
require.Equal(t, model.ErrMessageNotFound, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMarkPublished(t *testing.T, s message.Store) {
|
|
||||||
// Add a scheduled message (future time → unpublished)
|
|
||||||
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
|
||||||
m.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
// Verify it does not appear in non-scheduled queries
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, len(messages))
|
|
||||||
|
|
||||||
// Verify it does appear in scheduled queries
|
|
||||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
|
|
||||||
// Mark as published
|
|
||||||
require.Nil(t, s.MarkPublished(m))
|
|
||||||
|
|
||||||
// Now it should appear in non-scheduled queries too
|
|
||||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "scheduled message", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testExpireMessages(t *testing.T, s message.Store) {
|
|
||||||
// Add messages to two topics
|
|
||||||
m1 := model.NewDefaultMessage("topic1", "message 1")
|
|
||||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
m2 := model.NewDefaultMessage("topic1", "message 2")
|
|
||||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
m3 := model.NewDefaultMessage("topic2", "message 3")
|
|
||||||
m3.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
require.Nil(t, s.AddMessage(m3))
|
|
||||||
|
|
||||||
// Verify all messages exist
|
|
||||||
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
|
|
||||||
// Expire topic1 messages
|
|
||||||
require.Nil(t, s.ExpireMessages("topic1"))
|
|
||||||
|
|
||||||
// topic1 messages should now be expired (expires set to past)
|
|
||||||
expiredIDs, err := s.MessagesExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(expiredIDs))
|
|
||||||
sort.Strings(expiredIDs)
|
|
||||||
expectedIDs := []string{m1.ID, m2.ID}
|
|
||||||
sort.Strings(expectedIDs)
|
|
||||||
require.Equal(t, expectedIDs, expiredIDs)
|
|
||||||
|
|
||||||
// topic2 should be unaffected
|
|
||||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "message 3", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
|
|
||||||
// Add a message with an expired attachment (file needs cleanup)
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "old file")
|
|
||||||
m1.ID = "msg1"
|
|
||||||
m1.SequenceID = "msg1"
|
|
||||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
m1.Attachment = &model.Attachment{
|
|
||||||
Name: "old.pdf",
|
|
||||||
Type: "application/pdf",
|
|
||||||
Size: 50000,
|
|
||||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
|
||||||
URL: "https://ntfy.sh/file/old.pdf",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
|
|
||||||
// Add a message with another expired attachment
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
|
||||||
m2.ID = "msg2"
|
|
||||||
m2.SequenceID = "msg2"
|
|
||||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
m2.Attachment = &model.Attachment{
|
|
||||||
Name: "another.pdf",
|
|
||||||
Type: "application/pdf",
|
|
||||||
Size: 30000,
|
|
||||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
|
||||||
URL: "https://ntfy.sh/file/another.pdf",
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
|
|
||||||
// Both should show as expired attachments needing cleanup
|
|
||||||
ids, err := s.AttachmentsExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(ids))
|
|
||||||
|
|
||||||
// Mark msg1's attachment as deleted (file cleaned up)
|
|
||||||
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
|
||||||
|
|
||||||
// Now only msg2 should show as needing cleanup
|
|
||||||
ids, err = s.AttachmentsExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(ids))
|
|
||||||
require.Equal(t, "msg2", ids[0])
|
|
||||||
|
|
||||||
// Mark msg2 too
|
|
||||||
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
|
||||||
|
|
||||||
// No more expired attachments to clean up
|
|
||||||
ids, err = s.AttachmentsExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, len(ids))
|
|
||||||
|
|
||||||
// Messages themselves still exist
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStats(t *testing.T, s message.Store) {
|
|
||||||
// Initial stats should be zero
|
|
||||||
messages, err := s.Stats()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(0), messages)
|
|
||||||
|
|
||||||
// Update stats
|
|
||||||
require.Nil(t, s.UpdateStats(42))
|
|
||||||
messages, err = s.Stats()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(42), messages)
|
|
||||||
|
|
||||||
// Update again (overwrites)
|
|
||||||
require.Nil(t, s.UpdateStats(100))
|
|
||||||
messages, err = s.Stats()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(100), messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAddMessages(t *testing.T, s message.Store) {
|
|
||||||
// Batch add multiple messages
|
|
||||||
msgs := []*model.Message{
|
|
||||||
model.NewDefaultMessage("mytopic", "batch 1"),
|
|
||||||
model.NewDefaultMessage("mytopic", "batch 2"),
|
|
||||||
model.NewDefaultMessage("othertopic", "batch 3"),
|
|
||||||
}
|
|
||||||
require.Nil(t, s.AddMessages(msgs))
|
|
||||||
|
|
||||||
// Verify all were inserted
|
|
||||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
|
|
||||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "batch 3", messages[0].Message)
|
|
||||||
|
|
||||||
// Empty batch should succeed
|
|
||||||
require.Nil(t, s.AddMessages([]*model.Message{}))
|
|
||||||
|
|
||||||
// Batch with invalid event type should fail
|
|
||||||
badMsgs := []*model.Message{
|
|
||||||
model.NewKeepaliveMessage("mytopic"),
|
|
||||||
}
|
|
||||||
require.NotNil(t, s.AddMessages(badMsgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMessagesDue(t *testing.T, s message.Store) {
|
|
||||||
// Add a message scheduled in the past (i.e. it's due now)
|
|
||||||
m1 := model.NewDefaultMessage("mytopic", "due message")
|
|
||||||
m1.Time = time.Now().Add(-time.Second).Unix()
|
|
||||||
// Set expires in the future so it doesn't get pruned
|
|
||||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m1))
|
|
||||||
|
|
||||||
// Add a message scheduled in the future (not due)
|
|
||||||
m2 := model.NewDefaultMessage("mytopic", "future message")
|
|
||||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, s.AddMessage(m2))
|
|
||||||
|
|
||||||
// Mark m1 as published so it won't be "due"
|
|
||||||
// (MessagesDue returns unpublished messages whose time <= now)
|
|
||||||
// m1 is auto-published (time <= now), so it should not be due
|
|
||||||
// m2 is unpublished (time in future), not due yet
|
|
||||||
due, err := s.MessagesDue()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, len(due))
|
|
||||||
|
|
||||||
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
|
||||||
// We need to manipulate the database to create a truly "due" message:
|
|
||||||
// a message with published=false and time <= now
|
|
||||||
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
|
||||||
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
|
||||||
require.Nil(t, s.AddMessage(m3))
|
|
||||||
|
|
||||||
// Not due yet
|
|
||||||
due, err = s.MessagesDue()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, len(due))
|
|
||||||
|
|
||||||
// Wait for it to become due
|
|
||||||
time.Sleep(3 * time.Second)
|
|
||||||
|
|
||||||
due, err = s.MessagesDue()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(due))
|
|
||||||
require.Equal(t, "truly due message", due[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
|
|
||||||
// Create a message with all fields populated
|
|
||||||
m := model.NewDefaultMessage("mytopic", "hello world")
|
|
||||||
m.SequenceID = "custom_seq_id"
|
|
||||||
m.Title = "A Title"
|
|
||||||
m.Priority = 4
|
|
||||||
m.Tags = []string{"warning", "srv01"}
|
|
||||||
m.Click = "https://example.com/click"
|
|
||||||
m.Icon = "https://example.com/icon.png"
|
|
||||||
m.Actions = []*model.Action{
|
|
||||||
{
|
|
||||||
ID: "action1",
|
|
||||||
Action: "view",
|
|
||||||
Label: "Open Site",
|
|
||||||
URL: "https://example.com",
|
|
||||||
Clear: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "action2",
|
|
||||||
Action: "http",
|
|
||||||
Label: "Call Webhook",
|
|
||||||
URL: "https://example.com/hook",
|
|
||||||
Method: "PUT",
|
|
||||||
Headers: map[string]string{"X-Token": "secret"},
|
|
||||||
Body: `{"key":"value"}`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
m.ContentType = "text/markdown"
|
|
||||||
m.Encoding = "base64"
|
|
||||||
m.Sender = netip.MustParseAddr("9.8.7.6")
|
|
||||||
m.User = "u_TestUser123"
|
|
||||||
require.Nil(t, s.AddMessage(m))
|
|
||||||
|
|
||||||
// Retrieve and verify every field
|
|
||||||
retrieved, err := s.Message(m.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, m.ID, retrieved.ID)
|
|
||||||
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
|
||||||
require.Equal(t, m.Time, retrieved.Time)
|
|
||||||
require.Equal(t, m.Expires, retrieved.Expires)
|
|
||||||
require.Equal(t, model.MessageEvent, retrieved.Event)
|
|
||||||
require.Equal(t, "mytopic", retrieved.Topic)
|
|
||||||
require.Equal(t, "hello world", retrieved.Message)
|
|
||||||
require.Equal(t, "A Title", retrieved.Title)
|
|
||||||
require.Equal(t, 4, retrieved.Priority)
|
|
||||||
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
|
||||||
require.Equal(t, "https://example.com/click", retrieved.Click)
|
|
||||||
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
|
||||||
require.Equal(t, "text/markdown", retrieved.ContentType)
|
|
||||||
require.Equal(t, "base64", retrieved.Encoding)
|
|
||||||
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
|
||||||
require.Equal(t, "u_TestUser123", retrieved.User)
|
|
||||||
|
|
||||||
// Verify actions round-trip
|
|
||||||
require.Equal(t, 2, len(retrieved.Actions))
|
|
||||||
|
|
||||||
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
|
||||||
require.Equal(t, "view", retrieved.Actions[0].Action)
|
|
||||||
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
|
||||||
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
|
||||||
require.Equal(t, true, retrieved.Actions[0].Clear)
|
|
||||||
|
|
||||||
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
|
||||||
require.Equal(t, "http", retrieved.Actions[1].Action)
|
|
||||||
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
|
||||||
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
|
||||||
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
|
||||||
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
|
||||||
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
|
||||||
}
|
|
||||||
@@ -146,6 +146,13 @@ func NewActionMessage(event, topic, sequenceID string) *Message {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPollRequestMessage is a convenience method to create a poll request message
|
||||||
|
func NewPollRequestMessage(topic, pollID string) *Message {
|
||||||
|
m := NewMessage(PollRequestEvent, topic, "New message")
|
||||||
|
m.PollID = pollID
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// ValidMessageID returns true if the given string is a valid message ID
|
// ValidMessageID returns true if the given string is a valid message ID
|
||||||
func ValidMessageID(s string) bool {
|
func ValidMessageID(s string) bool {
|
||||||
return util.ValidRandomString(s, MessageIDLength)
|
return util.ValidRandomString(s, MessageIDLength)
|
||||||
@@ -184,7 +191,7 @@ func (t SinceMarker) IsLatest() bool {
|
|||||||
|
|
||||||
// IsID returns true if this marker references a specific message ID
|
// IsID returns true if this marker references a specific message ID
|
||||||
func (t SinceMarker) IsID() bool {
|
func (t SinceMarker) IsID() bool {
|
||||||
return t.id != "" && t.id != "latest"
|
return t.id != "" && t.id != SinceLatestMessage.id
|
||||||
}
|
}
|
||||||
|
|
||||||
// Time returns the time component of the marker
|
// Time returns the time component of the marker
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ func (p *actionParser) Parse() ([]*model.Action, error) {
|
|||||||
// and then uses populateAction to interpret the keys/values. The function terminates
|
// and then uses populateAction to interpret the keys/values. The function terminates
|
||||||
// when EOF or ";" is reached.
|
// when EOF or ";" is reached.
|
||||||
func (p *actionParser) parseAction() (*model.Action, error) {
|
func (p *actionParser) parseAction() (*model.Action, error) {
|
||||||
a := newAction()
|
a := model.NewAction()
|
||||||
section := 0
|
section := 0
|
||||||
for {
|
for {
|
||||||
key, value, last, err := p.parseSection()
|
key, value, last, err := p.parseSection()
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ var (
|
|||||||
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
// Config is the main config struct for the application. Use New to instantiate a default config struct.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
File string // Config file, only used for testing
|
File string // Config file, only used for testing
|
||||||
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
|
||||||
BaseURL string
|
BaseURL string
|
||||||
ListenHTTP string
|
ListenHTTP string
|
||||||
ListenHTTPS string
|
ListenHTTPS string
|
||||||
@@ -96,6 +95,7 @@ type Config struct {
|
|||||||
ListenUnixMode fs.FileMode
|
ListenUnixMode fs.FileMode
|
||||||
KeyFile string
|
KeyFile string
|
||||||
CertFile string
|
CertFile string
|
||||||
|
DatabaseURL string // PostgreSQL connection string (e.g. "postgres://user:pass@host:5432/ntfy")
|
||||||
FirebaseKeyFile string
|
FirebaseKeyFile string
|
||||||
CacheFile string
|
CacheFile string
|
||||||
CacheDuration time.Duration
|
CacheDuration time.Duration
|
||||||
@@ -193,7 +193,6 @@ type Config struct {
|
|||||||
func NewConfig() *Config {
|
func NewConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
File: DefaultConfigFile, // Only used for testing
|
File: DefaultConfigFile, // Only used for testing
|
||||||
DatabaseURL: "",
|
|
||||||
BaseURL: "",
|
BaseURL: "",
|
||||||
ListenHTTP: DefaultListenHTTP,
|
ListenHTTP: DefaultListenHTTP,
|
||||||
ListenHTTPS: "",
|
ListenHTTPS: "",
|
||||||
@@ -201,6 +200,7 @@ func NewConfig() *Config {
|
|||||||
ListenUnixMode: 0,
|
ListenUnixMode: 0,
|
||||||
KeyFile: "",
|
KeyFile: "",
|
||||||
CertFile: "",
|
CertFile: "",
|
||||||
|
DatabaseURL: "",
|
||||||
FirebaseKeyFile: "",
|
FirebaseKeyFile: "",
|
||||||
CacheFile: "",
|
CacheFile: "",
|
||||||
CacheDuration: DefaultCacheDuration,
|
CacheDuration: DefaultCacheDuration,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
|
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
|
||||||
errInvalidFileID = errors.New("invalid file ID")
|
errInvalidFileID = errors.New("invalid file ID")
|
||||||
errFileExists = errors.New("file exists")
|
errFileExists = errors.New("file exists")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -32,6 +33,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
@@ -45,6 +47,7 @@ import (
|
|||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
|
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
httpsServer *http.Server
|
httpsServer *http.Server
|
||||||
httpMetricsServer *http.Server
|
httpMetricsServer *http.Server
|
||||||
@@ -59,8 +62,8 @@ type Server struct {
|
|||||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||||
userManager *user.Manager // Might be nil!
|
userManager *user.Manager // Might be nil!
|
||||||
messageCache message.Store // Database that stores the messages
|
messageCache *message.Cache // Database that stores the messages
|
||||||
webPush webpush.Store // Database that stores web push subscriptions
|
webPush *webpush.Store // Database that stores web push subscriptions
|
||||||
fileCache *fileCache // File system based cache that stores attachments
|
fileCache *fileCache // File system based cache that stores attachments
|
||||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||||
@@ -175,14 +178,23 @@ func New(conf *Config) (*Server, error) {
|
|||||||
if payments.Available && conf.StripeSecretKey != "" {
|
if payments.Available && conf.StripeSecretKey != "" {
|
||||||
stripe = newStripeAPI()
|
stripe = newStripeAPI()
|
||||||
}
|
}
|
||||||
messageCache, err := createMessageCache(conf)
|
// OpenPostgres shared PostgreSQL connection pool if configured
|
||||||
|
var pool *sql.DB
|
||||||
|
if conf.DatabaseURL != "" {
|
||||||
|
var err error
|
||||||
|
pool, err = db.OpenPostgres(conf.DatabaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var wp webpush.Store
|
}
|
||||||
|
messageCache, err := createMessageCache(conf, pool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var wp *webpush.Store
|
||||||
if conf.WebPushPublicKey != "" {
|
if conf.WebPushPublicKey != "" {
|
||||||
if conf.DatabaseURL != "" {
|
if pool != nil {
|
||||||
wp, err = webpush.NewPostgresStore(conf.DatabaseURL)
|
wp, err = webpush.NewPostgresStore(pool)
|
||||||
} else {
|
} else {
|
||||||
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||||
}
|
}
|
||||||
@@ -210,7 +222,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
if conf.AuthFile != "" || pool != nil {
|
||||||
authConfig := &user.Config{
|
authConfig := &user.Config{
|
||||||
Filename: conf.AuthFile,
|
Filename: conf.AuthFile,
|
||||||
DatabaseURL: conf.DatabaseURL,
|
DatabaseURL: conf.DatabaseURL,
|
||||||
@@ -223,19 +235,14 @@ func New(conf *Config) (*Server, error) {
|
|||||||
BcryptCost: conf.AuthBcryptCost,
|
BcryptCost: conf.AuthBcryptCost,
|
||||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
var store user.Store
|
if pool != nil {
|
||||||
if conf.DatabaseURL != "" {
|
userManager, err = user.NewPostgresManager(pool, authConfig)
|
||||||
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
|
||||||
} else {
|
} else {
|
||||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
userManager, err = user.NewSQLiteManager(conf.AuthFile, conf.AuthStartupQueries, authConfig)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
userManager, err = user.NewManager(store, authConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var firebaseClient *firebaseClient
|
var firebaseClient *firebaseClient
|
||||||
if conf.FirebaseKeyFile != "" {
|
if conf.FirebaseKeyFile != "" {
|
||||||
@@ -253,6 +260,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
|
db: pool,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
webPush: wp,
|
webPush: wp,
|
||||||
fileCache: fileCache,
|
fileCache: fileCache,
|
||||||
@@ -269,11 +277,11 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config) (message.Store, error) {
|
func createMessageCache(conf *Config, pool *sql.DB) (*message.Cache, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return message.NewNopStore()
|
return message.NewNopStore()
|
||||||
} else if conf.DatabaseURL != "" {
|
} else if pool != nil {
|
||||||
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||||
} else if conf.CacheFile != "" {
|
} else if conf.CacheFile != "" {
|
||||||
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||||
}
|
}
|
||||||
@@ -412,6 +420,9 @@ func (s *Server) closeDatabases() {
|
|||||||
if s.webPush != nil {
|
if s.webPush != nil {
|
||||||
s.webPush.Close()
|
s.webPush.Close()
|
||||||
}
|
}
|
||||||
|
if s.db != nil {
|
||||||
|
s.db.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle is the main entry point for all HTTP requests
|
// handle is the main entry point for all HTTP requests
|
||||||
@@ -754,7 +765,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||||||
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
|
// - avoid abuse (e.g. 1 uploader, 1k downloaders)
|
||||||
// - and also uses the higher bandwidth limits of a paying user
|
// - and also uses the higher bandwidth limits of a paying user
|
||||||
m, err := s.messageCache.Message(messageID)
|
m, err := s.messageCache.Message(messageID)
|
||||||
if errors.Is(err, errMessageNotFound) {
|
if errors.Is(err, model.ErrMessageNotFound) {
|
||||||
if s.config.CacheBatchTimeout > 0 {
|
if s.config.CacheBatchTimeout > 0 {
|
||||||
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
||||||
// and messages are persisted asynchronously, retry fetching from the database
|
// and messages are persisted asynchronously, retry fetching from the database
|
||||||
@@ -818,7 +829,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(t.ID, "")
|
m := model.NewDefaultMessage(t.ID, "")
|
||||||
cache, firebase, email, call, template, unifiedpush, priorityStr, e := s.parsePublishParams(r, m)
|
cache, firebase, email, call, template, unifiedpush, priorityStr, e := s.parsePublishParams(r, m)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return nil, e.With(t)
|
return nil, e.With(t)
|
||||||
@@ -843,7 +854,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.PollID != "" {
|
if m.PollID != "" {
|
||||||
m = newPollRequestMessage(t.ID, m.PollID)
|
m = model.NewPollRequestMessage(t.ID, m.PollID)
|
||||||
}
|
}
|
||||||
m.Sender = v.IP()
|
m.Sender = v.IP()
|
||||||
m.User = v.MaybeUserID()
|
m.User = v.MaybeUserID()
|
||||||
@@ -961,11 +972,11 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
return s.handleActionMessage(w, r, v, messageDeleteEvent)
|
return s.handleActionMessage(w, r, v, model.MessageDeleteEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleClear(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleClear(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
return s.handleActionMessage(w, r, v, messageClearEvent)
|
return s.handleActionMessage(w, r, v, model.MessageClearEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string) error {
|
func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string) error {
|
||||||
@@ -985,7 +996,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
|||||||
return e.With(t)
|
return e.With(t)
|
||||||
}
|
}
|
||||||
// Create an action message with the given event type
|
// Create an action message with the given event type
|
||||||
m := newActionMessage(event, t.ID, sequenceID)
|
m := model.NewActionMessage(event, t.ID, sequenceID)
|
||||||
m.Sender = v.IP()
|
m.Sender = v.IP()
|
||||||
m.User = v.MaybeUserID()
|
m.User = v.MaybeUserID()
|
||||||
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
||||||
@@ -1001,7 +1012,7 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
|||||||
if s.config.WebPushPublicKey != "" {
|
if s.config.WebPushPublicKey != "" {
|
||||||
go s.publishToWebPushEndpoints(v, m)
|
go s.publishToWebPushEndpoints(v, m)
|
||||||
}
|
}
|
||||||
if event == messageDeleteEvent {
|
if event == model.MessageDeleteEvent {
|
||||||
// Delete any existing scheduled message with the same sequence ID
|
// Delete any existing scheduled message with the same sequence ID
|
||||||
deletedIDs, err := s.messageCache.DeleteScheduledBySequenceID(t.ID, sequenceID)
|
deletedIDs, err := s.messageCache.DeleteScheduledBySequenceID(t.ID, sequenceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1230,7 +1241,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bo
|
|||||||
// 7. curl -T file.txt ntfy.sh/mytopic
|
// 7. curl -T file.txt ntfy.sh/mytopic
|
||||||
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
||||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||||
if m.Event == pollRequestEvent { // Case 1
|
if m.Event == model.PollRequestEvent { // Case 1
|
||||||
return s.handleBodyDiscard(body)
|
return s.handleBodyDiscard(body)
|
||||||
} else if unifiedpush {
|
} else if unifiedpush {
|
||||||
return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
|
return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
|
||||||
@@ -1450,7 +1461,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
|
|||||||
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
|
||||||
return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
|
return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("data: %s\n", buf.String()), nil
|
return fmt.Sprintf("data: %s\n", buf.String()), nil
|
||||||
@@ -1460,7 +1471,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
|
|||||||
|
|
||||||
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
encoder := func(msg *model.Message) (string, error) {
|
encoder := func(msg *model.Message) (string, error) {
|
||||||
if msg.Event == messageEvent { // only handle default events
|
if msg.Event == model.MessageEvent { // only handle default events
|
||||||
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
||||||
}
|
}
|
||||||
return "\n", nil // "keepalive" and "open" events just send an empty line
|
return "\n", nil // "keepalive" and "open" events just send an empty line
|
||||||
@@ -1538,7 +1549,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||||||
topics[i].Unsubscribe(subscriberID) // Order!
|
topics[i].Unsubscribe(subscriberID) // Order!
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||||
@@ -1561,7 +1572,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
t.Keepalive()
|
t.Keepalive()
|
||||||
}
|
}
|
||||||
if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
if err := sub(v, model.NewKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1687,7 +1698,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||||||
topics[i].Unsubscribe(subscriberID) // Order!
|
topics[i].Unsubscribe(subscriberID) // Order!
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
|
if err := sub(v, model.NewOpenMessage(topicsStr)); err != nil { // Send out open message
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
|
||||||
@@ -1818,26 +1829,26 @@ func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
|
|||||||
// Easy cases (empty, all, none)
|
// Easy cases (empty, all, none)
|
||||||
if since == "" {
|
if since == "" {
|
||||||
if poll {
|
if poll {
|
||||||
return sinceAllMessages, nil
|
return model.SinceAllMessages, nil
|
||||||
}
|
}
|
||||||
return sinceNoMessages, nil
|
return model.SinceNoMessages, nil
|
||||||
} else if since == "all" {
|
} else if since == "all" {
|
||||||
return sinceAllMessages, nil
|
return model.SinceAllMessages, nil
|
||||||
} else if since == "latest" {
|
} else if since == "latest" {
|
||||||
return sinceLatestMessage, nil
|
return model.SinceLatestMessage, nil
|
||||||
} else if since == "none" {
|
} else if since == "none" {
|
||||||
return sinceNoMessages, nil
|
return model.SinceNoMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID, timestamp, duration
|
// ID, timestamp, duration
|
||||||
if validMessageID(since) {
|
if model.ValidMessageID(since) {
|
||||||
return newSinceID(since), nil
|
return model.NewSinceID(since), nil
|
||||||
} else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
|
} else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
|
||||||
return newSinceTime(s), nil
|
return model.NewSinceTime(s), nil
|
||||||
} else if d, err := time.ParseDuration(since); err == nil {
|
} else if d, err := time.ParseDuration(since); err == nil {
|
||||||
return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
|
return model.NewSinceTime(time.Now().Add(-1 * d).Unix()), nil
|
||||||
}
|
}
|
||||||
return sinceNoMessages, errHTTPBadRequestSinceInvalid
|
return model.SinceNoMessages, errHTTPBadRequestSinceInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||||
@@ -1993,14 +2004,14 @@ func (s *Server) runFirebaseKeepaliver() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
||||||
s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
|
s.sendToFirebase(v, model.NewKeepaliveMessage(firebaseControlTopic))
|
||||||
/*
|
/*
|
||||||
FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
|
FIXME: Disable iOS polling entirely for now due to thundering herd problem (see #677)
|
||||||
To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
|
To solve this, we'd have to shard the iOS poll topics to spread out the polling evenly.
|
||||||
Given that it's not really necessary to poll, turning it off for now should not have any impact.
|
Given that it's not really necessary to poll, turning it off for now should not have any impact.
|
||||||
|
|
||||||
case <-time.After(s.config.FirebasePollInterval):
|
case <-time.After(s.config.FirebasePollInterval):
|
||||||
s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
|
s.sendToFirebase(v, model.NewKeepaliveMessage(firebasePollTopic))
|
||||||
*/
|
*/
|
||||||
case <-s.closeChan:
|
case <-s.closeChan:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -38,14 +38,32 @@
|
|||||||
#
|
#
|
||||||
# firebase-key-file: <filename>
|
# firebase-key-file: <filename>
|
||||||
|
|
||||||
# If "database-url" is set, ntfy will use PostgreSQL for database-backed stores instead of SQLite.
|
# If "database-url" is set, ntfy will use PostgreSQL for all database-backed stores (message cache,
|
||||||
# Currently this applies to the web push subscription store. Message cache and user manager support
|
# user manager, and web push subscriptions) instead of SQLite. When set, the "cache-file",
|
||||||
# will be added in future releases. When set, the "web-push-file" option is not required.
|
# "auth-file", and "web-push-file" options must not be set.
|
||||||
#
|
#
|
||||||
|
# Note: Setting "database-url" implicitly enables authentication and access control.
|
||||||
|
# The default access is "read-write" (see "auth-default-access").
|
||||||
|
#
|
||||||
|
# The URL supports standard PostgreSQL parameters (sslmode, connect_timeout, sslcert, etc.),
|
||||||
|
# as well as ntfy-specific connection pool parameters:
|
||||||
|
# pool_max_conns=10 - Maximum number of open connections (default: 10)
|
||||||
|
# pool_max_idle_conns=N - Maximum number of idle connections
|
||||||
|
# pool_conn_max_lifetime=5m - Maximum lifetime of a connection (Go duration)
|
||||||
|
# pool_conn_max_idle_time=1m - Maximum idle time of a connection (Go duration)
|
||||||
|
#
|
||||||
|
# See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
|
||||||
|
# for the full list of supported PostgreSQL connection parameters.
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
# database-url: "postgres://user:pass@host:5432/ntfy"
|
# database-url: "postgres://user:pass@host:5432/ntfy"
|
||||||
|
# database-url: "postgres://user:pass@host:5432/ntfy?sslmode=require&pool_max_conns=50"
|
||||||
|
#
|
||||||
|
# database-url: <connection-string>
|
||||||
|
|
||||||
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
|
# If "cache-file" is set, messages are cached in a local SQLite database instead of only in-memory.
|
||||||
# This allows for service restarts without losing messages in support of the since= parameter.
|
# This allows for service restarts without losing messages in support of the since= parameter.
|
||||||
|
# Not required if "database-url" is set (messages are stored in PostgreSQL instead).
|
||||||
#
|
#
|
||||||
# The "cache-duration" parameter defines the duration for which messages will be buffered
|
# The "cache-duration" parameter defines the duration for which messages will be buffered
|
||||||
# before they are deleted. This is required to support the "since=..." and "poll=1" parameter.
|
# before they are deleted. This is required to support the "since=..." and "poll=1" parameter.
|
||||||
@@ -83,6 +101,8 @@
|
|||||||
# If set, access to the ntfy server and API can be controlled on a granular level using
|
# If set, access to the ntfy server and API can be controlled on a granular level using
|
||||||
# the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs.
|
# the 'ntfy user' and 'ntfy access' commands. See the --help pages for details, or check the docs.
|
||||||
#
|
#
|
||||||
|
# Note: If "database-url" is set, auth is implicitly enabled and "auth-file" must not be set.
|
||||||
|
#
|
||||||
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
|
# - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist
|
||||||
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
|
# - auth-default-access defines the default/fallback access if no access control entry is found; it can be
|
||||||
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
|
# set to "read-write" (default), "read-only", "write-only" or "deny-all".
|
||||||
@@ -203,6 +223,7 @@
|
|||||||
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||||
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||||
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
|
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
|
||||||
|
# Not required if "database-url" is set (subscriptions are stored in PostgreSQL instead).
|
||||||
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
|
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
|
||||||
# - web-push-startup-queries is an optional list of queries to run on startup
|
# - web-push-startup-queries is an optional list of queries to run on startup
|
||||||
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d)
|
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -641,7 +642,7 @@ func (s *Server) publishSyncEvent(v *visitor) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(syncTopic.ID, string(messageBytes))
|
m := model.NewDefaultMessage(syncTopic.ID, string(messageBytes))
|
||||||
if err := syncTopic.Publish(v, m); err != nil {
|
if err := syncTopic.Publish(v, m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"io"
|
"io"
|
||||||
@@ -15,8 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAccount_Signup_Success(t *testing.T) {
|
func TestAccount_Signup_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -54,8 +55,8 @@ func TestAccount_Signup_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Signup_UserExists(t *testing.T) {
|
func TestAccount_Signup_UserExists(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -70,8 +71,8 @@ func TestAccount_Signup_UserExists(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Signup_LimitReached(t *testing.T) {
|
func TestAccount_Signup_LimitReached(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -87,8 +88,8 @@ func TestAccount_Signup_LimitReached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Signup_AsUser(t *testing.T) {
|
func TestAccount_Signup_AsUser(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -111,8 +112,8 @@ func TestAccount_Signup_AsUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Signup_Disabled(t *testing.T) {
|
func TestAccount_Signup_Disabled(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = false
|
conf.EnableSignup = false
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -124,8 +125,8 @@ func TestAccount_Signup_Disabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Signup_Rate_Limit(t *testing.T) {
|
func TestAccount_Signup_Rate_Limit(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -140,8 +141,8 @@ func TestAccount_Signup_Rate_Limit(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Get_Anonymous(t *testing.T) {
|
func TestAccount_Get_Anonymous(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.VisitorRequestLimitReplenish = 86 * time.Second
|
conf.VisitorRequestLimitReplenish = 86 * time.Second
|
||||||
conf.VisitorEmailLimitReplenish = time.Hour
|
conf.VisitorEmailLimitReplenish = time.Hour
|
||||||
conf.VisitorAttachmentTotalSizeLimit = 5123
|
conf.VisitorAttachmentTotalSizeLimit = 5123
|
||||||
@@ -185,8 +186,8 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ChangeSettings(t *testing.T) {
|
func TestAccount_ChangeSettings(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
@@ -216,8 +217,8 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
@@ -269,8 +270,8 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ChangePassword(t *testing.T) {
|
func TestAccount_ChangePassword(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.AuthUsers = []*user.User{
|
conf.AuthUsers = []*user.User{
|
||||||
{Name: "philuser", Hash: "$2a$10$U4WSIYY6evyGmZaraavM2e2JeVG6EMGUKN1uUwufUeeRd4Jpg6cGC", Role: user.RoleUser}, // philuser:philpass
|
{Name: "philuser", Hash: "$2a$10$U4WSIYY6evyGmZaraavM2e2JeVG6EMGUKN1uUwufUeeRd4Jpg6cGC", Role: user.RoleUser}, // philuser:philpass
|
||||||
}
|
}
|
||||||
@@ -314,8 +315,8 @@ func TestAccount_ChangePassword(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil)
|
rr := request(t, s, "POST", "/v1/account/password", `{"password": "new password"}`, nil)
|
||||||
@@ -324,9 +325,9 @@ func TestAccount_ChangePassword_NoAccount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ExtendToken(t *testing.T) {
|
func TestAccount_ExtendToken(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
@@ -363,8 +364,8 @@ func TestAccount_ExtendToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
@@ -378,8 +379,8 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_DeleteToken(t *testing.T) {
|
func TestAccount_DeleteToken(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
@@ -420,8 +421,8 @@ func TestAccount_DeleteToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Delete_Success(t *testing.T) {
|
func TestAccount_Delete_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -451,8 +452,8 @@ func TestAccount_Delete_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Delete_Not_Allowed(t *testing.T) {
|
func TestAccount_Delete_Not_Allowed(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -474,8 +475,8 @@ func TestAccount_Delete_Not_Allowed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
|
func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -490,8 +491,8 @@ func TestAccount_Reservation_AddWithoutTierFails(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -544,8 +545,8 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
conf.EnableReservations = true
|
conf.EnableReservations = true
|
||||||
conf.TwilioAccount = "dummy"
|
conf.TwilioAccount = "dummy"
|
||||||
@@ -632,8 +633,8 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.AuthDefault = user.PermissionReadWrite
|
conf.AuthDefault = user.PermissionReadWrite
|
||||||
conf.EnableSignup = true
|
conf.EnableSignup = true
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
@@ -668,9 +669,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.AuthDefault = user.PermissionReadWrite
|
conf.AuthDefault = user.PermissionReadWrite
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
@@ -715,12 +716,12 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
|||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
|
||||||
|
|
||||||
// Pre-verify message count and file
|
// Pre-verify message count and file
|
||||||
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
|
ms, err := s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, len(ms))
|
require.Equal(t, 1, len(ms))
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
||||||
|
|
||||||
ms, err = s.messageCache.Messages("mytopic2", sinceAllMessages, false)
|
ms, err = s.messageCache.Messages("mytopic2", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, len(ms))
|
require.Equal(t, 1, len(ms))
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, m2.ID))
|
||||||
@@ -741,17 +742,17 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
|||||||
// Verify that messages and attachments were deleted
|
// Verify that messages and attachments were deleted
|
||||||
// This does not explicitly call the manager!
|
// This does not explicitly call the manager!
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
ms, err := s.messageCache.Messages("mytopic1", sinceAllMessages, false)
|
ms, err := s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return len(ms) == 0 && !util.FileExists(filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
return len(ms) == 0 && !util.FileExists(filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
||||||
})
|
})
|
||||||
|
|
||||||
ms, err = s.messageCache.Messages("mytopic1", sinceAllMessages, false)
|
ms, err = s.messageCache.Messages("mytopic1", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(ms))
|
require.Equal(t, 0, len(ms))
|
||||||
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, m1.ID))
|
||||||
|
|
||||||
ms, err = s.messageCache.Messages("mytopic2", sinceAllMessages, false)
|
ms, err = s.messageCache.Messages("mytopic2", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 1, len(ms))
|
require.Equal(t, 1, len(ms))
|
||||||
require.Equal(t, m2.ID, ms[0].ID)
|
require.Equal(t, m2.ID, ms[0].ID)
|
||||||
@@ -760,7 +761,7 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
|
/*func TestAccount_Persist_UserStats_After_Tier_Change(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
conf.AuthDefault = user.PermissionReadWrite
|
conf.AuthDefault = user.PermissionReadWrite
|
||||||
conf.AuthStatsQueueWriterInterval = 300 * time.Millisecond
|
conf.AuthStatsQueueWriterInterval = 300 * time.Millisecond
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestVersion_Admin(t *testing.T) {
|
func TestVersion_Admin(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.BuildVersion = "1.2.3"
|
c.BuildVersion = "1.2.3"
|
||||||
c.BuildCommit = "abcdef0"
|
c.BuildCommit = "abcdef0"
|
||||||
c.BuildDate = "2026-02-08T00:00:00Z"
|
c.BuildDate = "2026-02-08T00:00:00Z"
|
||||||
@@ -48,8 +48,8 @@ func TestVersion_Admin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_AddRemove(t *testing.T) {
|
func TestUser_AddRemove(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin, tier
|
// Create admin, tier
|
||||||
@@ -106,8 +106,8 @@ func TestUser_AddRemove(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_AddWithPasswordHash(t *testing.T) {
|
func TestUser_AddWithPasswordHash(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
@@ -137,8 +137,8 @@ func TestUser_AddWithPasswordHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
@@ -177,8 +177,8 @@ func TestUser_ChangeUserPassword(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_ChangeUserTier(t *testing.T) {
|
func TestUser_ChangeUserTier(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin, tier
|
// Create admin, tier
|
||||||
@@ -219,8 +219,8 @@ func TestUser_ChangeUserTier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin, tier
|
// Create admin, tier
|
||||||
@@ -273,8 +273,8 @@ func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
@@ -307,8 +307,8 @@ func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
@@ -324,8 +324,8 @@ func TestUser_DontChangeAdminPassword(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
s := newTestServer(t, newTestConfigWithAuthFile(t, databaseURL))
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
@@ -365,8 +365,8 @@ func TestUser_AddRemove_Failures(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccess_AllowReset(t *testing.T) {
|
func TestAccess_AllowReset(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -408,8 +408,8 @@ func TestAccess_AllowReset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
@@ -426,8 +426,8 @@ func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
defer s.closeDatabases()
|
defer s.closeDatabases()
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
|
|||||||
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
||||||
var apnsConfig *messaging.APNSConfig
|
var apnsConfig *messaging.APNSConfig
|
||||||
switch m.Event {
|
switch m.Event {
|
||||||
case keepaliveEvent, openEvent:
|
case model.KeepaliveEvent, model.OpenEvent:
|
||||||
data = map[string]string{
|
data = map[string]string{
|
||||||
"id": m.ID,
|
"id": m.ID,
|
||||||
"time": fmt.Sprintf("%d", m.Time),
|
"time": fmt.Sprintf("%d", m.Time),
|
||||||
@@ -134,7 +134,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
|
|||||||
"topic": m.Topic,
|
"topic": m.Topic,
|
||||||
}
|
}
|
||||||
apnsConfig = createAPNSBackgroundConfig(data)
|
apnsConfig = createAPNSBackgroundConfig(data)
|
||||||
case pollRequestEvent:
|
case model.PollRequestEvent:
|
||||||
data = map[string]string{
|
data = map[string]string{
|
||||||
"id": m.ID,
|
"id": m.ID,
|
||||||
"time": fmt.Sprintf("%d", m.Time),
|
"time": fmt.Sprintf("%d", m.Time),
|
||||||
@@ -144,7 +144,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
|
|||||||
"poll_id": m.PollID,
|
"poll_id": m.PollID,
|
||||||
}
|
}
|
||||||
apnsConfig = createAPNSAlertConfig(m, data)
|
apnsConfig = createAPNSAlertConfig(m, data)
|
||||||
case messageDeleteEvent, messageClearEvent:
|
case model.MessageDeleteEvent, model.MessageClearEvent:
|
||||||
data = map[string]string{
|
data = map[string]string{
|
||||||
"id": m.ID,
|
"id": m.ID,
|
||||||
"time": fmt.Sprintf("%d", m.Time),
|
"time": fmt.Sprintf("%d", m.Time),
|
||||||
@@ -153,7 +153,7 @@ func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message
|
|||||||
"sequence_id": m.SequenceID,
|
"sequence_id": m.SequenceID,
|
||||||
}
|
}
|
||||||
apnsConfig = createAPNSBackgroundConfig(data)
|
apnsConfig = createAPNSBackgroundConfig(data)
|
||||||
case messageEvent:
|
case model.MessageEvent:
|
||||||
if auther != nil {
|
if auther != nil {
|
||||||
// If "anonymous read" for a topic is not allowed, we cannot send the message along
|
// If "anonymous read" for a topic is not allowed, we cannot send the message along
|
||||||
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
|
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
|
||||||
@@ -298,7 +298,7 @@ func maybeTruncateAPNSBodyMessage(s string) string {
|
|||||||
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
||||||
// most importantly, the PollID.
|
// most importantly, the PollID.
|
||||||
func toPollRequest(m *model.Message) *model.Message {
|
func toPollRequest(m *model.Message) *model.Message {
|
||||||
pr := newPollRequestMessage(m.Topic, m.ID)
|
pr := model.NewPollRequestMessage(m.Topic, m.ID)
|
||||||
pr.ID = m.ID
|
pr.ID = m.ID
|
||||||
pr.Time = m.Time
|
pr.Time = m.Time
|
||||||
pr.Priority = m.Priority // Keep priority
|
pr.Priority = m.Priority // Keep priority
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (s *testFirebaseSender) Messages() []*messaging.Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
||||||
m := newKeepaliveMessage("mytopic")
|
m := model.NewKeepaliveMessage("mytopic")
|
||||||
fbm, err := toFirebaseMessage(m, nil)
|
fbm, err := toFirebaseMessage(m, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "mytopic", fbm.Topic)
|
require.Equal(t, "mytopic", fbm.Topic)
|
||||||
@@ -95,7 +95,7 @@ func TestToFirebaseMessage_Keepalive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_Open(t *testing.T) {
|
func TestToFirebaseMessage_Open(t *testing.T) {
|
||||||
m := newOpenMessage("mytopic")
|
m := model.NewOpenMessage("mytopic")
|
||||||
fbm, err := toFirebaseMessage(m, nil)
|
fbm, err := toFirebaseMessage(m, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "mytopic", fbm.Topic)
|
require.Equal(t, "mytopic", fbm.Topic)
|
||||||
@@ -126,7 +126,7 @@ func TestToFirebaseMessage_Open(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||||
m := newDefaultMessage("mytopic", "this is a message")
|
m := model.NewDefaultMessage("mytopic", "this is a message")
|
||||||
m.Priority = 4
|
m.Priority = 4
|
||||||
m.Tags = []string{"tag 1", "tag2"}
|
m.Tags = []string{"tag 1", "tag2"}
|
||||||
m.Click = "https://google.com"
|
m.Click = "https://google.com"
|
||||||
@@ -220,7 +220,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
|
func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
|
||||||
m := newDefaultMessage("mytopic", "this is a message")
|
m := model.NewDefaultMessage("mytopic", "this is a message")
|
||||||
m.Priority = 5
|
m.Priority = 5
|
||||||
fbm, err := toFirebaseMessage(m, &testAuther{Allow: false}) // Not allowed!
|
fbm, err := toFirebaseMessage(m, &testAuther{Allow: false}) // Not allowed!
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -251,7 +251,7 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToFirebaseMessage_PollRequest(t *testing.T) {
|
func TestToFirebaseMessage_PollRequest(t *testing.T) {
|
||||||
m := newPollRequestMessage("mytopic", "fOv6k1QbCzo6")
|
m := model.NewPollRequestMessage("mytopic", "fOv6k1QbCzo6")
|
||||||
fbm, err := toFirebaseMessage(m, nil)
|
fbm, err := toFirebaseMessage(m, nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "mytopic", fbm.Topic)
|
require.Equal(t, "mytopic", fbm.Topic)
|
||||||
@@ -345,7 +345,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
|
|||||||
func TestToFirebaseSender_Abuse(t *testing.T) {
|
func TestToFirebaseSender_Abuse(t *testing.T) {
|
||||||
sender := &testFirebaseSender{allowed: 2}
|
sender := &testFirebaseSender{allowed: 2}
|
||||||
client := newFirebaseClient(sender, &testAuther{})
|
client := newFirebaseClient(sender, &testAuther{})
|
||||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
visitor := newVisitor(newTestConfig(t, ""), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
||||||
|
|
||||||
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 1, len(sender.Messages()))
|
require.Equal(t, 1, len(sender.Messages()))
|
||||||
|
|||||||
@@ -17,15 +17,10 @@ func (s *Server) execManager() {
|
|||||||
s.pruneMessages()
|
s.pruneMessages()
|
||||||
s.pruneAndNotifyWebPushSubscriptions()
|
s.pruneAndNotifyWebPushSubscriptions()
|
||||||
|
|
||||||
// Message count per topic
|
// Message count
|
||||||
var messagesCached int
|
messagesCached, err := s.messageCache.MessagesCount()
|
||||||
messageCounts, err := s.messageCache.MessageCounts()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Tag(tagManager).Err(err).Warn("Cannot get message counts")
|
log.Tag(tagManager).Err(err).Warn("Cannot get messages count")
|
||||||
messageCounts = make(map[string]int) // Empty, so we can continue
|
|
||||||
}
|
|
||||||
for _, count := range messageCounts {
|
|
||||||
messagesCached += count
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove subscriptions without subscribers
|
// Remove subscriptions without subscribers
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||||
c := newTestConfig(t)
|
c := newTestConfig(t, databaseURL)
|
||||||
c.AttachmentCacheDir = ""
|
c.AttachmentCacheDir = ""
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
@@ -25,6 +26,6 @@ func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testi
|
|||||||
|
|
||||||
// Actually deleted
|
// Actually deleted
|
||||||
_, err := s.messageCache.Message(m.ID)
|
_, err := s.messageCache.Message(m.ID)
|
||||||
require.Equal(t, errMessageNotFound, err)
|
require.Equal(t, model.ErrMessageNotFound, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stripe/stripe-go/v74"
|
"github.com/stripe/stripe-go/v74"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/payments"
|
"heckel.io/ntfy/v2/payments"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
@@ -21,11 +22,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestPayments_Tiers(t *testing.T) {
|
func TestPayments_Tiers(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
c.VisitorRequestLimitReplenish = 12 * time.Hour
|
c.VisitorRequestLimitReplenish = 12 * time.Hour
|
||||||
@@ -133,11 +134,11 @@ func TestPayments_Tiers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -168,11 +169,11 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -214,11 +215,11 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.EnableSignup = true
|
c.EnableSignup = true
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
@@ -261,7 +262,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
|
func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
// This test is too overloaded, but it's also a great end-to-end a test.
|
// This test is too overloaded, but it's also a great end-to-end a test.
|
||||||
//
|
//
|
||||||
// It tests:
|
// It tests:
|
||||||
@@ -273,7 +274,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
c.VisitorRequestLimitBurst = 5
|
c.VisitorRequestLimitBurst = 5
|
||||||
@@ -428,7 +429,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// This tests incoming webhooks from Stripe to update a subscription:
|
// This tests incoming webhooks from Stripe to update a subscription:
|
||||||
@@ -439,7 +440,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -546,12 +547,12 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
s.execManager()
|
s.execManager()
|
||||||
|
|
||||||
ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
|
ms, err := s.messageCache.Messages("atopic", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 2, len(ms))
|
require.Equal(t, 2, len(ms))
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
|
||||||
|
|
||||||
ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
|
ms, err = s.messageCache.Messages("ztopic", model.SinceAllMessages, false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(ms))
|
require.Equal(t, 0, len(ms))
|
||||||
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
|
require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
|
||||||
@@ -559,7 +560,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
// This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
|
// This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
|
||||||
// updated (all Stripe fields are deleted, and the tier is removed).
|
// updated (all Stripe fields are deleted, and the tier is removed).
|
||||||
//
|
//
|
||||||
@@ -568,7 +569,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
|||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -626,11 +627,11 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -692,11 +693,11 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
|
func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
@@ -725,11 +726,11 @@ func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPayments_CreatePortalSession(t *testing.T) {
|
func TestPayments_CreatePortalSession(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
stripeMock := &testStripeAPI{}
|
stripeMock := &testStripeAPI{}
|
||||||
defer stripeMock.AssertExpectations(t)
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.StripeSecretKey = "secret key"
|
c.StripeSecretKey = "secret key"
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|||||||
5
server/server_race_off_test.go
Normal file
5
server/server_race_off_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build !race
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
const raceEnabled = false
|
||||||
5
server/server_race_on_test.go
Normal file
5
server/server_race_on_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//go:build race
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
const raceEnabled = true
|
||||||
File diff suppressed because one or more lines are too long
@@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
var called, verified atomic.Bool
|
var called, verified atomic.Bool
|
||||||
var code atomic.Pointer[string]
|
var code atomic.Pointer[string]
|
||||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -51,7 +51,7 @@ func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer twilioCallsServer.Close()
|
defer twilioCallsServer.Close()
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
@@ -117,7 +117,7 @@ func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success(t *testing.T) {
|
func TestServer_Twilio_Call_Success(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if called.Load() {
|
if called.Load() {
|
||||||
@@ -132,7 +132,7 @@ func TestServer_Twilio_Call_Success(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer twilioServer.Close()
|
defer twilioServer.Close()
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -164,7 +164,7 @@ func TestServer_Twilio_Call_Success(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if called.Load() {
|
if called.Load() {
|
||||||
@@ -179,7 +179,7 @@ func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer twilioServer.Close()
|
defer twilioServer.Close()
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -211,7 +211,7 @@ func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
var called atomic.Bool
|
var called atomic.Bool
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if called.Load() {
|
if called.Load() {
|
||||||
@@ -226,7 +226,7 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer twilioServer.Close()
|
defer twilioServer.Close()
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -274,8 +274,8 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -301,8 +301,8 @@ func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -317,8 +317,8 @@ func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t, databaseURL)
|
||||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
@@ -333,8 +333,8 @@ func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t, databaseURL))
|
||||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
"x-call": "+1234",
|
"x-call": "+1234",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -26,21 +26,21 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_WebPush_Enabled(t *testing.T) {
|
func TestServer_WebPush_Enabled(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
conf := newTestConfig(t)
|
conf := newTestConfig(t, databaseURL)
|
||||||
conf.WebRoot = "" // Disable web app
|
conf.WebRoot = "" // Disable web app
|
||||||
s := newTestServer(t, conf)
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||||
require.Equal(t, 404, rr.Code)
|
require.Equal(t, 404, rr.Code)
|
||||||
|
|
||||||
conf2 := newTestConfig(t)
|
conf2 := newTestConfig(t, databaseURL)
|
||||||
s2 := newTestServer(t, conf2)
|
s2 := newTestServer(t, conf2)
|
||||||
|
|
||||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||||
require.Equal(t, 404, rr.Code)
|
require.Equal(t, 404, rr.Code)
|
||||||
|
|
||||||
conf3 := newTestConfigWithWebPush(t)
|
conf3 := newTestConfigWithWebPush(t, databaseURL)
|
||||||
s3 := newTestServer(t, conf3)
|
s3 := newTestServer(t, conf3)
|
||||||
|
|
||||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||||
@@ -50,8 +50,8 @@ func TestServer_WebPush_Enabled(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t, databaseURL))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 404, response.Code)
|
require.Equal(t, 404, response.Code)
|
||||||
@@ -59,8 +59,8 @@ func TestServer_WebPush_Disabled(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
@@ -78,8 +78,8 @@ func TestServer_WebPush_TopicAdd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||||
require.Equal(t, 400, response.Code)
|
require.Equal(t, 400, response.Code)
|
||||||
@@ -88,8 +88,8 @@ func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
topicList := make([]string, 51)
|
topicList := make([]string, 51)
|
||||||
for i := range topicList {
|
for i := range topicList {
|
||||||
@@ -103,8 +103,8 @@ func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
@@ -118,8 +118,8 @@ func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Delete(t *testing.T) {
|
func TestServer_WebPush_Delete(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
@@ -133,8 +133,8 @@ func TestServer_WebPush_Delete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
config.AuthDefault = user.PermissionDenyAll
|
config.AuthDefault = user.PermissionDenyAll
|
||||||
s := newTestServer(t, config)
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
@@ -155,8 +155,8 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
config.AuthDefault = user.PermissionDenyAll
|
config.AuthDefault = user.PermissionDenyAll
|
||||||
s := newTestServer(t, config)
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
@@ -168,8 +168,8 @@ func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
config := configureAuth(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
s := newTestServer(t, config)
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
@@ -193,8 +193,8 @@ func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Publish(t *testing.T) {
|
func TestServer_WebPush_Publish(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -217,8 +217,8 @@ func TestServer_WebPush_Publish(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -247,8 +247,8 @@ func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||||
forEachBackend(t, func(t *testing.T) {
|
forEachBackend(t, func(t *testing.T, databaseURL string) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
s := newTestServer(t, newTestConfigWithWebPush(t, databaseURL))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
|
|
||||||
@@ -307,8 +307,8 @@ func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLen
|
|||||||
require.Len(t, subs, expectedLength)
|
require.Len(t, subs, expectedLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestConfigWithWebPush(t *testing.T) *Config {
|
func newTestConfigWithWebPush(t *testing.T, databaseURL string) *Config {
|
||||||
conf := newTestConfig(t)
|
conf := newTestConfig(t, databaseURL)
|
||||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
if conf.DatabaseURL == "" {
|
if conf.DatabaseURL == "" {
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ var (
|
|||||||
var (
|
var (
|
||||||
onlySpacesRegex = regexp.MustCompile(`(?m)^\s+$`)
|
onlySpacesRegex = regexp.MustCompile(`(?m)^\s+$`)
|
||||||
consecutiveNewLinesRegex = regexp.MustCompile(`\n{3,}`)
|
consecutiveNewLinesRegex = regexp.MustCompile(`\n{3,}`)
|
||||||
|
htmlLineBreakRegex = regexp.MustCompile(`(?i)<br\s*/?>`)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -159,7 +160,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||||||
if len(body) > conf.MessageSizeLimit {
|
if len(body) > conf.MessageSizeLimit {
|
||||||
body = body[:conf.MessageSizeLimit]
|
body = body[:conf.MessageSizeLimit]
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(s.topic, body)
|
m := model.NewDefaultMessage(s.topic, body)
|
||||||
subject := strings.TrimSpace(msg.Header.Get("Subject"))
|
subject := strings.TrimSpace(msg.Header.Get("Subject"))
|
||||||
if subject != "" {
|
if subject != "" {
|
||||||
dec := mime.WordDecoder{}
|
dec := mime.WordDecoder{}
|
||||||
@@ -328,6 +329,9 @@ func readHTMLMailBody(reader io.Reader, transferEncoding string) (string, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
// Convert <br> tags to newlines before stripping HTML, so that line breaks
|
||||||
|
// in HTML emails (e.g. from Synology DSM, and other appliances) are preserved.
|
||||||
|
body = htmlLineBreakRegex.ReplaceAllString(body, "\n")
|
||||||
stripped := bluemonday.
|
stripped := bluemonday.
|
||||||
StrictPolicy().
|
StrictPolicy().
|
||||||
AddSpaceWhenStrippingTag(true).
|
AddSpaceWhenStrippingTag(true).
|
||||||
|
|||||||
@@ -695,6 +695,7 @@ Now the light is on
|
|||||||
|
|
||||||
If you don't want to receive this message anymore, stop the push
|
If you don't want to receive this message anymore, stop the push
|
||||||
services in your FRITZ!Box .
|
services in your FRITZ!Box .
|
||||||
|
|
||||||
Here you can see the active push services: "System > Push Service".
|
Here you can see the active push services: "System > Push Service".
|
||||||
|
|
||||||
This mail has ben sent by your FRITZ!Box automatically.`
|
This mail has ben sent by your FRITZ!Box automatically.`
|
||||||
@@ -1354,9 +1355,11 @@ Congratulations! You have successfully set up the email notification on Synology
|
|||||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, "/synology", r.URL.Path)
|
require.Equal(t, "/synology", r.URL.Path)
|
||||||
require.Equal(t, "[Synology NAS] Test Message from Litts_NAS", r.Header.Get("Title"))
|
require.Equal(t, "[Synology NAS] Test Message from Litts_NAS", r.Header.Get("Title"))
|
||||||
actual := readAll(t, r.Body)
|
expected := "Congratulations! You have successfully set up the email notification on Synology_NAS.\n" +
|
||||||
expected := `Congratulations! You have successfully set up the email notification on Synology_NAS. For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/. (If you cannot connect to the server, please contact the administrator.) From Synology_NAS`
|
"For further system configurations, please visit http://192.168.1.28:5000/, http://172.16.60.5:5000/.\n" +
|
||||||
require.Equal(t, expected, actual)
|
"(If you cannot connect to the server, please contact the administrator.)\n\n" +
|
||||||
|
"From Synology_NAS"
|
||||||
|
require.Equal(t, expected, readAll(t, r.Body))
|
||||||
})
|
})
|
||||||
conf.SMTPServerDomain = "mydomain.me"
|
conf.SMTPServerDomain = "mydomain.me"
|
||||||
conf.SMTPServerAddrPrefix = ""
|
conf.SMTPServerAddrPrefix = ""
|
||||||
@@ -1365,6 +1368,36 @@ Congratulations! You have successfully set up the email notification on Synology
|
|||||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSmtpBackend_HTMLEmail_BrTagsPreserved(t *testing.T) {
|
||||||
|
email := `EHLO example.com
|
||||||
|
MAIL FROM: nas@example.com
|
||||||
|
RCPT TO: ntfy-alerts@ntfy.sh
|
||||||
|
DATA
|
||||||
|
Content-Type: text/html; charset=utf-8
|
||||||
|
Content-Transfer-Encoding: 8bit
|
||||||
|
Subject: Task Scheduler: daily-backup
|
||||||
|
|
||||||
|
Task Scheduler has completed a scheduled task.<BR><BR>Task: daily-backup<BR>Start time: Mon, 01 Jan 2026 02:00:00 +0000<BR>Stop time: Mon, 01 Jan 2024 02:03:00 +0000<BR>Current status: 0 (Normal)<BR>Standard output/error:<BR>OK<BR><BR>From MyNAS
|
||||||
|
.
|
||||||
|
`
|
||||||
|
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/alerts", r.URL.Path)
|
||||||
|
require.Equal(t, "Task Scheduler: daily-backup", r.Header.Get("Title"))
|
||||||
|
expected := "Task Scheduler has completed a scheduled task.\n\n" +
|
||||||
|
"Task: daily-backup\n" +
|
||||||
|
"Start time: Mon, 01 Jan 2026 02:00:00 +0000\n" +
|
||||||
|
"Stop time: Mon, 01 Jan 2024 02:03:00 +0000\n" +
|
||||||
|
"Current status: 0 (Normal)\n" +
|
||||||
|
"Standard output/error:\n" +
|
||||||
|
"OK\n\n" +
|
||||||
|
"From MyNAS"
|
||||||
|
require.Equal(t, expected, readAll(t, r.Body))
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
defer c.Close()
|
||||||
|
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||||
|
}
|
||||||
|
|
||||||
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
|
func TestSmtpBackend_PlaintextWithToken(t *testing.T) {
|
||||||
email := `EHLO example.com
|
email := `EHLO example.com
|
||||||
MAIL FROM: phil@example.com
|
MAIL FROM: phil@example.com
|
||||||
@@ -1411,7 +1444,7 @@ what's up
|
|||||||
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
|
type smtpHandlerFunc func(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
|
func newTestSMTPServer(t *testing.T, handler smtpHandlerFunc) (s *smtp.Server, c net.Conn, conf *Config, scanner *bufio.Scanner) {
|
||||||
conf = newTestConfig(t)
|
conf = newTestConfig(t, "")
|
||||||
conf.SMTPServerListen = ":25"
|
conf.SMTPServerListen = ":25"
|
||||||
conf.SMTPServerDomain = "ntfy.sh"
|
conf.SMTPServerDomain = "ntfy.sh"
|
||||||
conf.SMTPServerAddrPrefix = "ntfy-"
|
conf.SMTPServerAddrPrefix = "ntfy-"
|
||||||
|
|||||||
@@ -8,49 +8,6 @@ import (
|
|||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Event constants
|
|
||||||
const (
|
|
||||||
openEvent = model.OpenEvent
|
|
||||||
keepaliveEvent = model.KeepaliveEvent
|
|
||||||
messageEvent = model.MessageEvent
|
|
||||||
messageDeleteEvent = model.MessageDeleteEvent
|
|
||||||
messageClearEvent = model.MessageClearEvent
|
|
||||||
pollRequestEvent = model.PollRequestEvent
|
|
||||||
messageIDLength = model.MessageIDLength
|
|
||||||
)
|
|
||||||
|
|
||||||
// SinceMarker aliases
|
|
||||||
var (
|
|
||||||
sinceAllMessages = model.SinceAllMessages
|
|
||||||
sinceNoMessages = model.SinceNoMessages
|
|
||||||
sinceLatestMessage = model.SinceLatestMessage
|
|
||||||
)
|
|
||||||
|
|
||||||
// Error aliases
|
|
||||||
var (
|
|
||||||
errMessageNotFound = model.ErrMessageNotFound
|
|
||||||
)
|
|
||||||
|
|
||||||
// Constructors and helpers
|
|
||||||
var (
|
|
||||||
newMessage = model.NewMessage
|
|
||||||
newDefaultMessage = model.NewDefaultMessage
|
|
||||||
newOpenMessage = model.NewOpenMessage
|
|
||||||
newKeepaliveMessage = model.NewKeepaliveMessage
|
|
||||||
newActionMessage = model.NewActionMessage
|
|
||||||
newAction = model.NewAction
|
|
||||||
newSinceTime = model.NewSinceTime
|
|
||||||
newSinceID = model.NewSinceID
|
|
||||||
validMessageID = model.ValidMessageID
|
|
||||||
)
|
|
||||||
|
|
||||||
// newPollRequestMessage is a convenience method to create a poll request message
|
|
||||||
func newPollRequestMessage(topic, pollID string) *model.Message {
|
|
||||||
m := newMessage(pollRequestEvent, topic, newMessageBody)
|
|
||||||
m.PollID = pollID
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// publishMessage is used as input when publishing as JSON
|
// publishMessage is used as input when publishing as JSON
|
||||||
type publishMessage struct {
|
type publishMessage struct {
|
||||||
Topic string `json:"topic"`
|
Topic string `json:"topic"`
|
||||||
@@ -106,7 +63,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryFilter) Pass(msg *model.Message) bool {
|
func (q *queryFilter) Pass(msg *model.Message) bool {
|
||||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
if msg.Event != model.MessageEvent && msg.Event != model.MessageDeleteEvent && msg.Event != model.MessageClearEvent {
|
||||||
return true // filters only apply to messages
|
return true // filters only apply to messages
|
||||||
} else if q.ID != "" && msg.ID != q.ID {
|
} else if q.ID != "" && msg.ID != q.ID {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ const (
|
|||||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||||
type visitor struct {
|
type visitor struct {
|
||||||
config *Config
|
config *Config
|
||||||
messageCache message.Store
|
messageCache *message.Cache
|
||||||
userManager *user.Manager // May be nil
|
userManager *user.Manager // May be nil
|
||||||
ip netip.Addr // Visitor IP address
|
ip netip.Addr // Visitor IP address
|
||||||
user *user.User // Only set if authenticated user, otherwise nil
|
user *user.User // Only set if authenticated user, otherwise nil
|
||||||
@@ -115,7 +115,7 @@ const (
|
|||||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
func newVisitor(conf *Config, messageCache *message.Cache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||||
var messages, emails, calls int64
|
var messages, emails, calls int64
|
||||||
if user != nil {
|
if user != nil {
|
||||||
messages = user.Stats.Messages
|
messages = user.Stats.Messages
|
||||||
|
|||||||
35
tools/pgimport/README.md
Normal file
35
tools/pgimport/README.md
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# pgimport
|
||||||
|
|
||||||
|
Migrates ntfy data from SQLite to PostgreSQL.
|
||||||
|
|
||||||
|
## Build
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go build -o pgimport ./tools/pgimport/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Using CLI flags
|
||||||
|
pgimport \
|
||||||
|
--database-url "postgres://user:pass@host:5432/ntfy?sslmode=require" \
|
||||||
|
--cache-file /var/cache/ntfy/cache.db \
|
||||||
|
--auth-file /var/lib/ntfy/user.db \
|
||||||
|
--web-push-file /var/lib/ntfy/webpush.db
|
||||||
|
|
||||||
|
# Using server.yml (flags override config values)
|
||||||
|
pgimport --config /etc/ntfy/server.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- PostgreSQL schema must already be set up (run ntfy with `database-url` once)
|
||||||
|
- ntfy must not be running during the import
|
||||||
|
- All three SQLite files are optional; only the ones specified will be imported
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The tool is idempotent and safe to re-run
|
||||||
|
- After importing, row counts and content are verified against the SQLite sources
|
||||||
|
- Invalid UTF-8 in messages is replaced with the Unicode replacement character
|
||||||
888
tools/pgimport/main.go
Normal file
888
tools/pgimport/main.go
Normal file
@@ -0,0 +1,888 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
batchSize = 1000
|
||||||
|
|
||||||
|
expectedMessageSchemaVersion = 14
|
||||||
|
expectedUserSchemaVersion = 6
|
||||||
|
expectedWebPushSchemaVersion = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var flags = []cli.Flag{
|
||||||
|
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, Usage: "path to server.yml config file"},
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, Usage: "PostgreSQL connection string"}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "cache-file", Aliases: []string{"cache_file"}, Usage: "SQLite message cache file path"}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file"}, Usage: "SQLite user/auth database file path"}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "web-push-file", Aliases: []string{"web_push_file"}, Usage: "SQLite web push database file path"}),
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
app := &cli.App{
|
||||||
|
Name: "pgimport",
|
||||||
|
Usage: "SQLite to PostgreSQL migration tool for ntfy",
|
||||||
|
UsageText: "pgimport [OPTIONS]",
|
||||||
|
Flags: flags,
|
||||||
|
Before: loadConfigFile("config", flags),
|
||||||
|
Action: execImport,
|
||||||
|
}
|
||||||
|
if err := app.Run(os.Args); err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func execImport(c *cli.Context) error {
|
||||||
|
databaseURL := c.String("database-url")
|
||||||
|
cacheFile := c.String("cache-file")
|
||||||
|
authFile := c.String("auth-file")
|
||||||
|
webPushFile := c.String("web-push-file")
|
||||||
|
|
||||||
|
if databaseURL == "" {
|
||||||
|
return fmt.Errorf("database-url must be set (via --database-url or config file)")
|
||||||
|
}
|
||||||
|
if cacheFile == "" && authFile == "" && webPushFile == "" {
|
||||||
|
return fmt.Errorf("at least one of --cache-file, --auth-file, or --web-push-file must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("pgimport - SQLite to PostgreSQL migration tool for ntfy")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Sources:")
|
||||||
|
printSource(" Cache file: ", cacheFile)
|
||||||
|
printSource(" Auth file: ", authFile)
|
||||||
|
printSource(" Web push file: ", webPushFile)
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Target:")
|
||||||
|
fmt.Printf(" Database URL: %s\n", maskPassword(databaseURL))
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("This will import data from the SQLite databases into PostgreSQL.")
|
||||||
|
fmt.Print("Make sure ntfy is not running. Continue? (y/n): ")
|
||||||
|
|
||||||
|
var answer string
|
||||||
|
fmt.Scanln(&answer)
|
||||||
|
if strings.TrimSpace(strings.ToLower(answer)) != "y" {
|
||||||
|
fmt.Println("Aborted.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
pgDB, err := db.OpenPostgres(databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
|
||||||
|
}
|
||||||
|
defer pgDB.Close()
|
||||||
|
|
||||||
|
if authFile != "" {
|
||||||
|
if err := verifySchemaVersion(pgDB, "user", expectedUserSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := importUsers(authFile, pgDB); err != nil {
|
||||||
|
return fmt.Errorf("cannot import users: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cacheFile != "" {
|
||||||
|
if err := verifySchemaVersion(pgDB, "message", expectedMessageSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := importMessages(cacheFile, pgDB); err != nil {
|
||||||
|
return fmt.Errorf("cannot import messages: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webPushFile != "" {
|
||||||
|
if err := verifySchemaVersion(pgDB, "webpush", expectedWebPushSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := importWebPush(webPushFile, pgDB); err != nil {
|
||||||
|
return fmt.Errorf("cannot import web push subscriptions: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Verifying migration ...")
|
||||||
|
failed := false
|
||||||
|
if authFile != "" {
|
||||||
|
if err := verifyUsers(authFile, pgDB, &failed); err != nil {
|
||||||
|
return fmt.Errorf("cannot verify users: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cacheFile != "" {
|
||||||
|
if err := verifyMessages(cacheFile, pgDB, &failed); err != nil {
|
||||||
|
return fmt.Errorf("cannot verify messages: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webPushFile != "" {
|
||||||
|
if err := verifyWebPush(webPushFile, pgDB, &failed); err != nil {
|
||||||
|
return fmt.Errorf("cannot verify web push: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
if failed {
|
||||||
|
return fmt.Errorf("verification FAILED, see above for details")
|
||||||
|
}
|
||||||
|
fmt.Println("Verification successful. Migration complete.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfigFile(configFlag string, flags []cli.Flag) cli.BeforeFunc {
|
||||||
|
return func(c *cli.Context) error {
|
||||||
|
configFile := c.String(configFlag)
|
||||||
|
if configFile == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(configFile); os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("config file %s does not exist", configFile)
|
||||||
|
}
|
||||||
|
inputSource, err := newYamlSourceFromFile(configFile, flags)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return altsrc.ApplyInputSourceValues(c, inputSource, flags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newYamlSourceFromFile(file string, flags []cli.Flag) (altsrc.InputSourceContext, error) {
|
||||||
|
var rawConfig map[any]any
|
||||||
|
b, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := yaml.Unmarshal(b, &rawConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, f := range flags {
|
||||||
|
flagName := f.Names()[0]
|
||||||
|
for _, flagAlias := range f.Names()[1:] {
|
||||||
|
if _, ok := rawConfig[flagAlias]; ok {
|
||||||
|
rawConfig[flagName] = rawConfig[flagAlias]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return altsrc.NewMapInputSource(file, rawConfig), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifySchemaVersion(pgDB *sql.DB, store string, expected int) error {
|
||||||
|
var version int
|
||||||
|
err := pgDB.QueryRow(`SELECT version FROM schema_version WHERE store = $1`, store).Scan(&version)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot read %s schema version from PostgreSQL (is the schema set up?): %w", store, err)
|
||||||
|
}
|
||||||
|
if version != expected {
|
||||||
|
return fmt.Errorf("%s schema version mismatch: expected %d, got %d", store, expected, version)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printSource(label, path string) {
|
||||||
|
if path == "" {
|
||||||
|
fmt.Printf("%s(not set, skipping)\n", label)
|
||||||
|
} else if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||||
|
fmt.Printf("%s%s (NOT FOUND, skipping)\n", label, path)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s%s\n", label, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskPassword(databaseURL string) string {
|
||||||
|
u, err := url.Parse(databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return databaseURL
|
||||||
|
}
|
||||||
|
if u.User != nil {
|
||||||
|
if _, hasPass := u.User.Password(); hasPass {
|
||||||
|
masked := u.Scheme + "://" + u.User.Username() + ":****@" + u.Host + u.Path
|
||||||
|
if u.RawQuery != "" {
|
||||||
|
masked += "?" + u.RawQuery
|
||||||
|
}
|
||||||
|
return masked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func openSQLite(filename string) (*sql.DB, error) {
|
||||||
|
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("file %s does not exist", filename)
|
||||||
|
}
|
||||||
|
return sql.Open("sqlite3", filename+"?mode=ro")
|
||||||
|
}
|
||||||
|
|
||||||
|
// User import
|
||||||
|
|
||||||
|
func importUsers(sqliteFile string, pgDB *sql.DB) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Skipping user import: %s\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
fmt.Printf("Importing users from %s ...\n", sqliteFile)
|
||||||
|
|
||||||
|
count, err := importTiers(sqlDB, pgDB)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("importing tiers: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d tiers\n", count)
|
||||||
|
|
||||||
|
count, err = importUserRows(sqlDB, pgDB)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("importing users: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d users\n", count)
|
||||||
|
|
||||||
|
count, err = importUserAccess(sqlDB, pgDB)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("importing user access: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d access entries\n", count)
|
||||||
|
|
||||||
|
count, err = importUserTokens(sqlDB, pgDB)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("importing user tokens: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d tokens\n", count)
|
||||||
|
|
||||||
|
count, err = importUserPhones(sqlDB, pgDB)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("importing user phones: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d phone numbers\n", count)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func importTiers(sqlDB, pgDB *sql.DB) (int, error) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id FROM tier`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (id) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var id, code, name string
|
||||||
|
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit int64
|
||||||
|
var attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit int64
|
||||||
|
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
||||||
|
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := stmt.Exec(id, code, name, messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeMonthlyPriceID, stripeYearlyPriceID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func importUserRows(sqlDB, pgDB *sql.DB) (int, error) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT id, user, pass, role, prefs, sync_topic, provisioned, stats_messages, stats_emails, stats_calls, stripe_customer_id, stripe_subscription_id, stripe_subscription_status, stripe_subscription_interval, stripe_subscription_paid_until, stripe_subscription_cancel_at, created, deleted, tier_id FROM user`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`
|
||||||
|
INSERT INTO "user" (id, user_name, pass, role, prefs, sync_topic, provisioned, stats_messages, stats_emails, stats_calls, stripe_customer_id, stripe_subscription_id, stripe_subscription_status, stripe_subscription_interval, stripe_subscription_paid_until, stripe_subscription_cancel_at, created, deleted, tier_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||||
|
ON CONFLICT (id) DO NOTHING
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var id, userName, pass, role, prefs, syncTopic string
|
||||||
|
var provisioned int
|
||||||
|
var statsMessages, statsEmails, statsCalls int64
|
||||||
|
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval sql.NullString
|
||||||
|
var stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
|
||||||
|
var created int64
|
||||||
|
var deleted sql.NullInt64
|
||||||
|
var tierID sql.NullString
|
||||||
|
if err := rows.Scan(&id, &userName, &pass, &role, &prefs, &syncTopic, &provisioned, &statsMessages, &statsEmails, &statsCalls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &created, &deleted, &tierID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
provisionedBool := provisioned != 0
|
||||||
|
if _, err := stmt.Exec(id, userName, pass, role, prefs, syncTopic, provisionedBool, statsMessages, statsEmails, statsCalls, stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, created, deleted, tierID); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func importUserAccess(sqlDB, pgDB *sql.DB) (int, error) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT a.user_id, a.topic, a.read, a.write, a.owner_user_id, a.provisioned FROM user_access a JOIN user u ON u.id = a.user_id`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id, topic) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var userID, topic string
|
||||||
|
var read, write, provisioned int
|
||||||
|
var ownerUserID sql.NullString
|
||||||
|
if err := rows.Scan(&userID, &topic, &read, &write, &ownerUserID, &provisioned); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
readBool := read != 0
|
||||||
|
writeBool := write != 0
|
||||||
|
provisionedBool := provisioned != 0
|
||||||
|
if _, err := stmt.Exec(userID, topic, readBool, writeBool, ownerUserID, provisionedBool); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func importUserTokens(sqlDB, pgDB *sql.DB) (int, error) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT t.user_id, t.token, t.label, t.last_access, t.last_origin, t.expires, t.provisioned FROM user_token t JOIN user u ON u.id = t.user_id`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (user_id, token) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var userID, token, label, lastOrigin string
|
||||||
|
var lastAccess, expires int64
|
||||||
|
var provisioned int
|
||||||
|
if err := rows.Scan(&userID, &token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
provisionedBool := provisioned != 0
|
||||||
|
if _, err := stmt.Exec(userID, token, label, lastAccess, lastOrigin, expires, provisionedBool); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func importUserPhones(sqlDB, pgDB *sql.DB) (int, error) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT p.user_id, p.phone_number FROM user_phone p JOIN user u ON u.id = p.user_id`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2) ON CONFLICT (user_id, phone_number) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var userID, phoneNumber string
|
||||||
|
if err := rows.Scan(&userID, &phoneNumber); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := stmt.Exec(userID, phoneNumber); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return count, tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message import
|
||||||
|
|
||||||
|
func importMessages(sqliteFile string, pgDB *sql.DB) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Skipping message import: %s\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
fmt.Printf("Importing messages from %s ...\n", sqliteFile)
|
||||||
|
|
||||||
|
rows, err := sqlDB.Query(`SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published FROM messages`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("querying messages: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if _, err := pgDB.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_message_mid_unique ON message (mid)`); err != nil {
|
||||||
|
return fmt.Errorf("creating unique index on mid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
insertQuery := `INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) ON CONFLICT (mid) DO NOTHING`
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
batchCount := 0
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(insertQuery)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var mid, sequenceID, event, topic, message, title, tags, click, icon, actions string
|
||||||
|
var attachmentName, attachmentType, attachmentURL, sender, userID, contentType, encoding string
|
||||||
|
var msgTime, expires, attachmentExpires int64
|
||||||
|
var priority int
|
||||||
|
var attachmentSize int64
|
||||||
|
var attachmentDeleted, published int
|
||||||
|
if err := rows.Scan(&mid, &sequenceID, &msgTime, &event, &expires, &topic, &message, &title, &priority, &tags, &click, &icon, &actions, &attachmentName, &attachmentType, &attachmentSize, &attachmentExpires, &attachmentURL, &attachmentDeleted, &sender, &userID, &contentType, &encoding, &published); err != nil {
|
||||||
|
return fmt.Errorf("scanning message: %w", err)
|
||||||
|
}
|
||||||
|
mid = toUTF8(mid)
|
||||||
|
sequenceID = toUTF8(sequenceID)
|
||||||
|
event = toUTF8(event)
|
||||||
|
topic = toUTF8(topic)
|
||||||
|
message = toUTF8(message)
|
||||||
|
title = toUTF8(title)
|
||||||
|
tags = toUTF8(tags)
|
||||||
|
click = toUTF8(click)
|
||||||
|
icon = toUTF8(icon)
|
||||||
|
actions = toUTF8(actions)
|
||||||
|
attachmentName = toUTF8(attachmentName)
|
||||||
|
attachmentType = toUTF8(attachmentType)
|
||||||
|
attachmentURL = toUTF8(attachmentURL)
|
||||||
|
sender = toUTF8(sender)
|
||||||
|
userID = toUTF8(userID)
|
||||||
|
contentType = toUTF8(contentType)
|
||||||
|
encoding = toUTF8(encoding)
|
||||||
|
attachmentDeletedBool := attachmentDeleted != 0
|
||||||
|
publishedBool := published != 0
|
||||||
|
if _, err := stmt.Exec(mid, sequenceID, msgTime, event, expires, topic, message, title, priority, tags, click, icon, actions, attachmentName, attachmentType, attachmentSize, attachmentExpires, attachmentURL, attachmentDeletedBool, sender, userID, contentType, encoding, publishedBool); err != nil {
|
||||||
|
return fmt.Errorf("inserting message: %w", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
batchCount++
|
||||||
|
if batchCount >= batchSize {
|
||||||
|
stmt.Close()
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing message batch: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" ... %d messages\n", count)
|
||||||
|
tx, err = pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stmt, err = tx.Prepare(insertQuery)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
batchCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if batchCount > 0 {
|
||||||
|
stmt.Close()
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing final message batch: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d messages\n", count)
|
||||||
|
|
||||||
|
var statsValue int64
|
||||||
|
err = sqlDB.QueryRow(`SELECT value FROM stats WHERE key = 'messages'`).Scan(&statsValue)
|
||||||
|
if err == nil {
|
||||||
|
if _, err := pgDB.Exec(`UPDATE message_stats SET value = $1 WHERE key = 'messages'`, statsValue); err != nil {
|
||||||
|
return fmt.Errorf("updating message stats: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Updated message stats (count: %d)\n", statsValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Web push import
|
||||||
|
|
||||||
|
func importWebPush(sqliteFile string, pgDB *sql.DB) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Skipping web push import: %s\n", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
fmt.Printf("Importing web push subscriptions from %s ...\n", sqliteFile)
|
||||||
|
|
||||||
|
rows, err := sqlDB.Query(`SELECT id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at FROM subscription`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("querying subscriptions: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
tx, err := pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.Prepare(`INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (id) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var id, endpoint, keyAuth, keyP256dh, userID, subscriberIP string
|
||||||
|
var updatedAt, warnedAt int64
|
||||||
|
if err := rows.Scan(&id, &endpoint, &keyAuth, &keyP256dh, &userID, &subscriberIP, &updatedAt, &warnedAt); err != nil {
|
||||||
|
return fmt.Errorf("scanning subscription: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := stmt.Exec(id, endpoint, keyAuth, keyP256dh, userID, subscriberIP, updatedAt, warnedAt); err != nil {
|
||||||
|
return fmt.Errorf("inserting subscription: %w", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
stmt.Close()
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing subscriptions: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d subscriptions\n", count)
|
||||||
|
|
||||||
|
topicRows, err := sqlDB.Query(`SELECT subscription_id, topic FROM subscription_topic`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("querying subscription topics: %w", err)
|
||||||
|
}
|
||||||
|
defer topicRows.Close()
|
||||||
|
|
||||||
|
tx, err = pgDB.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err = tx.Prepare(`INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2) ON CONFLICT (subscription_id, topic) DO NOTHING`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
topicCount := 0
|
||||||
|
for topicRows.Next() {
|
||||||
|
var subscriptionID, topic string
|
||||||
|
if err := topicRows.Scan(&subscriptionID, &topic); err != nil {
|
||||||
|
return fmt.Errorf("scanning subscription topic: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := stmt.Exec(subscriptionID, topic); err != nil {
|
||||||
|
return fmt.Errorf("inserting subscription topic: %w", err)
|
||||||
|
}
|
||||||
|
topicCount++
|
||||||
|
}
|
||||||
|
stmt.Close()
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("committing subscription topics: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf(" Imported %d subscription topics\n", topicCount)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toUTF8(s string) string {
|
||||||
|
return strings.ToValidUTF8(s, "\uFFFD")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verification
|
||||||
|
|
||||||
|
func verifyUsers(sqliteFile string, pgDB *sql.DB, failed *bool) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "tier", `SELECT COUNT(*) FROM tier`, `SELECT COUNT(*) FROM tier`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "tier",
|
||||||
|
`SELECT id, code, name FROM tier ORDER BY id`,
|
||||||
|
`SELECT id, code, name FROM tier ORDER BY id COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "user", `SELECT COUNT(*) FROM user`, `SELECT COUNT(*) FROM "user"`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "user",
|
||||||
|
`SELECT id, user, role, sync_topic FROM user ORDER BY id`,
|
||||||
|
`SELECT id, user_name, role, sync_topic FROM "user" ORDER BY id COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "user_access", `SELECT COUNT(*) FROM user_access a JOIN user u ON u.id = a.user_id`, `SELECT COUNT(*) FROM user_access`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "user_access",
|
||||||
|
`SELECT a.user_id, a.topic FROM user_access a JOIN user u ON u.id = a.user_id ORDER BY a.user_id, a.topic`,
|
||||||
|
`SELECT user_id, topic FROM user_access ORDER BY user_id COLLATE "C", topic COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "user_token", `SELECT COUNT(*) FROM user_token t JOIN user u ON u.id = t.user_id`, `SELECT COUNT(*) FROM user_token`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "user_token",
|
||||||
|
`SELECT t.user_id, t.token, t.label FROM user_token t JOIN user u ON u.id = t.user_id ORDER BY t.user_id, t.token`,
|
||||||
|
`SELECT user_id, token, label FROM user_token ORDER BY user_id COLLATE "C", token COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "user_phone", `SELECT COUNT(*) FROM user_phone p JOIN user u ON u.id = p.user_id`, `SELECT COUNT(*) FROM user_phone`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "user_phone",
|
||||||
|
`SELECT p.user_id, p.phone_number FROM user_phone p JOIN user u ON u.id = p.user_id ORDER BY p.user_id, p.phone_number`,
|
||||||
|
`SELECT user_id, phone_number FROM user_phone ORDER BY user_id COLLATE "C", phone_number COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyMessages(sqliteFile string, pgDB *sql.DB, failed *bool) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "messages", `SELECT COUNT(*) FROM messages`, `SELECT COUNT(*) FROM message`, failed)
|
||||||
|
verifySampledMessages(sqlDB, pgDB, failed)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyWebPush(sqliteFile string, pgDB *sql.DB, failed *bool) error {
|
||||||
|
sqlDB, err := openSQLite(sqliteFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer sqlDB.Close()
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "subscription", `SELECT COUNT(*) FROM subscription`, `SELECT COUNT(*) FROM webpush_subscription`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "subscription",
|
||||||
|
`SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription ORDER BY id`,
|
||||||
|
`SELECT id, endpoint, key_auth, key_p256dh, user_id FROM webpush_subscription ORDER BY id COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
verifyCount(sqlDB, pgDB, "subscription_topic", `SELECT COUNT(*) FROM subscription_topic`, `SELECT COUNT(*) FROM webpush_subscription_topic`, failed)
|
||||||
|
verifyContent(sqlDB, pgDB, "subscription_topic",
|
||||||
|
`SELECT subscription_id, topic FROM subscription_topic ORDER BY subscription_id, topic`,
|
||||||
|
`SELECT subscription_id, topic FROM webpush_subscription_topic ORDER BY subscription_id COLLATE "C", topic COLLATE "C"`,
|
||||||
|
failed)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyCount(sqlDB, pgDB *sql.DB, table, sqliteQuery, pgQuery string, failed *bool) {
|
||||||
|
var sqliteCount, pgCount int64
|
||||||
|
if err := sqlDB.QueryRow(sqliteQuery).Scan(&sqliteCount); err != nil {
|
||||||
|
fmt.Printf(" %-25s count ERROR reading SQLite: %s\n", table, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := pgDB.QueryRow(pgQuery).Scan(&pgCount); err != nil {
|
||||||
|
fmt.Printf(" %-25s count ERROR reading PostgreSQL: %s\n", table, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sqliteCount == pgCount {
|
||||||
|
fmt.Printf(" %-25s count OK (%d rows)\n", table, pgCount)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" %-25s count MISMATCH: SQLite=%d, PostgreSQL=%d\n", table, sqliteCount, pgCount)
|
||||||
|
*failed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyContent(sqlDB, pgDB *sql.DB, table, sqliteQuery, pgQuery string, failed *bool) {
|
||||||
|
sqliteRows, err := sqlDB.Query(sqliteQuery)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR reading SQLite: %s\n", table, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer sqliteRows.Close()
|
||||||
|
|
||||||
|
pgRows, err := pgDB.Query(pgQuery)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR reading PostgreSQL: %s\n", table, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer pgRows.Close()
|
||||||
|
|
||||||
|
cols, err := sqliteRows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR reading columns: %s\n", table, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
numCols := len(cols)
|
||||||
|
|
||||||
|
rowNum := 0
|
||||||
|
mismatches := 0
|
||||||
|
for sqliteRows.Next() {
|
||||||
|
rowNum++
|
||||||
|
if !pgRows.Next() {
|
||||||
|
fmt.Printf(" %-25s content MISMATCH: PostgreSQL has fewer rows (at row %d)\n", table, rowNum)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sqliteVals := makeStringSlice(numCols)
|
||||||
|
pgVals := makeStringSlice(numCols)
|
||||||
|
if err := sqliteRows.Scan(sqliteVals...); err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR scanning SQLite row %d: %s\n", table, rowNum, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := pgRows.Scan(pgVals...); err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR scanning PostgreSQL row %d: %s\n", table, rowNum, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < numCols; i++ {
|
||||||
|
sv := *(sqliteVals[i].(*sql.NullString))
|
||||||
|
pv := *(pgVals[i].(*sql.NullString))
|
||||||
|
if sv != pv {
|
||||||
|
mismatches++
|
||||||
|
if mismatches <= 3 {
|
||||||
|
fmt.Printf(" %-25s content MISMATCH at row %d, col %s: SQLite=%q, PostgreSQL=%q\n", table, rowNum, cols[i], sv.String, pv.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pgRows.Next() {
|
||||||
|
fmt.Printf(" %-25s content MISMATCH: PostgreSQL has more rows than SQLite\n", table)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if mismatches > 0 {
|
||||||
|
if mismatches > 3 {
|
||||||
|
fmt.Printf(" %-25s content ... and %d more mismatches\n", table, mismatches-3)
|
||||||
|
}
|
||||||
|
*failed = true
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" %-25s content OK\n", table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifySampledMessages(sqlDB, pgDB *sql.DB, failed *bool) {
|
||||||
|
rows, err := sqlDB.Query(`SELECT mid, topic, time, message, title, tags, priority FROM messages ORDER BY mid`)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR reading SQLite: %s\n", "messages (sampled)", err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
rowNum := 0
|
||||||
|
checked := 0
|
||||||
|
mismatches := 0
|
||||||
|
for rows.Next() {
|
||||||
|
rowNum++
|
||||||
|
var mid, topic, message, title, tags string
|
||||||
|
var msgTime int64
|
||||||
|
var priority int
|
||||||
|
if err := rows.Scan(&mid, &topic, &msgTime, &message, &title, &tags, &priority); err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR scanning SQLite row %d: %s\n", "messages (sampled)", rowNum, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rowNum%100 != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
checked++
|
||||||
|
var pgTopic, pgMessage, pgTitle, pgTags string
|
||||||
|
var pgTime int64
|
||||||
|
var pgPriority int
|
||||||
|
err := pgDB.QueryRow(`SELECT topic, time, message, title, tags, priority FROM message WHERE mid = $1`, mid).
|
||||||
|
Scan(&pgTopic, &pgTime, &pgMessage, &pgTitle, &pgTags, &pgPriority)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
mismatches++
|
||||||
|
if mismatches <= 3 {
|
||||||
|
fmt.Printf(" %-25s content MISMATCH: mid=%s not found in PostgreSQL\n", "messages (sampled)", mid)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
fmt.Printf(" %-25s content ERROR querying PostgreSQL for mid=%s: %s\n", "messages (sampled)", mid, err)
|
||||||
|
*failed = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
topic = toUTF8(topic)
|
||||||
|
message = toUTF8(message)
|
||||||
|
title = toUTF8(title)
|
||||||
|
tags = toUTF8(tags)
|
||||||
|
if topic != pgTopic || msgTime != pgTime || message != pgMessage || title != pgTitle || tags != pgTags || priority != pgPriority {
|
||||||
|
mismatches++
|
||||||
|
if mismatches <= 3 {
|
||||||
|
fmt.Printf(" %-25s content MISMATCH at mid=%s\n", "messages (sampled)", mid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mismatches > 0 {
|
||||||
|
if mismatches > 3 {
|
||||||
|
fmt.Printf(" %-25s content ... and %d more mismatches\n", "messages (sampled)", mismatches-3)
|
||||||
|
}
|
||||||
|
*failed = true
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" %-25s content OK (%d samples checked)\n", "messages (sampled)", checked)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeStringSlice(n int) []any {
|
||||||
|
vals := make([]any, n)
|
||||||
|
for i := range vals {
|
||||||
|
vals[i] = &sql.NullString{}
|
||||||
|
}
|
||||||
|
return vals
|
||||||
|
}
|
||||||
1327
user/manager.go
1327
user/manager.go
File diff suppressed because it is too large
Load Diff
270
user/manager_postgres.go
Normal file
270
user/manager_postgres.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL queries
|
||||||
|
const (
|
||||||
|
// User queries
|
||||||
|
postgresSelectUserByIDQuery = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.id = $1
|
||||||
|
`
|
||||||
|
postgresSelectUserByNameQuery = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE user_name = $1
|
||||||
|
`
|
||||||
|
postgresSelectUserByTokenQuery = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
JOIN user_token tk on u.id = tk.user_id
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||||
|
`
|
||||||
|
postgresSelectUserByStripeCustomerIDQuery = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.stripe_customer_id = $1
|
||||||
|
`
|
||||||
|
postgresSelectUsernamesQuery = `
|
||||||
|
SELECT user_name
|
||||||
|
FROM "user"
|
||||||
|
ORDER BY
|
||||||
|
CASE role
|
||||||
|
WHEN 'admin' THEN 1
|
||||||
|
WHEN 'anonymous' THEN 3
|
||||||
|
ELSE 2
|
||||||
|
END, user_name
|
||||||
|
`
|
||||||
|
postgresSelectUserCountQuery = `SELECT COUNT(*) FROM "user"`
|
||||||
|
postgresSelectUserIDFromUsernameQuery = `SELECT id FROM "user" WHERE user_name = $1`
|
||||||
|
postgresInsertUserQuery = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||||
|
postgresUpdateUserPassQuery = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserRoleQuery = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserProvisionedQuery = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserPrefsQuery = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||||
|
postgresUpdateUserStatsQuery = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||||
|
postgresUpdateUserStatsResetAllQuery = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||||
|
postgresUpdateUserTierQuery = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||||
|
postgresUpdateUserDeletedQuery = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||||
|
postgresDeleteUserQuery = `DELETE FROM "user" WHERE user_name = $1`
|
||||||
|
postgresDeleteUserTierQuery = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||||
|
postgresDeleteUsersMarkedQuery = `DELETE FROM "user" WHERE deleted < $1`
|
||||||
|
|
||||||
|
// Access queries
|
||||||
|
postgresSelectTopicPermsQuery = `
|
||||||
|
SELECT read, write
|
||||||
|
FROM user_access a
|
||||||
|
JOIN "user" u ON u.id = a.user_id
|
||||||
|
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||||
|
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||||
|
`
|
||||||
|
postgresSelectUserAllAccessQuery = `
|
||||||
|
SELECT user_id, topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||||
|
`
|
||||||
|
postgresSelectUserAccessQuery = `
|
||||||
|
SELECT topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||||
|
`
|
||||||
|
postgresSelectUserReservationsQuery = `
|
||||||
|
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||||
|
FROM user_access a_user
|
||||||
|
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
WHERE a_user.user_id = a_user.owner_user_id
|
||||||
|
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||||
|
ORDER BY a_user.topic
|
||||||
|
`
|
||||||
|
postgresSelectUserReservationsCountQuery = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
`
|
||||||
|
postgresSelectUserReservationsOwnerQuery = `
|
||||||
|
SELECT owner_user_id
|
||||||
|
FROM user_access
|
||||||
|
WHERE topic = $1
|
||||||
|
AND user_id = owner_user_id
|
||||||
|
`
|
||||||
|
postgresSelectUserHasReservationQuery = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
AND topic = $2
|
||||||
|
`
|
||||||
|
postgresSelectOtherAccessCountQuery = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||||
|
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||||
|
`
|
||||||
|
postgresUpsertUserAccessQuery = `
|
||||||
|
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||||
|
VALUES (
|
||||||
|
(SELECT id FROM "user" WHERE user_name = $1),
|
||||||
|
$2,
|
||||||
|
$3,
|
||||||
|
$4,
|
||||||
|
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||||
|
$7
|
||||||
|
)
|
||||||
|
ON CONFLICT (user_id, topic)
|
||||||
|
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||||
|
`
|
||||||
|
postgresDeleteUserAccessQuery = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||||
|
`
|
||||||
|
postgresDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = true`
|
||||||
|
postgresDeleteTopicAccessQuery = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||||
|
AND topic = $3
|
||||||
|
`
|
||||||
|
postgresDeleteAllAccessQuery = `DELETE FROM user_access`
|
||||||
|
|
||||||
|
// Token queries
|
||||||
|
postgresSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||||
|
postgresSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||||
|
postgresSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||||
|
postgresSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||||
|
postgresUpsertTokenQuery = `
|
||||||
|
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
ON CONFLICT (user_id, token)
|
||||||
|
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
|
||||||
|
`
|
||||||
|
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
|
||||||
|
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||||
|
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||||
|
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
|
||||||
|
postgresDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = $1`
|
||||||
|
postgresDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||||
|
postgresDeleteExcessTokensQuery = `
|
||||||
|
DELETE FROM user_token
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND (user_id, token) NOT IN (
|
||||||
|
SELECT user_id, token
|
||||||
|
FROM user_token
|
||||||
|
WHERE user_id = $2
|
||||||
|
ORDER BY expires DESC
|
||||||
|
LIMIT $3
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
// Tier queries
|
||||||
|
postgresInsertTierQuery = `
|
||||||
|
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||||
|
`
|
||||||
|
postgresUpdateTierQuery = `
|
||||||
|
UPDATE tier
|
||||||
|
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||||
|
WHERE code = $13
|
||||||
|
`
|
||||||
|
postgresSelectTiersQuery = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
`
|
||||||
|
postgresSelectTierByCodeQuery = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE code = $1
|
||||||
|
`
|
||||||
|
postgresSelectTierByPriceIDQuery = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||||
|
`
|
||||||
|
postgresDeleteTierQuery = `DELETE FROM tier WHERE code = $1`
|
||||||
|
|
||||||
|
// Phone queries
|
||||||
|
postgresSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||||
|
postgresInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||||
|
postgresDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||||
|
|
||||||
|
// Billing queries
|
||||||
|
postgresUpdateBillingQuery = `
|
||||||
|
UPDATE "user"
|
||||||
|
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||||
|
WHERE user_name = $7
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewPostgresManager creates a new Manager backed by a PostgreSQL database using an existing connection pool.
|
||||||
|
var postgresQueries = queries{
|
||||||
|
selectUserByID: postgresSelectUserByIDQuery,
|
||||||
|
selectUserByName: postgresSelectUserByNameQuery,
|
||||||
|
selectUserByToken: postgresSelectUserByTokenQuery,
|
||||||
|
selectUserByStripeCustomerID: postgresSelectUserByStripeCustomerIDQuery,
|
||||||
|
selectUsernames: postgresSelectUsernamesQuery,
|
||||||
|
selectUserCount: postgresSelectUserCountQuery,
|
||||||
|
selectUserIDFromUsername: postgresSelectUserIDFromUsernameQuery,
|
||||||
|
insertUser: postgresInsertUserQuery,
|
||||||
|
updateUserPass: postgresUpdateUserPassQuery,
|
||||||
|
updateUserRole: postgresUpdateUserRoleQuery,
|
||||||
|
updateUserProvisioned: postgresUpdateUserProvisionedQuery,
|
||||||
|
updateUserPrefs: postgresUpdateUserPrefsQuery,
|
||||||
|
updateUserStats: postgresUpdateUserStatsQuery,
|
||||||
|
updateUserStatsResetAll: postgresUpdateUserStatsResetAllQuery,
|
||||||
|
updateUserTier: postgresUpdateUserTierQuery,
|
||||||
|
updateUserDeleted: postgresUpdateUserDeletedQuery,
|
||||||
|
deleteUser: postgresDeleteUserQuery,
|
||||||
|
deleteUserTier: postgresDeleteUserTierQuery,
|
||||||
|
deleteUsersMarked: postgresDeleteUsersMarkedQuery,
|
||||||
|
selectTopicPerms: postgresSelectTopicPermsQuery,
|
||||||
|
selectUserAllAccess: postgresSelectUserAllAccessQuery,
|
||||||
|
selectUserAccess: postgresSelectUserAccessQuery,
|
||||||
|
selectUserReservations: postgresSelectUserReservationsQuery,
|
||||||
|
selectUserReservationsCount: postgresSelectUserReservationsCountQuery,
|
||||||
|
selectUserReservationsOwner: postgresSelectUserReservationsOwnerQuery,
|
||||||
|
selectUserHasReservation: postgresSelectUserHasReservationQuery,
|
||||||
|
selectOtherAccessCount: postgresSelectOtherAccessCountQuery,
|
||||||
|
upsertUserAccess: postgresUpsertUserAccessQuery,
|
||||||
|
deleteUserAccess: postgresDeleteUserAccessQuery,
|
||||||
|
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisionedQuery,
|
||||||
|
deleteTopicAccess: postgresDeleteTopicAccessQuery,
|
||||||
|
deleteAllAccess: postgresDeleteAllAccessQuery,
|
||||||
|
selectToken: postgresSelectTokenQuery,
|
||||||
|
selectTokens: postgresSelectTokensQuery,
|
||||||
|
selectTokenCount: postgresSelectTokenCountQuery,
|
||||||
|
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
|
||||||
|
upsertToken: postgresUpsertTokenQuery,
|
||||||
|
updateToken: postgresUpdateTokenQuery,
|
||||||
|
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
|
||||||
|
deleteToken: postgresDeleteTokenQuery,
|
||||||
|
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,
|
||||||
|
deleteAllToken: postgresDeleteAllTokenQuery,
|
||||||
|
deleteExpiredTokens: postgresDeleteExpiredTokensQuery,
|
||||||
|
deleteExcessTokens: postgresDeleteExcessTokensQuery,
|
||||||
|
insertTier: postgresInsertTierQuery,
|
||||||
|
selectTiers: postgresSelectTiersQuery,
|
||||||
|
selectTierByCode: postgresSelectTierByCodeQuery,
|
||||||
|
selectTierByPriceID: postgresSelectTierByPriceIDQuery,
|
||||||
|
updateTier: postgresUpdateTierQuery,
|
||||||
|
deleteTier: postgresDeleteTierQuery,
|
||||||
|
selectPhoneNumbers: postgresSelectPhoneNumbersQuery,
|
||||||
|
insertPhoneNumber: postgresInsertPhoneNumberQuery,
|
||||||
|
deletePhoneNumber: postgresDeletePhoneNumberQuery,
|
||||||
|
updateBilling: postgresUpdateBillingQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresManager creates a new Manager backed by a PostgreSQL database
|
||||||
|
func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) {
|
||||||
|
if err := setupPostgres(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newManager(db, postgresQueries, config)
|
||||||
|
}
|
||||||
@@ -85,13 +85,13 @@ const (
|
|||||||
// Schema table management queries for Postgres
|
// Schema table management queries for Postgres
|
||||||
const (
|
const (
|
||||||
postgresCurrentSchemaVersion = 6
|
postgresCurrentSchemaVersion = 6
|
||||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupPostgres(db *sql.DB) error {
|
func setupPostgres(db *sql.DB) error {
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewPostgres(db)
|
return setupNewPostgres(db)
|
||||||
}
|
}
|
||||||
@@ -106,7 +106,7 @@ func setupNewPostgres(db *sql.DB) error {
|
|||||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
|
if _, err := db.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -2,38 +2,42 @@ package user
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// User queries
|
// User queries
|
||||||
sqliteSelectUserByID = `
|
sqliteSelectUserByIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE u.id = ?
|
WHERE u.id = ?
|
||||||
`
|
`
|
||||||
sqliteSelectUserByName = `
|
sqliteSelectUserByNameQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE user = ?
|
WHERE user = ?
|
||||||
`
|
`
|
||||||
sqliteSelectUserByToken = `
|
sqliteSelectUserByTokenQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
JOIN user_token tk on u.id = tk.user_id
|
JOIN user_token tk on u.id = tk.user_id
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||||
`
|
`
|
||||||
sqliteSelectUserByStripeID = `
|
sqliteSelectUserByStripeCustomerIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE u.stripe_customer_id = ?
|
WHERE u.stripe_customer_id = ?
|
||||||
`
|
`
|
||||||
sqliteSelectUsernames = `
|
sqliteSelectUsernamesQuery = `
|
||||||
SELECT user
|
SELECT user
|
||||||
FROM user
|
FROM user
|
||||||
ORDER BY
|
ORDER BY
|
||||||
@@ -43,41 +47,41 @@ const (
|
|||||||
ELSE 2
|
ELSE 2
|
||||||
END, user
|
END, user
|
||||||
`
|
`
|
||||||
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
|
sqliteSelectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||||
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
|
sqliteSelectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
|
||||||
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
sqliteInsertUserQuery = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||||
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
|
sqliteUpdateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||||
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
|
sqliteUpdateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||||
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
|
sqliteUpdateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||||
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
|
sqliteUpdateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||||
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
sqliteUpdateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||||
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
sqliteUpdateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||||
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
sqliteUpdateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||||
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
|
sqliteUpdateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||||
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
|
sqliteDeleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||||
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
|
sqliteDeleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||||
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
|
sqliteDeleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||||
|
|
||||||
// Access queries
|
// Access queries
|
||||||
sqliteSelectTopicPerms = `
|
sqliteSelectTopicPermsQuery = `
|
||||||
SELECT read, write
|
SELECT read, write
|
||||||
FROM user_access a
|
FROM user_access a
|
||||||
JOIN user u ON u.id = a.user_id
|
JOIN user u ON u.id = a.user_id
|
||||||
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||||
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||||
`
|
`
|
||||||
sqliteSelectUserAllAccess = `
|
sqliteSelectUserAllAccessQuery = `
|
||||||
SELECT user_id, topic, read, write, provisioned
|
SELECT user_id, topic, read, write, provisioned
|
||||||
FROM user_access
|
FROM user_access
|
||||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||||
`
|
`
|
||||||
sqliteSelectUserAccess = `
|
sqliteSelectUserAccessQuery = `
|
||||||
SELECT topic, read, write, provisioned
|
SELECT topic, read, write, provisioned
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||||
`
|
`
|
||||||
sqliteSelectUserReservations = `
|
sqliteSelectUserReservationsQuery = `
|
||||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||||
FROM user_access a_user
|
FROM user_access a_user
|
||||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
@@ -85,69 +89,68 @@ const (
|
|||||||
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
ORDER BY a_user.topic
|
ORDER BY a_user.topic
|
||||||
`
|
`
|
||||||
sqliteSelectUserReservationsCount = `
|
sqliteSelectUserReservationsCountQuery = `
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE user_id = owner_user_id
|
WHERE user_id = owner_user_id
|
||||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
`
|
`
|
||||||
sqliteSelectUserReservationsOwner = `
|
sqliteSelectUserReservationsOwnerQuery = `
|
||||||
SELECT owner_user_id
|
SELECT owner_user_id
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE topic = ?
|
WHERE topic = ?
|
||||||
AND user_id = owner_user_id
|
AND user_id = owner_user_id
|
||||||
`
|
`
|
||||||
sqliteSelectUserHasReservation = `
|
sqliteSelectUserHasReservationQuery = `
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE user_id = owner_user_id
|
WHERE user_id = owner_user_id
|
||||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
AND topic = ?
|
AND topic = ?
|
||||||
`
|
`
|
||||||
sqliteSelectOtherAccessCount = `
|
sqliteSelectOtherAccessCountQuery = `
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||||
`
|
`
|
||||||
sqliteUpsertUserAccess = `
|
sqliteUpsertUserAccessQuery = `
|
||||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||||
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||||
ON CONFLICT (user_id, topic)
|
ON CONFLICT (user_id, topic)
|
||||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||||
`
|
`
|
||||||
sqliteDeleteUserAccess = `
|
sqliteDeleteUserAccessQuery = `
|
||||||
DELETE FROM user_access
|
DELETE FROM user_access
|
||||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
`
|
`
|
||||||
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
|
sqliteDeleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
|
||||||
sqliteDeleteTopicAccess = `
|
sqliteDeleteTopicAccessQuery = `
|
||||||
DELETE FROM user_access
|
DELETE FROM user_access
|
||||||
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||||
AND topic = ?
|
AND topic = ?
|
||||||
`
|
`
|
||||||
sqliteDeleteAllAccess = `DELETE FROM user_access`
|
sqliteDeleteAllAccessQuery = `DELETE FROM user_access`
|
||||||
|
|
||||||
// Token queries
|
// Token queries
|
||||||
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
sqliteSelectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
sqliteSelectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||||
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
sqliteSelectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||||
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
sqliteSelectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||||
sqliteUpsertToken = `
|
sqliteUpsertTokenQuery = `
|
||||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT (user_id, token)
|
ON CONFLICT (user_id, token)
|
||||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
|
||||||
`
|
`
|
||||||
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
|
||||||
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||||
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
|
||||||
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
|
sqliteDeleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||||
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
|
sqliteDeleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||||
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
sqliteDeleteExcessTokensQuery = `
|
||||||
sqliteDeleteExcessTokens = `
|
|
||||||
DELETE FROM user_token
|
DELETE FROM user_token
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
AND (user_id, token) NOT IN (
|
AND (user_id, token) NOT IN (
|
||||||
@@ -160,46 +163,107 @@ const (
|
|||||||
`
|
`
|
||||||
|
|
||||||
// Tier queries
|
// Tier queries
|
||||||
sqliteInsertTier = `
|
sqliteInsertTierQuery = `
|
||||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
sqliteUpdateTier = `
|
sqliteUpdateTierQuery = `
|
||||||
UPDATE tier
|
UPDATE tier
|
||||||
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||||
WHERE code = ?
|
WHERE code = ?
|
||||||
`
|
`
|
||||||
sqliteSelectTiers = `
|
sqliteSelectTiersQuery = `
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
FROM tier
|
FROM tier
|
||||||
`
|
`
|
||||||
sqliteSelectTierByCode = `
|
sqliteSelectTierByCodeQuery = `
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
FROM tier
|
FROM tier
|
||||||
WHERE code = ?
|
WHERE code = ?
|
||||||
`
|
`
|
||||||
sqliteSelectTierByPriceID = `
|
sqliteSelectTierByPriceIDQuery = `
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
FROM tier
|
FROM tier
|
||||||
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||||
`
|
`
|
||||||
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
|
sqliteDeleteTierQuery = `DELETE FROM tier WHERE code = ?`
|
||||||
|
|
||||||
// Phone queries
|
// Phone queries
|
||||||
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
sqliteSelectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||||
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
sqliteInsertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||||
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
sqliteDeletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||||
|
|
||||||
// Billing queries
|
// Billing queries
|
||||||
sqliteUpdateBilling = `
|
sqliteUpdateBillingQuery = `
|
||||||
UPDATE user
|
UPDATE user
|
||||||
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||||
WHERE user = ?
|
WHERE user = ?
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewSQLiteStore creates a new SQLite-backed user store
|
var sqliteQueries = queries{
|
||||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
selectUserByID: sqliteSelectUserByIDQuery,
|
||||||
|
selectUserByName: sqliteSelectUserByNameQuery,
|
||||||
|
selectUserByToken: sqliteSelectUserByTokenQuery,
|
||||||
|
selectUserByStripeCustomerID: sqliteSelectUserByStripeCustomerIDQuery,
|
||||||
|
selectUsernames: sqliteSelectUsernamesQuery,
|
||||||
|
selectUserCount: sqliteSelectUserCountQuery,
|
||||||
|
selectUserIDFromUsername: sqliteSelectUserIDFromUsernameQuery,
|
||||||
|
insertUser: sqliteInsertUserQuery,
|
||||||
|
updateUserPass: sqliteUpdateUserPassQuery,
|
||||||
|
updateUserRole: sqliteUpdateUserRoleQuery,
|
||||||
|
updateUserProvisioned: sqliteUpdateUserProvisionedQuery,
|
||||||
|
updateUserPrefs: sqliteUpdateUserPrefsQuery,
|
||||||
|
updateUserStats: sqliteUpdateUserStatsQuery,
|
||||||
|
updateUserStatsResetAll: sqliteUpdateUserStatsResetAllQuery,
|
||||||
|
updateUserTier: sqliteUpdateUserTierQuery,
|
||||||
|
updateUserDeleted: sqliteUpdateUserDeletedQuery,
|
||||||
|
deleteUser: sqliteDeleteUserQuery,
|
||||||
|
deleteUserTier: sqliteDeleteUserTierQuery,
|
||||||
|
deleteUsersMarked: sqliteDeleteUsersMarkedQuery,
|
||||||
|
selectTopicPerms: sqliteSelectTopicPermsQuery,
|
||||||
|
selectUserAllAccess: sqliteSelectUserAllAccessQuery,
|
||||||
|
selectUserAccess: sqliteSelectUserAccessQuery,
|
||||||
|
selectUserReservations: sqliteSelectUserReservationsQuery,
|
||||||
|
selectUserReservationsCount: sqliteSelectUserReservationsCountQuery,
|
||||||
|
selectUserReservationsOwner: sqliteSelectUserReservationsOwnerQuery,
|
||||||
|
selectUserHasReservation: sqliteSelectUserHasReservationQuery,
|
||||||
|
selectOtherAccessCount: sqliteSelectOtherAccessCountQuery,
|
||||||
|
upsertUserAccess: sqliteUpsertUserAccessQuery,
|
||||||
|
deleteUserAccess: sqliteDeleteUserAccessQuery,
|
||||||
|
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisionedQuery,
|
||||||
|
deleteTopicAccess: sqliteDeleteTopicAccessQuery,
|
||||||
|
deleteAllAccess: sqliteDeleteAllAccessQuery,
|
||||||
|
selectToken: sqliteSelectTokenQuery,
|
||||||
|
selectTokens: sqliteSelectTokensQuery,
|
||||||
|
selectTokenCount: sqliteSelectTokenCountQuery,
|
||||||
|
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
|
||||||
|
upsertToken: sqliteUpsertTokenQuery,
|
||||||
|
updateToken: sqliteUpdateTokenQuery,
|
||||||
|
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
|
||||||
|
deleteToken: sqliteDeleteTokenQuery,
|
||||||
|
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,
|
||||||
|
deleteAllToken: sqliteDeleteAllTokenQuery,
|
||||||
|
deleteExpiredTokens: sqliteDeleteExpiredTokensQuery,
|
||||||
|
deleteExcessTokens: sqliteDeleteExcessTokensQuery,
|
||||||
|
insertTier: sqliteInsertTierQuery,
|
||||||
|
selectTiers: sqliteSelectTiersQuery,
|
||||||
|
selectTierByCode: sqliteSelectTierByCodeQuery,
|
||||||
|
selectTierByPriceID: sqliteSelectTierByPriceIDQuery,
|
||||||
|
updateTier: sqliteUpdateTierQuery,
|
||||||
|
deleteTier: sqliteDeleteTierQuery,
|
||||||
|
selectPhoneNumbers: sqliteSelectPhoneNumbersQuery,
|
||||||
|
insertPhoneNumber: sqliteInsertPhoneNumberQuery,
|
||||||
|
deletePhoneNumber: sqliteDeletePhoneNumberQuery,
|
||||||
|
updateBilling: sqliteUpdateBillingQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLiteManager creates a new Manager backed by a SQLite database
|
||||||
|
func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager, error) {
|
||||||
|
parentDir := filepath.Dir(filename)
|
||||||
|
if !util.FileExists(parentDir) {
|
||||||
|
return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
|
||||||
|
}
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -210,64 +274,5 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
|||||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return newManager(db, sqliteQueries, config)
|
||||||
db: db,
|
|
||||||
queries: storeQueries{
|
|
||||||
selectUserByID: sqliteSelectUserByID,
|
|
||||||
selectUserByName: sqliteSelectUserByName,
|
|
||||||
selectUserByToken: sqliteSelectUserByToken,
|
|
||||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
|
||||||
selectUsernames: sqliteSelectUsernames,
|
|
||||||
selectUserCount: sqliteSelectUserCount,
|
|
||||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
|
||||||
insertUser: sqliteInsertUser,
|
|
||||||
updateUserPass: sqliteUpdateUserPass,
|
|
||||||
updateUserRole: sqliteUpdateUserRole,
|
|
||||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
|
||||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
|
||||||
updateUserStats: sqliteUpdateUserStats,
|
|
||||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
|
||||||
updateUserTier: sqliteUpdateUserTier,
|
|
||||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
|
||||||
deleteUser: sqliteDeleteUser,
|
|
||||||
deleteUserTier: sqliteDeleteUserTier,
|
|
||||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
|
||||||
selectTopicPerms: sqliteSelectTopicPerms,
|
|
||||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
|
||||||
selectUserAccess: sqliteSelectUserAccess,
|
|
||||||
selectUserReservations: sqliteSelectUserReservations,
|
|
||||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
|
||||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
|
||||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
|
||||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
|
||||||
upsertUserAccess: sqliteUpsertUserAccess,
|
|
||||||
deleteUserAccess: sqliteDeleteUserAccess,
|
|
||||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
|
||||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
|
||||||
deleteAllAccess: sqliteDeleteAllAccess,
|
|
||||||
selectToken: sqliteSelectToken,
|
|
||||||
selectTokens: sqliteSelectTokens,
|
|
||||||
selectTokenCount: sqliteSelectTokenCount,
|
|
||||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
|
||||||
upsertToken: sqliteUpsertToken,
|
|
||||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
|
||||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
|
||||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
|
||||||
deleteToken: sqliteDeleteToken,
|
|
||||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
|
||||||
deleteAllToken: sqliteDeleteAllToken,
|
|
||||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
|
||||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
|
||||||
insertTier: sqliteInsertTier,
|
|
||||||
selectTiers: sqliteSelectTiers,
|
|
||||||
selectTierByCode: sqliteSelectTierByCode,
|
|
||||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
|
||||||
updateTier: sqliteUpdateTier,
|
|
||||||
deleteTier: sqliteDeleteTier,
|
|
||||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
|
||||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
|
||||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
|
||||||
updateBilling: sqliteUpdateBilling,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
@@ -103,8 +103,8 @@ const (
|
|||||||
// Schema version table management for SQLite
|
// Schema version table management for SQLite
|
||||||
const (
|
const (
|
||||||
sqliteCurrentSchemaVersion = 6
|
sqliteCurrentSchemaVersion = 6
|
||||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
sqliteUpdateSchemaVersionQuery = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -179,12 +179,12 @@ const (
|
|||||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||||
ON CONFLICT (id) DO NOTHING;
|
ON CONFLICT (id) DO NOTHING;
|
||||||
`
|
`
|
||||||
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery = `SELECT user FROM user_old`
|
||||||
sqliteMigrate1To2InsertUserNoTx = `
|
sqliteMigrate1To2InsertUserNoTxQuery = `
|
||||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||||
`
|
`
|
||||||
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
|
sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery = `
|
||||||
INSERT INTO user_access (user_id, topic, read, write)
|
INSERT INTO user_access (user_id, topic, read, write)
|
||||||
SELECT u.id, a.topic, a.read, a.write
|
SELECT u.id, a.topic, a.read, a.write
|
||||||
FROM user u
|
FROM user u
|
||||||
@@ -352,7 +352,7 @@ func setupNewSQLite(db *sql.DB) error {
|
|||||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
if _, err := db.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -382,7 +382,7 @@ func sqliteMigrateFrom1(db *sql.DB) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Insert users from user_old into new user table, with ID and sync_topic
|
// Insert users from user_old into new user table, with ID and sync_topic
|
||||||
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
|
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTxQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -401,15 +401,15 @@ func sqliteMigrateFrom1(db *sql.DB) error {
|
|||||||
for _, username := range usernames {
|
for _, username := range usernames {
|
||||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||||
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTxQuery, userID, syncTopic, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||||
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTxQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 2); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
@@ -428,7 +428,7 @@ func sqliteMigrateFrom2(db *sql.DB) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 3); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -444,7 +444,7 @@ func sqliteMigrateFrom3(db *sql.DB) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 4); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -460,7 +460,7 @@ func sqliteMigrateFrom4(db *sql.DB) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 5); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
@@ -476,7 +476,7 @@ func sqliteMigrateFrom5(db *sql.DB) error {
|
|||||||
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 6); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
File diff suppressed because it is too large
Load Diff
986
user/store.go
986
user/store.go
@@ -1,986 +0,0 @@
|
|||||||
package user
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net/netip"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/payments"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Store is the interface for a user database store
|
|
||||||
type Store interface {
|
|
||||||
// User operations
|
|
||||||
UserByID(id string) (*User, error)
|
|
||||||
User(username string) (*User, error)
|
|
||||||
UserByToken(token string) (*User, error)
|
|
||||||
UserByStripeCustomer(customerID string) (*User, error)
|
|
||||||
UserIDByUsername(username string) (string, error)
|
|
||||||
Users() ([]*User, error)
|
|
||||||
UsersCount() (int64, error)
|
|
||||||
AddUser(username, hash string, role Role, provisioned bool) error
|
|
||||||
RemoveUser(username string) error
|
|
||||||
MarkUserRemoved(userID string) error
|
|
||||||
RemoveDeletedUsers() error
|
|
||||||
ChangePassword(username, hash string) error
|
|
||||||
ChangeRole(username string, role Role) error
|
|
||||||
ChangeProvisioned(username string, provisioned bool) error
|
|
||||||
ChangeSettings(userID string, prefs *Prefs) error
|
|
||||||
ChangeTier(username, tierCode string) error
|
|
||||||
ResetTier(username string) error
|
|
||||||
UpdateStats(userID string, stats *Stats) error
|
|
||||||
ResetStats() error
|
|
||||||
|
|
||||||
// Token operations
|
|
||||||
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
|
|
||||||
Token(userID, token string) (*Token, error)
|
|
||||||
Tokens(userID string) ([]*Token, error)
|
|
||||||
AllProvisionedTokens() ([]*Token, error)
|
|
||||||
ChangeTokenLabel(userID, token, label string) error
|
|
||||||
ChangeTokenExpiry(userID, token string, expires time.Time) error
|
|
||||||
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
|
|
||||||
RemoveToken(userID, token string) error
|
|
||||||
RemoveExpiredTokens() error
|
|
||||||
TokenCount(userID string) (int, error)
|
|
||||||
RemoveExcessTokens(userID string, maxCount int) error
|
|
||||||
|
|
||||||
// Access operations
|
|
||||||
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
|
|
||||||
AllGrants() (map[string][]Grant, error)
|
|
||||||
Grants(username string) ([]Grant, error)
|
|
||||||
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
|
|
||||||
ResetAccess(username, topicPattern string) error
|
|
||||||
ResetAllProvisionedAccess() error
|
|
||||||
Reservations(username string) ([]Reservation, error)
|
|
||||||
HasReservation(username, topic string) (bool, error)
|
|
||||||
ReservationsCount(username string) (int64, error)
|
|
||||||
ReservationOwner(topic string) (string, error)
|
|
||||||
OtherAccessCount(username, topic string) (int, error)
|
|
||||||
|
|
||||||
// Tier operations
|
|
||||||
AddTier(tier *Tier) error
|
|
||||||
UpdateTier(tier *Tier) error
|
|
||||||
RemoveTier(code string) error
|
|
||||||
Tiers() ([]*Tier, error)
|
|
||||||
Tier(code string) (*Tier, error)
|
|
||||||
TierByStripePrice(priceID string) (*Tier, error)
|
|
||||||
|
|
||||||
// Phone operations
|
|
||||||
PhoneNumbers(userID string) ([]string, error)
|
|
||||||
AddPhoneNumber(userID, phoneNumber string) error
|
|
||||||
RemovePhoneNumber(userID, phoneNumber string) error
|
|
||||||
|
|
||||||
// Other stuff
|
|
||||||
ChangeBilling(username string, billing *Billing) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// storeQueries holds the database-specific SQL queries
|
|
||||||
type storeQueries struct {
|
|
||||||
// User queries
|
|
||||||
selectUserByID string
|
|
||||||
selectUserByName string
|
|
||||||
selectUserByToken string
|
|
||||||
selectUserByStripeID string
|
|
||||||
selectUsernames string
|
|
||||||
selectUserCount string
|
|
||||||
selectUserIDFromUsername string
|
|
||||||
insertUser string
|
|
||||||
updateUserPass string
|
|
||||||
updateUserRole string
|
|
||||||
updateUserProvisioned string
|
|
||||||
updateUserPrefs string
|
|
||||||
updateUserStats string
|
|
||||||
updateUserStatsResetAll string
|
|
||||||
updateUserTier string
|
|
||||||
updateUserDeleted string
|
|
||||||
deleteUser string
|
|
||||||
deleteUserTier string
|
|
||||||
deleteUsersMarked string
|
|
||||||
// Access queries
|
|
||||||
selectTopicPerms string
|
|
||||||
selectUserAllAccess string
|
|
||||||
selectUserAccess string
|
|
||||||
selectUserReservations string
|
|
||||||
selectUserReservationsCount string
|
|
||||||
selectUserReservationsOwner string
|
|
||||||
selectUserHasReservation string
|
|
||||||
selectOtherAccessCount string
|
|
||||||
upsertUserAccess string
|
|
||||||
deleteUserAccess string
|
|
||||||
deleteUserAccessProvisioned string
|
|
||||||
deleteTopicAccess string
|
|
||||||
deleteAllAccess string
|
|
||||||
// Token queries
|
|
||||||
selectToken string
|
|
||||||
selectTokens string
|
|
||||||
selectTokenCount string
|
|
||||||
selectAllProvisionedTokens string
|
|
||||||
upsertToken string
|
|
||||||
updateTokenLabel string
|
|
||||||
updateTokenExpiry string
|
|
||||||
updateTokenLastAccess string
|
|
||||||
deleteToken string
|
|
||||||
deleteProvisionedToken string
|
|
||||||
deleteAllToken string
|
|
||||||
deleteExpiredTokens string
|
|
||||||
deleteExcessTokens string
|
|
||||||
// Tier queries
|
|
||||||
insertTier string
|
|
||||||
selectTiers string
|
|
||||||
selectTierByCode string
|
|
||||||
selectTierByPriceID string
|
|
||||||
updateTier string
|
|
||||||
deleteTier string
|
|
||||||
// Phone queries
|
|
||||||
selectPhoneNumbers string
|
|
||||||
insertPhoneNumber string
|
|
||||||
deletePhoneNumber string
|
|
||||||
// Billing queries
|
|
||||||
updateBilling string
|
|
||||||
}
|
|
||||||
|
|
||||||
// commonStore implements store operations that work across database backends
|
|
||||||
type commonStore struct {
|
|
||||||
db *sql.DB
|
|
||||||
queries storeQueries
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
|
||||||
func (s *commonStore) UserByID(id string) (*User, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserByID, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.readUser(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
|
||||||
func (s *commonStore) User(username string) (*User, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserByName, username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.readUser(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
|
||||||
func (s *commonStore) UserByToken(token string) (*User, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.readUser(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
|
||||||
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.readUser(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Users returns a list of users
|
|
||||||
func (s *commonStore) Users() ([]*User, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUsernames)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
usernames := make([]string, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var username string
|
|
||||||
if err := rows.Scan(&username); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
usernames = append(usernames, username)
|
|
||||||
}
|
|
||||||
rows.Close()
|
|
||||||
users := make([]*User, 0)
|
|
||||||
for _, username := range usernames {
|
|
||||||
user, err := s.User(username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
users = append(users, user)
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsersCount returns the number of users in the database
|
|
||||||
func (s *commonStore) UsersCount() (int64, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserCount)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return 0, errNoRows
|
|
||||||
}
|
|
||||||
var count int64
|
|
||||||
if err := rows.Scan(&count); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUser adds a user with the given username, password hash and role
|
|
||||||
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
|
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
|
||||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
|
||||||
now := time.Now().Unix()
|
|
||||||
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
|
||||||
if isUniqueConstraintError(err) {
|
|
||||||
return ErrUserExists
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveUser deletes the user with the given username
|
|
||||||
func (s *commonStore) RemoveUser(username string) error {
|
|
||||||
if !AllowedUsername(username) {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
|
|
||||||
func (s *commonStore) MarkUserRemoved(userID string) error {
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
// Get username for deleteUserAccess query
|
|
||||||
user, err := s.UserByID(userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
|
||||||
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
|
||||||
func (s *commonStore) RemoveDeletedUsers() error {
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangePassword changes a user's password
|
|
||||||
func (s *commonStore) ChangePassword(username, hash string) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeRole changes a user's role
|
|
||||||
func (s *commonStore) ChangeRole(username string, role Role) error {
|
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// If changing to admin, remove all access entries
|
|
||||||
if role == RoleAdmin {
|
|
||||||
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeProvisioned changes the provisioned status of a user
|
|
||||||
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeSettings persists the user settings
|
|
||||||
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
|
|
||||||
b, err := json.Marshal(prefs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeTier changes a user's tier using the tier code
|
|
||||||
func (s *commonStore) ChangeTier(username, tierCode string) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetTier removes the tier from the given user
|
|
||||||
func (s *commonStore) ResetTier(username string) error {
|
|
||||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
_, err := s.db.Exec(s.queries.deleteUserTier, username)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStats updates the user statistics
|
|
||||||
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetStats resets all user stats in the user database
|
|
||||||
func (s *commonStore) ResetStats() error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
|
|
||||||
defer rows.Close()
|
|
||||||
var id, username, hash, role, prefs, syncTopic string
|
|
||||||
var provisioned bool
|
|
||||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
|
||||||
var messages, emails, calls int64
|
|
||||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, ErrUserNotFound
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
user := &User{
|
|
||||||
ID: id,
|
|
||||||
Name: username,
|
|
||||||
Hash: hash,
|
|
||||||
Role: Role(role),
|
|
||||||
Prefs: &Prefs{},
|
|
||||||
SyncTopic: syncTopic,
|
|
||||||
Provisioned: provisioned,
|
|
||||||
Stats: &Stats{
|
|
||||||
Messages: messages,
|
|
||||||
Emails: emails,
|
|
||||||
Calls: calls,
|
|
||||||
},
|
|
||||||
Billing: &Billing{
|
|
||||||
StripeCustomerID: stripeCustomerID.String,
|
|
||||||
StripeSubscriptionID: stripeSubscriptionID.String,
|
|
||||||
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
|
|
||||||
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
|
|
||||||
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
|
|
||||||
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
|
|
||||||
},
|
|
||||||
Deleted: deleted.Valid,
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if tierCode.Valid {
|
|
||||||
user.Tier = &Tier{
|
|
||||||
ID: tierID.String,
|
|
||||||
Code: tierCode.String,
|
|
||||||
Name: tierName.String,
|
|
||||||
MessageLimit: messagesLimit.Int64,
|
|
||||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
|
||||||
EmailLimit: emailsLimit.Int64,
|
|
||||||
CallLimit: callsLimit.Int64,
|
|
||||||
ReservationLimit: reservationsLimit.Int64,
|
|
||||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
|
||||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
|
||||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
|
||||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
|
||||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
|
||||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateToken creates a new token
|
|
||||||
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
|
|
||||||
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Token{
|
|
||||||
Value: token,
|
|
||||||
Label: label,
|
|
||||||
LastAccess: lastAccess,
|
|
||||||
LastOrigin: lastOrigin,
|
|
||||||
Expires: expires,
|
|
||||||
Provisioned: provisioned,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token returns a specific token for a user
|
|
||||||
func (s *commonStore) Token(userID, token string) (*Token, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectToken, userID, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
return s.readToken(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tokens returns all existing tokens for the user with the given user ID
|
|
||||||
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTokens, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
tokens := make([]*Token, 0)
|
|
||||||
for {
|
|
||||||
token, err := s.readToken(rows)
|
|
||||||
if errors.Is(err, ErrTokenNotFound) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tokens = append(tokens, token)
|
|
||||||
}
|
|
||||||
return tokens, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllProvisionedTokens returns all provisioned tokens
|
|
||||||
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
tokens := make([]*Token, 0)
|
|
||||||
for {
|
|
||||||
token, err := s.readToken(rows)
|
|
||||||
if errors.Is(err, ErrTokenNotFound) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tokens = append(tokens, token)
|
|
||||||
}
|
|
||||||
return tokens, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeTokenLabel updates a token's label
|
|
||||||
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeTokenExpiry updates a token's expiry time
|
|
||||||
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTokenLastAccess updates a token's last access time and origin
|
|
||||||
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveToken deletes the token
|
|
||||||
func (s *commonStore) RemoveToken(userID, token string) error {
|
|
||||||
if token == "" {
|
|
||||||
return errNoTokenProvided
|
|
||||||
}
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveExpiredTokens deletes all expired tokens from the database
|
|
||||||
func (s *commonStore) RemoveExpiredTokens() error {
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenCount returns the number of tokens for a user
|
|
||||||
func (s *commonStore) TokenCount(userID string) (int, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return 0, errNoRows
|
|
||||||
}
|
|
||||||
var count int
|
|
||||||
if err := rows.Scan(&count); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
|
|
||||||
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
|
|
||||||
var token, label, lastOrigin string
|
|
||||||
var lastAccess, expires int64
|
|
||||||
var provisioned bool
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, ErrTokenNotFound
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
|
||||||
if err != nil {
|
|
||||||
lastOriginIP = netip.IPv4Unspecified()
|
|
||||||
}
|
|
||||||
return &Token{
|
|
||||||
Value: token,
|
|
||||||
Label: label,
|
|
||||||
LastAccess: time.Unix(lastAccess, 0),
|
|
||||||
LastOrigin: lastOriginIP,
|
|
||||||
Expires: time.Unix(expires, 0),
|
|
||||||
Provisioned: provisioned,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
|
|
||||||
// The found return value indicates whether an ACL entry was found at all.
|
|
||||||
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
|
||||||
if err != nil {
|
|
||||||
return false, false, false, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return false, false, false, nil
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&read, &write); err != nil {
|
|
||||||
return false, false, false, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return false, false, false, err
|
|
||||||
}
|
|
||||||
return read, write, true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
|
||||||
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserAllAccess)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
grants := make(map[string][]Grant, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var userID, topic string
|
|
||||||
var read, write, provisioned bool
|
|
||||||
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if _, ok := grants[userID]; !ok {
|
|
||||||
grants[userID] = make([]Grant, 0)
|
|
||||||
}
|
|
||||||
grants[userID] = append(grants[userID], Grant{
|
|
||||||
TopicPattern: fromSQLWildcard(topic),
|
|
||||||
Permission: NewPermission(read, write),
|
|
||||||
Provisioned: provisioned,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return grants, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Grants returns all user-specific access control entries
|
|
||||||
func (s *commonStore) Grants(username string) ([]Grant, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserAccess, username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
grants := make([]Grant, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var topic string
|
|
||||||
var read, write, provisioned bool
|
|
||||||
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
grants = append(grants, Grant{
|
|
||||||
TopicPattern: fromSQLWildcard(topic),
|
|
||||||
Permission: NewPermission(read, write),
|
|
||||||
Provisioned: provisioned,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return grants, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowAccess adds or updates an entry in the access control list
|
|
||||||
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
|
||||||
if !AllowedUsername(username) && username != Everyone {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
} else if !AllowedTopicPattern(topicPattern) {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetAccess removes an access control list entry
|
|
||||||
func (s *commonStore) ResetAccess(username, topicPattern string) error {
|
|
||||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
if username == "" && topicPattern == "" {
|
|
||||||
_, err := s.db.Exec(s.queries.deleteAllAccess)
|
|
||||||
return err
|
|
||||||
} else if topicPattern == "" {
|
|
||||||
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetAllProvisionedAccess removes all provisioned access control entries
|
|
||||||
func (s *commonStore) ResetAllProvisionedAccess() error {
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
|
||||||
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
reservations := make([]Reservation, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var topic string
|
|
||||||
var ownerRead, ownerWrite bool
|
|
||||||
var everyoneRead, everyoneWrite sql.NullBool
|
|
||||||
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
reservations = append(reservations, Reservation{
|
|
||||||
Topic: unescapeUnderscore(topic),
|
|
||||||
Owner: NewPermission(ownerRead, ownerWrite),
|
|
||||||
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return reservations, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasReservation returns true if the given topic access is owned by the user
|
|
||||||
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return false, errNoRows
|
|
||||||
}
|
|
||||||
var count int64
|
|
||||||
if err := rows.Scan(&count); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return count > 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReservationsCount returns the number of reservations owned by this user
|
|
||||||
func (s *commonStore) ReservationsCount(username string) (int64, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return 0, errNoRows
|
|
||||||
}
|
|
||||||
var count int64
|
|
||||||
if err := rows.Scan(&count); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
|
||||||
func (s *commonStore) ReservationOwner(topic string) (string, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
var ownerUserID string
|
|
||||||
if err := rows.Scan(&ownerUserID); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return ownerUserID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
|
||||||
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return 0, errNoRows
|
|
||||||
}
|
|
||||||
var count int
|
|
||||||
if err := rows.Scan(&count); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTier creates a new tier in the database
|
|
||||||
func (s *commonStore) AddTier(tier *Tier) error {
|
|
||||||
if tier.ID == "" {
|
|
||||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
|
||||||
}
|
|
||||||
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateTier updates a tier's properties in the database
|
|
||||||
func (s *commonStore) UpdateTier(tier *Tier) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveTier deletes the tier with the given code
|
|
||||||
func (s *commonStore) RemoveTier(code string) error {
|
|
||||||
if !AllowedTier(code) {
|
|
||||||
return ErrInvalidArgument
|
|
||||||
}
|
|
||||||
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tiers returns a list of all Tier structs
|
|
||||||
func (s *commonStore) Tiers() ([]*Tier, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTiers)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
tiers := make([]*Tier, 0)
|
|
||||||
for {
|
|
||||||
tier, err := s.readTier(rows)
|
|
||||||
if errors.Is(err, ErrTierNotFound) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tiers = append(tiers, tier)
|
|
||||||
}
|
|
||||||
return tiers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
|
||||||
func (s *commonStore) Tier(code string) (*Tier, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTierByCode, code)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
return s.readTier(rows)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
|
||||||
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
return s.readTier(rows)
|
|
||||||
}
|
|
||||||
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
|
|
||||||
var id, code, name string
|
|
||||||
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
|
||||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
|
||||||
if !rows.Next() {
|
|
||||||
return nil, ErrTierNotFound
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Tier{
|
|
||||||
ID: id,
|
|
||||||
Code: code,
|
|
||||||
Name: name,
|
|
||||||
MessageLimit: messagesLimit.Int64,
|
|
||||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
|
||||||
EmailLimit: emailsLimit.Int64,
|
|
||||||
CallLimit: callsLimit.Int64,
|
|
||||||
ReservationLimit: reservationsLimit.Int64,
|
|
||||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
|
||||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
|
||||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
|
||||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
|
||||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
|
||||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
|
||||||
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
phoneNumbers := make([]string, 0)
|
|
||||||
for {
|
|
||||||
phoneNumber, err := s.readPhoneNumber(rows)
|
|
||||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
phoneNumbers = append(phoneNumbers, phoneNumber)
|
|
||||||
}
|
|
||||||
return phoneNumbers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPhoneNumber adds a phone number to the user with the given user ID
|
|
||||||
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
|
|
||||||
if isUniqueConstraintError(err) {
|
|
||||||
return ErrPhoneNumberExists
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePhoneNumber deletes a phone number from the user with the given user ID
|
|
||||||
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
|
|
||||||
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
|
|
||||||
var phoneNumber string
|
|
||||||
if !rows.Next() {
|
|
||||||
return "", ErrPhoneNumberNotFound
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&phoneNumber); err != nil {
|
|
||||||
return "", err
|
|
||||||
} else if err := rows.Err(); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return phoneNumber, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangeBilling updates a user's billing fields
|
|
||||||
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
|
|
||||||
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserIDByUsername returns the user ID for the given username
|
|
||||||
func (s *commonStore) UserIDByUsername(username string) (string, error) {
|
|
||||||
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if !rows.Next() {
|
|
||||||
return "", ErrUserNotFound
|
|
||||||
}
|
|
||||||
var userID string
|
|
||||||
if err := rows.Scan(&userID); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return userID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the underlying database
|
|
||||||
func (s *commonStore) Close() error {
|
|
||||||
return s.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
|
|
||||||
func isUniqueConstraintError(err error) bool {
|
|
||||||
errStr := err.Error()
|
|
||||||
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
|
|
||||||
}
|
|
||||||
@@ -1,292 +0,0 @@
|
|||||||
package user
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostgreSQL queries
|
|
||||||
const (
|
|
||||||
// User queries
|
|
||||||
postgresSelectUserByID = `
|
|
||||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
|
||||||
FROM "user" u
|
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
|
||||||
WHERE u.id = $1
|
|
||||||
`
|
|
||||||
postgresSelectUserByName = `
|
|
||||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
|
||||||
FROM "user" u
|
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
|
||||||
WHERE user_name = $1
|
|
||||||
`
|
|
||||||
postgresSelectUserByToken = `
|
|
||||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
|
||||||
FROM "user" u
|
|
||||||
JOIN user_token tk on u.id = tk.user_id
|
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
|
||||||
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
|
||||||
`
|
|
||||||
postgresSelectUserByStripeID = `
|
|
||||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
|
||||||
FROM "user" u
|
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
|
||||||
WHERE u.stripe_customer_id = $1
|
|
||||||
`
|
|
||||||
postgresSelectUsernames = `
|
|
||||||
SELECT user_name
|
|
||||||
FROM "user"
|
|
||||||
ORDER BY
|
|
||||||
CASE role
|
|
||||||
WHEN 'admin' THEN 1
|
|
||||||
WHEN 'anonymous' THEN 3
|
|
||||||
ELSE 2
|
|
||||||
END, user_name
|
|
||||||
`
|
|
||||||
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
|
|
||||||
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
|
|
||||||
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
|
||||||
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
|
||||||
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
|
||||||
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
|
||||||
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
|
||||||
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
|
||||||
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
|
||||||
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
|
||||||
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
|
||||||
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
|
|
||||||
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
|
||||||
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
|
|
||||||
|
|
||||||
// Access queries
|
|
||||||
postgresSelectTopicPerms = `
|
|
||||||
SELECT read, write
|
|
||||||
FROM user_access a
|
|
||||||
JOIN "user" u ON u.id = a.user_id
|
|
||||||
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
|
||||||
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
|
||||||
`
|
|
||||||
postgresSelectUserAllAccess = `
|
|
||||||
SELECT user_id, topic, read, write, provisioned
|
|
||||||
FROM user_access
|
|
||||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
|
||||||
`
|
|
||||||
postgresSelectUserAccess = `
|
|
||||||
SELECT topic, read, write, provisioned
|
|
||||||
FROM user_access
|
|
||||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
|
||||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
|
||||||
`
|
|
||||||
postgresSelectUserReservations = `
|
|
||||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
|
||||||
FROM user_access a_user
|
|
||||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
|
||||||
WHERE a_user.user_id = a_user.owner_user_id
|
|
||||||
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
|
||||||
ORDER BY a_user.topic
|
|
||||||
`
|
|
||||||
postgresSelectUserReservationsCount = `
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM user_access
|
|
||||||
WHERE user_id = owner_user_id
|
|
||||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
|
||||||
`
|
|
||||||
postgresSelectUserReservationsOwner = `
|
|
||||||
SELECT owner_user_id
|
|
||||||
FROM user_access
|
|
||||||
WHERE topic = $1
|
|
||||||
AND user_id = owner_user_id
|
|
||||||
`
|
|
||||||
postgresSelectUserHasReservation = `
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM user_access
|
|
||||||
WHERE user_id = owner_user_id
|
|
||||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
|
||||||
AND topic = $2
|
|
||||||
`
|
|
||||||
postgresSelectOtherAccessCount = `
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM user_access
|
|
||||||
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
|
||||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
|
||||||
`
|
|
||||||
postgresUpsertUserAccess = `
|
|
||||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
|
||||||
VALUES (
|
|
||||||
(SELECT id FROM "user" WHERE user_name = $1),
|
|
||||||
$2,
|
|
||||||
$3,
|
|
||||||
$4,
|
|
||||||
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
|
||||||
$7
|
|
||||||
)
|
|
||||||
ON CONFLICT (user_id, topic)
|
|
||||||
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
|
|
||||||
`
|
|
||||||
postgresDeleteUserAccess = `
|
|
||||||
DELETE FROM user_access
|
|
||||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
|
||||||
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
|
||||||
`
|
|
||||||
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
|
|
||||||
postgresDeleteTopicAccess = `
|
|
||||||
DELETE FROM user_access
|
|
||||||
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
|
||||||
AND topic = $3
|
|
||||||
`
|
|
||||||
postgresDeleteAllAccess = `DELETE FROM user_access`
|
|
||||||
|
|
||||||
// Token queries
|
|
||||||
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
|
||||||
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
|
||||||
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
|
||||||
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
|
||||||
postgresUpsertToken = `
|
|
||||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
||||||
ON CONFLICT (user_id, token)
|
|
||||||
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
|
|
||||||
`
|
|
||||||
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
|
|
||||||
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
|
|
||||||
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
|
||||||
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
|
||||||
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
|
|
||||||
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
|
|
||||||
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
|
||||||
postgresDeleteExcessTokens = `
|
|
||||||
DELETE FROM user_token
|
|
||||||
WHERE user_id = $1
|
|
||||||
AND (user_id, token) NOT IN (
|
|
||||||
SELECT user_id, token
|
|
||||||
FROM user_token
|
|
||||||
WHERE user_id = $2
|
|
||||||
ORDER BY expires DESC
|
|
||||||
LIMIT $3
|
|
||||||
)
|
|
||||||
`
|
|
||||||
|
|
||||||
// Tier queries
|
|
||||||
postgresInsertTier = `
|
|
||||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
|
||||||
`
|
|
||||||
postgresUpdateTier = `
|
|
||||||
UPDATE tier
|
|
||||||
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
|
||||||
WHERE code = $13
|
|
||||||
`
|
|
||||||
postgresSelectTiers = `
|
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
|
||||||
FROM tier
|
|
||||||
`
|
|
||||||
postgresSelectTierByCode = `
|
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
|
||||||
FROM tier
|
|
||||||
WHERE code = $1
|
|
||||||
`
|
|
||||||
postgresSelectTierByPriceID = `
|
|
||||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
|
||||||
FROM tier
|
|
||||||
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
|
||||||
`
|
|
||||||
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
|
|
||||||
|
|
||||||
// Phone queries
|
|
||||||
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
|
||||||
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
|
||||||
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
|
||||||
|
|
||||||
// Billing queries
|
|
||||||
postgresUpdateBilling = `
|
|
||||||
UPDATE "user"
|
|
||||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
|
||||||
WHERE user_name = $7
|
|
||||||
`
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
|
||||||
func NewPostgresStore(dsn string) (Store, error) {
|
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgres(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &commonStore{
|
|
||||||
db: db,
|
|
||||||
queries: storeQueries{
|
|
||||||
// User queries
|
|
||||||
selectUserByID: postgresSelectUserByID,
|
|
||||||
selectUserByName: postgresSelectUserByName,
|
|
||||||
selectUserByToken: postgresSelectUserByToken,
|
|
||||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
|
||||||
selectUsernames: postgresSelectUsernames,
|
|
||||||
selectUserCount: postgresSelectUserCount,
|
|
||||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
|
||||||
insertUser: postgresInsertUser,
|
|
||||||
updateUserPass: postgresUpdateUserPass,
|
|
||||||
updateUserRole: postgresUpdateUserRole,
|
|
||||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
|
||||||
updateUserPrefs: postgresUpdateUserPrefs,
|
|
||||||
updateUserStats: postgresUpdateUserStats,
|
|
||||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
|
||||||
updateUserTier: postgresUpdateUserTier,
|
|
||||||
updateUserDeleted: postgresUpdateUserDeleted,
|
|
||||||
deleteUser: postgresDeleteUser,
|
|
||||||
deleteUserTier: postgresDeleteUserTier,
|
|
||||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
|
||||||
|
|
||||||
// Access queries
|
|
||||||
selectTopicPerms: postgresSelectTopicPerms,
|
|
||||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
|
||||||
selectUserAccess: postgresSelectUserAccess,
|
|
||||||
selectUserReservations: postgresSelectUserReservations,
|
|
||||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
|
||||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
|
||||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
|
||||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
|
||||||
upsertUserAccess: postgresUpsertUserAccess,
|
|
||||||
deleteUserAccess: postgresDeleteUserAccess,
|
|
||||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
|
||||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
|
||||||
deleteAllAccess: postgresDeleteAllAccess,
|
|
||||||
|
|
||||||
// Token queries
|
|
||||||
selectToken: postgresSelectToken,
|
|
||||||
selectTokens: postgresSelectTokens,
|
|
||||||
selectTokenCount: postgresSelectTokenCount,
|
|
||||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
|
||||||
upsertToken: postgresUpsertToken,
|
|
||||||
updateTokenLabel: postgresUpdateTokenLabel,
|
|
||||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
|
||||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
|
||||||
deleteToken: postgresDeleteToken,
|
|
||||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
|
||||||
deleteAllToken: postgresDeleteAllToken,
|
|
||||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
|
||||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
|
||||||
|
|
||||||
// Tier queries
|
|
||||||
insertTier: postgresInsertTier,
|
|
||||||
selectTiers: postgresSelectTiers,
|
|
||||||
selectTierByCode: postgresSelectTierByCode,
|
|
||||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
|
||||||
updateTier: postgresUpdateTier,
|
|
||||||
deleteTier: postgresDeleteTier,
|
|
||||||
|
|
||||||
// Phone queries
|
|
||||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
|
||||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
|
||||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
|
||||||
|
|
||||||
// Billing queries
|
|
||||||
updateBilling: postgresUpdateBilling,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
package user_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) user.Store {
|
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
store, err := user.NewPostgresStore(u.String())
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAddUser(t *testing.T) {
|
|
||||||
testStoreAddUser(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
|
|
||||||
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveUser(t *testing.T) {
|
|
||||||
testStoreRemoveUser(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUserByID(t *testing.T) {
|
|
||||||
testStoreUserByID(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUserByToken(t *testing.T) {
|
|
||||||
testStoreUserByToken(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
|
|
||||||
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUsers(t *testing.T) {
|
|
||||||
testStoreUsers(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUsersCount(t *testing.T) {
|
|
||||||
testStoreUsersCount(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreChangePassword(t *testing.T) {
|
|
||||||
testStoreChangePassword(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreChangeRole(t *testing.T) {
|
|
||||||
testStoreChangeRole(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokens(t *testing.T) {
|
|
||||||
testStoreTokens(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
|
|
||||||
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokenRemove(t *testing.T) {
|
|
||||||
testStoreTokenRemove(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
|
|
||||||
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
|
|
||||||
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
|
|
||||||
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAllowAccess(t *testing.T) {
|
|
||||||
testStoreAllowAccess(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
|
|
||||||
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreResetAccess(t *testing.T) {
|
|
||||||
testStoreResetAccess(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreResetAccessAll(t *testing.T) {
|
|
||||||
testStoreResetAccessAll(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreReservations(t *testing.T) {
|
|
||||||
testStoreReservations(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreReservationsCount(t *testing.T) {
|
|
||||||
testStoreReservationsCount(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreHasReservation(t *testing.T) {
|
|
||||||
testStoreHasReservation(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreReservationOwner(t *testing.T) {
|
|
||||||
testStoreReservationOwner(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTiers(t *testing.T) {
|
|
||||||
testStoreTiers(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTierUpdate(t *testing.T) {
|
|
||||||
testStoreTierUpdate(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTierRemove(t *testing.T) {
|
|
||||||
testStoreTierRemove(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreTierByStripePrice(t *testing.T) {
|
|
||||||
testStoreTierByStripePrice(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreChangeTier(t *testing.T) {
|
|
||||||
testStoreChangeTier(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStorePhoneNumbers(t *testing.T) {
|
|
||||||
testStorePhoneNumbers(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreChangeSettings(t *testing.T) {
|
|
||||||
testStoreChangeSettings(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreChangeBilling(t *testing.T) {
|
|
||||||
testStoreChangeBilling(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUpdateStats(t *testing.T) {
|
|
||||||
testStoreUpdateStats(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreResetStats(t *testing.T) {
|
|
||||||
testStoreResetStats(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
|
|
||||||
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
|
|
||||||
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreAllGrants(t *testing.T) {
|
|
||||||
testStoreAllGrants(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreOtherAccessCount(t *testing.T) {
|
|
||||||
testStoreOtherAccessCount(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
@@ -1,180 +0,0 @@
|
|||||||
package user_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestSQLiteStore(t *testing.T) user.Store {
|
|
||||||
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() { store.Close() })
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAddUser(t *testing.T) {
|
|
||||||
testStoreAddUser(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
|
|
||||||
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveUser(t *testing.T) {
|
|
||||||
testStoreRemoveUser(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUserByID(t *testing.T) {
|
|
||||||
testStoreUserByID(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUserByToken(t *testing.T) {
|
|
||||||
testStoreUserByToken(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
|
|
||||||
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUsers(t *testing.T) {
|
|
||||||
testStoreUsers(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUsersCount(t *testing.T) {
|
|
||||||
testStoreUsersCount(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreChangePassword(t *testing.T) {
|
|
||||||
testStoreChangePassword(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreChangeRole(t *testing.T) {
|
|
||||||
testStoreChangeRole(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokens(t *testing.T) {
|
|
||||||
testStoreTokens(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
|
|
||||||
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokenRemove(t *testing.T) {
|
|
||||||
testStoreTokenRemove(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
|
|
||||||
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
|
|
||||||
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
|
|
||||||
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAllowAccess(t *testing.T) {
|
|
||||||
testStoreAllowAccess(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
|
|
||||||
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreResetAccess(t *testing.T) {
|
|
||||||
testStoreResetAccess(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreResetAccessAll(t *testing.T) {
|
|
||||||
testStoreResetAccessAll(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
|
||||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreReservations(t *testing.T) {
|
|
||||||
testStoreReservations(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreReservationsCount(t *testing.T) {
|
|
||||||
testStoreReservationsCount(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreHasReservation(t *testing.T) {
|
|
||||||
testStoreHasReservation(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreReservationOwner(t *testing.T) {
|
|
||||||
testStoreReservationOwner(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTiers(t *testing.T) {
|
|
||||||
testStoreTiers(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTierUpdate(t *testing.T) {
|
|
||||||
testStoreTierUpdate(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTierRemove(t *testing.T) {
|
|
||||||
testStoreTierRemove(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
|
|
||||||
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreChangeTier(t *testing.T) {
|
|
||||||
testStoreChangeTier(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStorePhoneNumbers(t *testing.T) {
|
|
||||||
testStorePhoneNumbers(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreChangeSettings(t *testing.T) {
|
|
||||||
testStoreChangeSettings(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreChangeBilling(t *testing.T) {
|
|
||||||
testStoreChangeBilling(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUpdateStats(t *testing.T) {
|
|
||||||
testStoreUpdateStats(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreResetStats(t *testing.T) {
|
|
||||||
testStoreResetStats(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
|
|
||||||
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
|
|
||||||
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreAllGrants(t *testing.T) {
|
|
||||||
testStoreAllGrants(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
|
|
||||||
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
package user_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testStoreAddUser(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "phil", u.Name)
|
|
||||||
require.Equal(t, user.RoleUser, u.Role)
|
|
||||||
require.False(t, u.Provisioned)
|
|
||||||
require.NotEmpty(t, u.ID)
|
|
||||||
require.NotEmpty(t, u.SyncTopic)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreRemoveUser(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "phil", u.Name)
|
|
||||||
|
|
||||||
require.Nil(t, store.RemoveUser("phil"))
|
|
||||||
_, err = store.User("phil")
|
|
||||||
require.Equal(t, user.ErrUserNotFound, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUserByID(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
u2, err := store.UserByID(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, u.Name, u2.Name)
|
|
||||||
require.Equal(t, u.ID, u2.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUserByToken(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tk_test123", tk.Value)
|
|
||||||
|
|
||||||
u2, err := store.UserByToken(tk.Value)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "phil", u2.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
|
|
||||||
StripeCustomerID: "cus_test123",
|
|
||||||
StripeSubscriptionID: "sub_test123",
|
|
||||||
}))
|
|
||||||
|
|
||||||
u, err := store.UserByStripeCustomer("cus_test123")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "phil", u.Name)
|
|
||||||
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUsers(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
|
|
||||||
|
|
||||||
users, err := store.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUsersCount(t *testing.T, store user.Store) {
|
|
||||||
count, err := store.UsersCount()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, count >= 1) // At least the everyone user
|
|
||||||
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
count2, err := store.UsersCount()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, count+1, count2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreChangePassword(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "philhash", u.Hash)
|
|
||||||
|
|
||||||
require.Nil(t, store.ChangePassword("phil", "newhash"))
|
|
||||||
u, err = store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "newhash", u.Hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreChangeRole(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, user.RoleUser, u.Role)
|
|
||||||
|
|
||||||
require.Nil(t, store.ChangeRole("phil", user.RoleAdmin))
|
|
||||||
u, err = store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, user.RoleAdmin, u.Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokens(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
expires := now.Add(24 * time.Hour)
|
|
||||||
origin := netip.MustParseAddr("9.9.9.9")
|
|
||||||
|
|
||||||
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tk_abc", tk.Value)
|
|
||||||
require.Equal(t, "my token", tk.Label)
|
|
||||||
|
|
||||||
// Get single token
|
|
||||||
tk2, err := store.Token(u.ID, "tk_abc")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tk_abc", tk2.Value)
|
|
||||||
require.Equal(t, "my token", tk2.Label)
|
|
||||||
|
|
||||||
// Get all tokens
|
|
||||||
tokens, err := store.Tokens(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, tokens, 1)
|
|
||||||
require.Equal(t, "tk_abc", tokens[0].Value)
|
|
||||||
|
|
||||||
// Token count
|
|
||||||
count, err := store.TokenCount(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
|
|
||||||
tk, err := store.Token(u.ID, "tk_abc")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "new label", tk.Label)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokenRemove(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
|
|
||||||
_, err = store.Token(u.ID, "tk_abc")
|
|
||||||
require.Equal(t, user.ErrTokenNotFound, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create expired token and active token
|
|
||||||
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.RemoveExpiredTokens())
|
|
||||||
|
|
||||||
// Expired token should be gone
|
|
||||||
_, err = store.Token(u.ID, "tk_expired")
|
|
||||||
require.Equal(t, user.ErrTokenNotFound, err)
|
|
||||||
|
|
||||||
// Active token should still exist
|
|
||||||
tk, err := store.Token(u.ID, "tk_active")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tk_active", tk.Value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create 3 tokens with increasing expiry
|
|
||||||
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
|
|
||||||
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
count, err := store.TokenCount(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 3, count)
|
|
||||||
|
|
||||||
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
|
|
||||||
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
|
|
||||||
|
|
||||||
count, err = store.TokenCount(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, count)
|
|
||||||
|
|
||||||
// tk_a should be removed (earliest expiry)
|
|
||||||
_, err = store.Token(u.ID, "tk_a")
|
|
||||||
require.Equal(t, user.ErrTokenNotFound, err)
|
|
||||||
|
|
||||||
// tk_b and tk_c should remain
|
|
||||||
_, err = store.Token(u.ID, "tk_b")
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = store.Token(u.ID, "tk_c")
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
newTime := time.Now().Add(5 * time.Minute)
|
|
||||||
newOrigin := netip.MustParseAddr("5.5.5.5")
|
|
||||||
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin))
|
|
||||||
|
|
||||||
tk, err := store.Token(u.ID, "tk_abc")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
|
|
||||||
require.Equal(t, newOrigin, tk.LastOrigin)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAllowAccess(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
|
||||||
grants, err := store.Grants("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, grants, 1)
|
|
||||||
require.Equal(t, "mytopic", grants[0].TopicPattern)
|
|
||||||
require.True(t, grants[0].Permission.IsReadWrite())
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
|
|
||||||
grants, err := store.Grants("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, grants, 1)
|
|
||||||
require.True(t, grants[0].Permission.IsRead())
|
|
||||||
require.False(t, grants[0].Permission.IsWrite())
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreResetAccess(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
|
||||||
|
|
||||||
grants, err := store.Grants("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, grants, 2)
|
|
||||||
|
|
||||||
require.Nil(t, store.ResetAccess("phil", "topic1"))
|
|
||||||
grants, err = store.Grants("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, grants, 1)
|
|
||||||
require.Equal(t, "topic2", grants[0].TopicPattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreResetAccessAll(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
|
||||||
|
|
||||||
require.Nil(t, store.ResetAccess("phil", ""))
|
|
||||||
grants, err := store.Grants("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, grants, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
|
||||||
|
|
||||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, found)
|
|
||||||
require.True(t, read)
|
|
||||||
require.True(t, write)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
|
|
||||||
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.False(t, found)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
|
|
||||||
|
|
||||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "secret")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, found)
|
|
||||||
require.False(t, read)
|
|
||||||
require.False(t, write)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreReservations(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
|
||||||
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
|
|
||||||
|
|
||||||
reservations, err := store.Reservations("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, reservations, 1)
|
|
||||||
require.Equal(t, "mytopic", reservations[0].Topic)
|
|
||||||
require.True(t, reservations[0].Owner.IsReadWrite())
|
|
||||||
require.True(t, reservations[0].Everyone.IsRead())
|
|
||||||
require.False(t, reservations[0].Everyone.IsWrite())
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreReservationsCount(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
|
|
||||||
|
|
||||||
count, err := store.ReservationsCount("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(2), count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreHasReservation(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
|
||||||
|
|
||||||
has, err := store.HasReservation("phil", "mytopic")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, has)
|
|
||||||
|
|
||||||
has, err = store.HasReservation("phil", "other")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.False(t, has)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreReservationOwner(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
|
||||||
|
|
||||||
owner, err := store.ReservationOwner("mytopic")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.NotEmpty(t, owner) // Returns the user ID
|
|
||||||
|
|
||||||
owner, err = store.ReservationOwner("unowned")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, owner)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTiers(t *testing.T, store user.Store) {
|
|
||||||
tier := &user.Tier{
|
|
||||||
ID: "ti_test",
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
MessageLimit: 5000,
|
|
||||||
MessageExpiryDuration: 24 * time.Hour,
|
|
||||||
EmailLimit: 100,
|
|
||||||
CallLimit: 10,
|
|
||||||
ReservationLimit: 20,
|
|
||||||
AttachmentFileSizeLimit: 10 * 1024 * 1024,
|
|
||||||
AttachmentTotalSizeLimit: 100 * 1024 * 1024,
|
|
||||||
AttachmentExpiryDuration: 48 * time.Hour,
|
|
||||||
AttachmentBandwidthLimit: 500 * 1024 * 1024,
|
|
||||||
}
|
|
||||||
require.Nil(t, store.AddTier(tier))
|
|
||||||
|
|
||||||
// Get by code
|
|
||||||
t2, err := store.Tier("pro")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "ti_test", t2.ID)
|
|
||||||
require.Equal(t, "pro", t2.Code)
|
|
||||||
require.Equal(t, "Pro", t2.Name)
|
|
||||||
require.Equal(t, int64(5000), t2.MessageLimit)
|
|
||||||
require.Equal(t, int64(100), t2.EmailLimit)
|
|
||||||
require.Equal(t, int64(10), t2.CallLimit)
|
|
||||||
require.Equal(t, int64(20), t2.ReservationLimit)
|
|
||||||
|
|
||||||
// List all tiers
|
|
||||||
tiers, err := store.Tiers()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, tiers, 1)
|
|
||||||
require.Equal(t, "pro", tiers[0].Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTierUpdate(t *testing.T, store user.Store) {
|
|
||||||
tier := &user.Tier{
|
|
||||||
ID: "ti_test",
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
}
|
|
||||||
require.Nil(t, store.AddTier(tier))
|
|
||||||
|
|
||||||
tier.Name = "Professional"
|
|
||||||
tier.MessageLimit = 9999
|
|
||||||
require.Nil(t, store.UpdateTier(tier))
|
|
||||||
|
|
||||||
t2, err := store.Tier("pro")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "Professional", t2.Name)
|
|
||||||
require.Equal(t, int64(9999), t2.MessageLimit)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTierRemove(t *testing.T, store user.Store) {
|
|
||||||
tier := &user.Tier{
|
|
||||||
ID: "ti_test",
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
}
|
|
||||||
require.Nil(t, store.AddTier(tier))
|
|
||||||
|
|
||||||
t2, err := store.Tier("pro")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "pro", t2.Code)
|
|
||||||
|
|
||||||
require.Nil(t, store.RemoveTier("pro"))
|
|
||||||
_, err = store.Tier("pro")
|
|
||||||
require.Equal(t, user.ErrTierNotFound, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
|
|
||||||
tier := &user.Tier{
|
|
||||||
ID: "ti_test",
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
StripeMonthlyPriceID: "price_monthly",
|
|
||||||
StripeYearlyPriceID: "price_yearly",
|
|
||||||
}
|
|
||||||
require.Nil(t, store.AddTier(tier))
|
|
||||||
|
|
||||||
t2, err := store.TierByStripePrice("price_monthly")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "pro", t2.Code)
|
|
||||||
|
|
||||||
t3, err := store.TierByStripePrice("price_yearly")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "pro", t3.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreChangeTier(t *testing.T, store user.Store) {
|
|
||||||
tier := &user.Tier{
|
|
||||||
ID: "ti_test",
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
}
|
|
||||||
require.Nil(t, store.AddTier(tier))
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.ChangeTier("phil", "pro"))
|
|
||||||
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.NotNil(t, u.Tier)
|
|
||||||
require.Equal(t, "pro", u.Tier.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStorePhoneNumbers(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890"))
|
|
||||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321"))
|
|
||||||
|
|
||||||
numbers, err := store.PhoneNumbers(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, numbers, 2)
|
|
||||||
|
|
||||||
require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890"))
|
|
||||||
numbers, err = store.PhoneNumbers(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, numbers, 1)
|
|
||||||
require.Equal(t, "+0987654321", numbers[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreChangeSettings(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
lang := "de"
|
|
||||||
prefs := &user.Prefs{Language: &lang}
|
|
||||||
require.Nil(t, store.ChangeSettings(u.ID, prefs))
|
|
||||||
|
|
||||||
u2, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.NotNil(t, u2.Prefs)
|
|
||||||
require.Equal(t, "de", *u2.Prefs.Language)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreChangeBilling(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
|
|
||||||
billing := &user.Billing{
|
|
||||||
StripeCustomerID: "cus_123",
|
|
||||||
StripeSubscriptionID: "sub_456",
|
|
||||||
}
|
|
||||||
require.Nil(t, store.ChangeBilling("phil", billing))
|
|
||||||
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
|
|
||||||
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreUpdateStats(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
|
|
||||||
require.Nil(t, store.UpdateStats(u.ID, stats))
|
|
||||||
|
|
||||||
u2, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(42), u2.Stats.Messages)
|
|
||||||
require.Equal(t, int64(3), u2.Stats.Emails)
|
|
||||||
require.Equal(t, int64(1), u2.Stats.Calls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreResetStats(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
|
|
||||||
require.Nil(t, store.ResetStats())
|
|
||||||
|
|
||||||
u2, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(0), u2.Stats.Messages)
|
|
||||||
require.Equal(t, int64(0), u2.Stats.Emails)
|
|
||||||
require.Equal(t, int64(0), u2.Stats.Calls)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
|
||||||
|
|
||||||
u2, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, u2.Deleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
u, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
|
||||||
|
|
||||||
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
|
|
||||||
// Immediately after marking, the user should still exist.
|
|
||||||
require.Nil(t, store.RemoveDeletedUsers())
|
|
||||||
u2, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, u2.Deleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreAllGrants(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
|
||||||
phil, err := store.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
ben, err := store.User("ben")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
|
||||||
require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false))
|
|
||||||
|
|
||||||
grants, err := store.AllGrants()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Contains(t, grants, phil.ID)
|
|
||||||
require.Contains(t, grants, ben.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
|
|
||||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
|
||||||
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
|
|
||||||
|
|
||||||
count, err := store.OtherAccessCount("phil", "mytopic")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
}
|
|
||||||
@@ -2,11 +2,12 @@ package user
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"heckel.io/ntfy/v2/log"
|
|
||||||
"heckel.io/ntfy/v2/payments"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/payments"
|
||||||
)
|
)
|
||||||
|
|
||||||
// User is a struct that represents a user
|
// User is a struct that represents a user
|
||||||
@@ -273,3 +274,72 @@ var (
|
|||||||
ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user")
|
ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user")
|
||||||
ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token")
|
ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// queries holds the database-specific SQL queries
|
||||||
|
type queries struct {
|
||||||
|
// User queries
|
||||||
|
selectUserByID string
|
||||||
|
selectUserByName string
|
||||||
|
selectUserByToken string
|
||||||
|
selectUserByStripeCustomerID string
|
||||||
|
selectUsernames string
|
||||||
|
selectUserCount string
|
||||||
|
selectUserIDFromUsername string
|
||||||
|
insertUser string
|
||||||
|
updateUserPass string
|
||||||
|
updateUserRole string
|
||||||
|
updateUserProvisioned string
|
||||||
|
updateUserPrefs string
|
||||||
|
updateUserStats string
|
||||||
|
updateUserStatsResetAll string
|
||||||
|
updateUserTier string
|
||||||
|
updateUserDeleted string
|
||||||
|
deleteUser string
|
||||||
|
deleteUserTier string
|
||||||
|
deleteUsersMarked string
|
||||||
|
|
||||||
|
// Access queries
|
||||||
|
selectTopicPerms string
|
||||||
|
selectUserAllAccess string
|
||||||
|
selectUserAccess string
|
||||||
|
selectUserReservations string
|
||||||
|
selectUserReservationsCount string
|
||||||
|
selectUserReservationsOwner string
|
||||||
|
selectUserHasReservation string
|
||||||
|
selectOtherAccessCount string
|
||||||
|
upsertUserAccess string
|
||||||
|
deleteUserAccess string
|
||||||
|
deleteUserAccessProvisioned string
|
||||||
|
deleteTopicAccess string
|
||||||
|
deleteAllAccess string
|
||||||
|
|
||||||
|
// Token queries
|
||||||
|
selectToken string
|
||||||
|
selectTokens string
|
||||||
|
selectTokenCount string
|
||||||
|
selectAllProvisionedTokens string
|
||||||
|
upsertToken string
|
||||||
|
updateToken string
|
||||||
|
updateTokenLastAccess string
|
||||||
|
deleteToken string
|
||||||
|
deleteProvisionedToken string
|
||||||
|
deleteAllToken string
|
||||||
|
deleteExpiredTokens string
|
||||||
|
deleteExcessTokens string
|
||||||
|
|
||||||
|
// Tier queries
|
||||||
|
insertTier string
|
||||||
|
selectTiers string
|
||||||
|
selectTierByCode string
|
||||||
|
selectTierByPriceID string
|
||||||
|
updateTier string
|
||||||
|
deleteTier string
|
||||||
|
|
||||||
|
// Phone queries
|
||||||
|
selectPhoneNumbers string
|
||||||
|
insertPhoneNumber string
|
||||||
|
deletePhoneNumber string
|
||||||
|
|
||||||
|
// Billing queries
|
||||||
|
updateBilling string
|
||||||
|
}
|
||||||
|
|||||||
32
user/util.go
32
user/util.go
@@ -113,3 +113,35 @@ func escapeUnderscore(s string) string {
|
|||||||
func unescapeUnderscore(s string) string {
|
func unescapeUnderscore(s string) string {
|
||||||
return strings.ReplaceAll(s, "\\_", "_")
|
return strings.ReplaceAll(s, "\\_", "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
||||||
|
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err := f(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryTx executes a function in a transaction and returns the result. If the function
|
||||||
|
// returns an error, the transaction is rolled back.
|
||||||
|
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
t, err := f(tx)
|
||||||
|
if err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -48,5 +48,26 @@
|
|||||||
"notifications_none_for_topic_title": "לא קיבלת התראות בנושא הזה עדיין.",
|
"notifications_none_for_topic_title": "לא קיבלת התראות בנושא הזה עדיין.",
|
||||||
"notifications_none_for_topic_description": "כדי לשלוח התראות לנושא הזה, צריך לשלוח PUT או POST לכתובת הנושא הזה.",
|
"notifications_none_for_topic_description": "כדי לשלוח התראות לנושא הזה, צריך לשלוח PUT או POST לכתובת הנושא הזה.",
|
||||||
"notifications_none_for_any_title": "לא קיבלת התראות כלל.",
|
"notifications_none_for_any_title": "לא קיבלת התראות כלל.",
|
||||||
"notifications_no_subscriptions_title": "נראה שלא נרשמת למינויים עדיין."
|
"notifications_no_subscriptions_title": "נראה שלא נרשמת למינויים עדיין.",
|
||||||
|
"action_bar_toggle_mute": "השתקת/הפעלת התראות",
|
||||||
|
"action_bar_toggle_action_menu": "פתיחת/סגירת תפריט הפעולות",
|
||||||
|
"action_bar_profile_title": "פרופיל",
|
||||||
|
"action_bar_profile_settings": "הגדרות",
|
||||||
|
"action_bar_profile_logout": "יציאה",
|
||||||
|
"action_bar_sign_in": "כניסה",
|
||||||
|
"action_bar_sign_up": "הרשמה",
|
||||||
|
"message_bar_type_message": "כאן ניתן להקליד הודעה",
|
||||||
|
"message_bar_error_publishing": "שגיאה בפרסום ההתראה",
|
||||||
|
"message_bar_show_dialog": "הצגת חלונית פרסום",
|
||||||
|
"message_bar_publish": "פרסום הודעה",
|
||||||
|
"nav_topics_title": "נושאים שנרשמת אליהם",
|
||||||
|
"nav_button_all_notifications": "כל ההתראות",
|
||||||
|
"nav_button_account": "חשבון",
|
||||||
|
"nav_button_settings": "הגדרות",
|
||||||
|
"nav_button_documentation": "תיעוד",
|
||||||
|
"nav_button_publish_message": "פרסום התראה",
|
||||||
|
"nav_button_subscribe": "הרשמה לנושא",
|
||||||
|
"nav_button_muted": "התראות הושתקו",
|
||||||
|
"nav_button_connecting": "מתחבר",
|
||||||
|
"nav_upgrade_banner_label": "שדרוג ל־ntfy Pro"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,26 +21,19 @@ var (
|
|||||||
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
ErrWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store is the interface for a web push subscription store.
|
// Store holds the database connection and queries for web push subscriptions.
|
||||||
type Store interface {
|
type Store struct {
|
||||||
UpsertSubscription(endpoint, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error
|
db *sql.DB
|
||||||
SubscriptionsForTopic(topic string) ([]*Subscription, error)
|
queries queries
|
||||||
SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error)
|
|
||||||
MarkExpiryWarningSent(subscriptions []*Subscription) error
|
|
||||||
RemoveSubscriptionsByEndpoint(endpoint string) error
|
|
||||||
RemoveSubscriptionsByUserID(userID string) error
|
|
||||||
RemoveExpiredSubscriptions(expireAfter time.Duration) error
|
|
||||||
SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error
|
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeQueries holds the database-specific SQL queries.
|
// queries holds the database-specific SQL queries.
|
||||||
type storeQueries struct {
|
type queries struct {
|
||||||
selectSubscriptionIDByEndpoint string
|
selectSubscriptionIDByEndpoint string
|
||||||
selectSubscriptionCountBySubscriberIP string
|
selectSubscriptionCountBySubscriberIP string
|
||||||
selectSubscriptionsForTopic string
|
selectSubscriptionsForTopic string
|
||||||
selectSubscriptionsExpiringSoon string
|
selectSubscriptionsExpiringSoon string
|
||||||
insertSubscription string
|
upsertSubscription string
|
||||||
updateSubscriptionWarningSent string
|
updateSubscriptionWarningSent string
|
||||||
updateSubscriptionUpdatedAt string
|
updateSubscriptionUpdatedAt string
|
||||||
deleteSubscriptionByEndpoint string
|
deleteSubscriptionByEndpoint string
|
||||||
@@ -51,14 +44,8 @@ type storeQueries struct {
|
|||||||
deleteSubscriptionTopicWithoutSubscription string
|
deleteSubscriptionTopicWithoutSubscription string
|
||||||
}
|
}
|
||||||
|
|
||||||
// commonStore implements store operations that are identical across database backends.
|
|
||||||
type commonStore struct {
|
|
||||||
db *sql.DB
|
|
||||||
queries storeQueries
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
||||||
func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -71,8 +58,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
|
|||||||
}
|
}
|
||||||
// Read existing subscription ID for endpoint (or create new ID)
|
// Read existing subscription ID for endpoint (or create new ID)
|
||||||
var subscriptionID string
|
var subscriptionID string
|
||||||
err = tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
|
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||||
return ErrWebPushTooManySubscriptions
|
return ErrWebPushTooManySubscriptions
|
||||||
}
|
}
|
||||||
@@ -82,7 +68,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
|
|||||||
}
|
}
|
||||||
// Insert or update subscription
|
// Insert or update subscription
|
||||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||||
if _, err = tx.Exec(s.queries.insertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Replace all subscription topics
|
// Replace all subscription topics
|
||||||
@@ -98,7 +84,7 @@ func (s *commonStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||||
func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsForTopic(topic string) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
rows, err := s.db.Query(s.queries.selectSubscriptionsForTopic, topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -108,7 +94,7 @@ func (s *commonStore) SubscriptionsForTopic(topic string) ([]*Subscription, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period.
|
||||||
func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription, error) {
|
||||||
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
rows, err := s.db.Query(s.queries.selectSubscriptionsExpiringSoon, time.Now().Add(-warnAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -118,7 +104,7 @@ func (s *commonStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
||||||
func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -133,13 +119,13 @@ func (s *commonStore) MarkExpiryWarningSent(subscriptions []*Subscription) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
||||||
func (s *commonStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
func (s *Store) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
|
_, err := s.db.Exec(s.queries.deleteSubscriptionByEndpoint, endpoint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
|
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID.
|
||||||
func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
|
func (s *Store) RemoveSubscriptionsByUserID(userID string) error {
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
return ErrWebPushUserIDCannotBeEmpty
|
return ErrWebPushUserIDCannotBeEmpty
|
||||||
}
|
}
|
||||||
@@ -148,7 +134,7 @@ func (s *commonStore) RemoveSubscriptionsByUserID(userID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period.
|
||||||
func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
func (s *Store) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||||
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
_, err := s.db.Exec(s.queries.deleteSubscriptionByAge, time.Now().Add(-expireAfter).Unix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -158,14 +144,14 @@ func (s *commonStore) RemoveExpiredSubscriptions(expireAfter time.Duration) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
// SetSubscriptionUpdatedAt updates the updated_at timestamp for a subscription by endpoint. This is
|
||||||
// exported for testing purposes and is not part of the Store interface.
|
// exported for testing purposes.
|
||||||
func (s *commonStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
|
func (s *Store) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64) error {
|
||||||
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
|
_, err := s.db.Exec(s.queries.updateSubscriptionUpdatedAt, updatedAt, endpoint)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the underlying database connection.
|
// Close closes the underlying database connection.
|
||||||
func (s *commonStore) Close() error {
|
func (s *Store) Close() error {
|
||||||
return s.db.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,10 @@ package webpush
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
pgCreateTablesQuery = `
|
postgresCreateTablesQuery = `
|
||||||
CREATE TABLE IF NOT EXISTS webpush_subscription (
|
CREATE TABLE IF NOT EXISTS webpush_subscription (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
endpoint TEXT NOT NULL UNIQUE,
|
endpoint TEXT NOT NULL UNIQUE,
|
||||||
@@ -20,6 +18,8 @@ const (
|
|||||||
warned_at BIGINT NOT NULL DEFAULT 0
|
warned_at BIGINT NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
|
CREATE INDEX IF NOT EXISTS idx_webpush_subscriber_ip ON webpush_subscription (subscriber_ip);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_updated_at ON webpush_subscription (updated_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webpush_user_id ON webpush_subscription (user_id);
|
||||||
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
|
CREATE TABLE IF NOT EXISTS webpush_subscription_topic (
|
||||||
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
|
subscription_id TEXT NOT NULL REFERENCES webpush_subscription (id) ON DELETE CASCADE,
|
||||||
topic TEXT NOT NULL,
|
topic TEXT NOT NULL,
|
||||||
@@ -32,79 +32,72 @@ const (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
pgSelectSubscriptionIDByEndpoint = `SELECT id FROM webpush_subscription WHERE endpoint = $1`
|
postgresSelectSubscriptionIDByEndpointQuery = `SELECT id FROM webpush_subscription WHERE endpoint = $1`
|
||||||
pgSelectSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1`
|
postgresSelectSubscriptionCountBySubscriberIPQuery = `SELECT COUNT(*) FROM webpush_subscription WHERE subscriber_ip = $1`
|
||||||
pgSelectSubscriptionsForTopicQuery = `
|
postgresSelectSubscriptionsForTopicQuery = `
|
||||||
SELECT s.id, s.endpoint, s.key_auth, s.key_p256dh, s.user_id
|
SELECT s.id, s.endpoint, s.key_auth, s.key_p256dh, s.user_id
|
||||||
FROM webpush_subscription_topic st
|
FROM webpush_subscription_topic st
|
||||||
JOIN webpush_subscription s ON s.id = st.subscription_id
|
JOIN webpush_subscription s ON s.id = st.subscription_id
|
||||||
WHERE st.topic = $1
|
WHERE st.topic = $1
|
||||||
ORDER BY s.endpoint
|
ORDER BY s.endpoint
|
||||||
`
|
`
|
||||||
pgSelectSubscriptionsExpiringSoonQuery = `
|
postgresSelectSubscriptionsExpiringSoonQuery = `
|
||||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||||
FROM webpush_subscription
|
FROM webpush_subscription
|
||||||
WHERE warned_at = 0 AND updated_at <= $1
|
WHERE warned_at = 0 AND updated_at <= $1
|
||||||
`
|
`
|
||||||
pgInsertSubscriptionQuery = `
|
postgresUpsertSubscriptionQuery = `
|
||||||
INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
INSERT INTO webpush_subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
ON CONFLICT (endpoint)
|
ON CONFLICT (endpoint)
|
||||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||||
`
|
`
|
||||||
pgUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
|
postgresUpdateSubscriptionWarningSentQuery = `UPDATE webpush_subscription SET warned_at = $1 WHERE id = $2`
|
||||||
pgUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
|
postgresUpdateSubscriptionUpdatedAtQuery = `UPDATE webpush_subscription SET updated_at = $1 WHERE endpoint = $2`
|
||||||
pgDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1`
|
postgresDeleteSubscriptionByEndpointQuery = `DELETE FROM webpush_subscription WHERE endpoint = $1`
|
||||||
pgDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1`
|
postgresDeleteSubscriptionByUserIDQuery = `DELETE FROM webpush_subscription WHERE user_id = $1`
|
||||||
pgDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1`
|
postgresDeleteSubscriptionByAgeQuery = `DELETE FROM webpush_subscription WHERE updated_at <= $1`
|
||||||
|
|
||||||
pgInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)`
|
postgresInsertSubscriptionTopicQuery = `INSERT INTO webpush_subscription_topic (subscription_id, topic) VALUES ($1, $2)`
|
||||||
pgDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1`
|
postgresDeleteSubscriptionTopicAllQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id = $1`
|
||||||
pgDeleteSubscriptionTopicWithoutSubscription = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)`
|
postgresDeleteSubscriptionTopicWithoutSubscriptionQuery = `DELETE FROM webpush_subscription_topic WHERE subscription_id NOT IN (SELECT id FROM webpush_subscription)`
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostgreSQL schema management queries
|
// PostgreSQL schema management queries
|
||||||
const (
|
const (
|
||||||
pgCurrentSchemaVersion = 1
|
pgCurrentSchemaVersion = 1
|
||||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
||||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
||||||
func NewPostgresStore(dsn string) (Store, error) {
|
func NewPostgresStore(db *sql.DB) (*Store, error) {
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgresDB(db); err != nil {
|
if err := setupPostgresDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return &Store{
|
||||||
db: db,
|
db: db,
|
||||||
queries: storeQueries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: pgSelectSubscriptionIDByEndpoint,
|
selectSubscriptionIDByEndpoint: postgresSelectSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: pgSelectSubscriptionCountBySubscriberIP,
|
selectSubscriptionCountBySubscriberIP: postgresSelectSubscriptionCountBySubscriberIPQuery,
|
||||||
selectSubscriptionsForTopic: pgSelectSubscriptionsForTopicQuery,
|
selectSubscriptionsForTopic: postgresSelectSubscriptionsForTopicQuery,
|
||||||
selectSubscriptionsExpiringSoon: pgSelectSubscriptionsExpiringSoonQuery,
|
selectSubscriptionsExpiringSoon: postgresSelectSubscriptionsExpiringSoonQuery,
|
||||||
insertSubscription: pgInsertSubscriptionQuery,
|
upsertSubscription: postgresUpsertSubscriptionQuery,
|
||||||
updateSubscriptionWarningSent: pgUpdateSubscriptionWarningSentQuery,
|
updateSubscriptionWarningSent: postgresUpdateSubscriptionWarningSentQuery,
|
||||||
updateSubscriptionUpdatedAt: pgUpdateSubscriptionUpdatedAtQuery,
|
updateSubscriptionUpdatedAt: postgresUpdateSubscriptionUpdatedAtQuery,
|
||||||
deleteSubscriptionByEndpoint: pgDeleteSubscriptionByEndpointQuery,
|
deleteSubscriptionByEndpoint: postgresDeleteSubscriptionByEndpointQuery,
|
||||||
deleteSubscriptionByUserID: pgDeleteSubscriptionByUserIDQuery,
|
deleteSubscriptionByUserID: postgresDeleteSubscriptionByUserIDQuery,
|
||||||
deleteSubscriptionByAge: pgDeleteSubscriptionByAgeQuery,
|
deleteSubscriptionByAge: postgresDeleteSubscriptionByAgeQuery,
|
||||||
insertSubscriptionTopic: pgInsertSubscriptionTopicQuery,
|
insertSubscriptionTopic: postgresInsertSubscriptionTopicQuery,
|
||||||
deleteSubscriptionTopicAll: pgDeleteSubscriptionTopicAllQuery,
|
deleteSubscriptionTopicAll: postgresDeleteSubscriptionTopicAllQuery,
|
||||||
deleteSubscriptionTopicWithoutSubscription: pgDeleteSubscriptionTopicWithoutSubscription,
|
deleteSubscriptionTopicWithoutSubscription: postgresDeleteSubscriptionTopicWithoutSubscriptionQuery,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupPostgresDB(db *sql.DB) error {
|
func setupPostgresDB(db *sql.DB) error {
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewPostgresDB(db)
|
return setupNewPostgresDB(db)
|
||||||
}
|
}
|
||||||
@@ -120,10 +113,10 @@ func setupNewPostgresDB(db *sql.DB) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
|
|||||||
@@ -1,91 +0,0 @@
|
|||||||
package webpush_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
"heckel.io/ntfy/v2/webpush"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) webpush.Store {
|
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
store, err := webpush.NewPostgresStore(u.String())
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionUpdateTopics(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionUpdateFields(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveByUserIDMultiple(t *testing.T) {
|
|
||||||
testStoreRemoveByUserIDMultiple(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveByEndpoint(t *testing.T) {
|
|
||||||
testStoreRemoveByEndpoint(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveByUserID(t *testing.T) {
|
|
||||||
testStoreRemoveByUserID(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveByUserIDEmpty(t *testing.T) {
|
|
||||||
testStoreRemoveByUserIDEmpty(t, newTestPostgresStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreExpiryWarningSent(t *testing.T) {
|
|
||||||
store := newTestPostgresStore(t)
|
|
||||||
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreExpiring(t *testing.T) {
|
|
||||||
store := newTestPostgresStore(t)
|
|
||||||
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresStoreRemoveExpired(t *testing.T) {
|
|
||||||
store := newTestPostgresStore(t)
|
|
||||||
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
@@ -8,8 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sqliteCreateWebPushSubscriptionsTableQuery = `
|
sqliteCreateTablesQuery = `
|
||||||
BEGIN;
|
|
||||||
CREATE TABLE IF NOT EXISTS subscription (
|
CREATE TABLE IF NOT EXISTS subscription (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
endpoint TEXT NOT NULL,
|
endpoint TEXT NOT NULL,
|
||||||
@@ -33,52 +32,51 @@ const (
|
|||||||
id INT PRIMARY KEY,
|
id INT PRIMARY KEY,
|
||||||
version INT NOT NULL
|
version INT NOT NULL
|
||||||
);
|
);
|
||||||
COMMIT;
|
|
||||||
`
|
`
|
||||||
sqliteBuiltinStartupQueries = `
|
sqliteBuiltinStartupQueries = `
|
||||||
PRAGMA foreign_keys = ON;
|
PRAGMA foreign_keys = ON;
|
||||||
`
|
`
|
||||||
|
|
||||||
sqliteSelectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
sqliteSelectSubscriptionIDByEndpointQuery = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||||
sqliteSelectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
sqliteSelectSubscriptionCountBySubscriberIPQuery = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||||
sqliteSelectWebPushSubscriptionsForTopicQuery = `
|
sqliteSelectSubscriptionsForTopicQuery = `
|
||||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||||
FROM subscription_topic st
|
FROM subscription_topic st
|
||||||
JOIN subscription s ON s.id = st.subscription_id
|
JOIN subscription s ON s.id = st.subscription_id
|
||||||
WHERE st.topic = ?
|
WHERE st.topic = ?
|
||||||
ORDER BY endpoint
|
ORDER BY endpoint
|
||||||
`
|
`
|
||||||
sqliteSelectWebPushSubscriptionsExpiringSoonQuery = `
|
sqliteSelectSubscriptionsExpiringSoonQuery = `
|
||||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||||
FROM subscription
|
FROM subscription
|
||||||
WHERE warned_at = 0 AND updated_at <= ?
|
WHERE warned_at = 0 AND updated_at <= ?
|
||||||
`
|
`
|
||||||
sqliteInsertWebPushSubscriptionQuery = `
|
sqliteUpsertSubscriptionQuery = `
|
||||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT (endpoint)
|
ON CONFLICT (endpoint)
|
||||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||||
`
|
`
|
||||||
sqliteUpdateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
sqliteUpdateSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||||
sqliteUpdateWebPushSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
|
sqliteUpdateSubscriptionUpdatedAtQuery = `UPDATE subscription SET updated_at = ? WHERE endpoint = ?`
|
||||||
sqliteDeleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
sqliteDeleteSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||||
sqliteDeleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
sqliteDeleteSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||||
sqliteDeleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
sqliteDeleteSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||||
|
|
||||||
sqliteInsertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
sqliteInsertSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||||
sqliteDeleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
sqliteDeleteSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||||
sqliteDeleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
sqliteDeleteSubscriptionTopicWithoutSubscriptionQuery = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
|
||||||
)
|
)
|
||||||
|
|
||||||
// SQLite schema management queries
|
// SQLite schema management queries
|
||||||
const (
|
const (
|
||||||
sqliteCurrentWebPushSchemaVersion = 1
|
sqliteCurrentSchemaVersion = 1
|
||||||
sqliteInsertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
sqliteInsertSchemaVersionQuery = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
sqliteSelectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewSQLiteStore creates a new SQLite-backed web push store.
|
// NewSQLiteStore creates a new SQLite-backed web push store.
|
||||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -89,46 +87,51 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
|||||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return &Store{
|
||||||
db: db,
|
db: db,
|
||||||
queries: storeQueries{
|
queries: queries{
|
||||||
selectSubscriptionIDByEndpoint: sqliteSelectWebPushSubscriptionIDByEndpoint,
|
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
|
||||||
selectSubscriptionCountBySubscriberIP: sqliteSelectWebPushSubscriptionCountBySubscriberIP,
|
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,
|
||||||
selectSubscriptionsForTopic: sqliteSelectWebPushSubscriptionsForTopicQuery,
|
selectSubscriptionsForTopic: sqliteSelectSubscriptionsForTopicQuery,
|
||||||
selectSubscriptionsExpiringSoon: sqliteSelectWebPushSubscriptionsExpiringSoonQuery,
|
selectSubscriptionsExpiringSoon: sqliteSelectSubscriptionsExpiringSoonQuery,
|
||||||
insertSubscription: sqliteInsertWebPushSubscriptionQuery,
|
upsertSubscription: sqliteUpsertSubscriptionQuery,
|
||||||
updateSubscriptionWarningSent: sqliteUpdateWebPushSubscriptionWarningSentQuery,
|
updateSubscriptionWarningSent: sqliteUpdateSubscriptionWarningSentQuery,
|
||||||
updateSubscriptionUpdatedAt: sqliteUpdateWebPushSubscriptionUpdatedAtQuery,
|
updateSubscriptionUpdatedAt: sqliteUpdateSubscriptionUpdatedAtQuery,
|
||||||
deleteSubscriptionByEndpoint: sqliteDeleteWebPushSubscriptionByEndpointQuery,
|
deleteSubscriptionByEndpoint: sqliteDeleteSubscriptionByEndpointQuery,
|
||||||
deleteSubscriptionByUserID: sqliteDeleteWebPushSubscriptionByUserIDQuery,
|
deleteSubscriptionByUserID: sqliteDeleteSubscriptionByUserIDQuery,
|
||||||
deleteSubscriptionByAge: sqliteDeleteWebPushSubscriptionByAgeQuery,
|
deleteSubscriptionByAge: sqliteDeleteSubscriptionByAgeQuery,
|
||||||
insertSubscriptionTopic: sqliteInsertWebPushSubscriptionTopicQuery,
|
insertSubscriptionTopic: sqliteInsertSubscriptionTopicQuery,
|
||||||
deleteSubscriptionTopicAll: sqliteDeleteWebPushSubscriptionTopicAllQuery,
|
deleteSubscriptionTopicAll: sqliteDeleteSubscriptionTopicAllQuery,
|
||||||
deleteSubscriptionTopicWithoutSubscription: sqliteDeleteWebPushSubscriptionTopicWithoutSubscription,
|
deleteSubscriptionTopicWithoutSubscription: sqliteDeleteSubscriptionTopicWithoutSubscriptionQuery,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupSQLite(db *sql.DB) error {
|
func setupSQLite(db *sql.DB) error {
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
|
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewSQLite(db)
|
return setupNewSQLite(db)
|
||||||
}
|
}
|
||||||
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
|
if schemaVersion > sqliteCurrentSchemaVersion {
|
||||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewSQLite(db *sql.DB) error {
|
func setupNewSQLite(db *sql.DB) error {
|
||||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := db.Exec(sqliteInsertWebPushSchemaVersion, sqliteCurrentWebPushSchemaVersion); err != nil {
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
package webpush_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"heckel.io/ntfy/v2/webpush"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestSQLiteStore(t *testing.T) webpush.Store {
|
|
||||||
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() { store.Close() })
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionSubscriptionsForTopic(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionSubscriberIPLimitReached(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionUpdateTopics(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
|
||||||
testStoreUpsertSubscriptionUpdateFields(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveByUserIDMultiple(t *testing.T) {
|
|
||||||
testStoreRemoveByUserIDMultiple(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveByEndpoint(t *testing.T) {
|
|
||||||
testStoreRemoveByEndpoint(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveByUserID(t *testing.T) {
|
|
||||||
testStoreRemoveByUserID(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveByUserIDEmpty(t *testing.T) {
|
|
||||||
testStoreRemoveByUserIDEmpty(t, newTestSQLiteStore(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreExpiryWarningSent(t *testing.T) {
|
|
||||||
store := newTestSQLiteStore(t)
|
|
||||||
testStoreExpiryWarningSent(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreExpiring(t *testing.T) {
|
|
||||||
store := newTestSQLiteStore(t)
|
|
||||||
testStoreExpiring(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSQLiteStoreRemoveExpired(t *testing.T) {
|
|
||||||
store := newTestSQLiteStore(t)
|
|
||||||
testStoreRemoveExpired(t, store, store.SetSubscriptionUpdatedAt)
|
|
||||||
}
|
|
||||||
@@ -3,16 +3,34 @@ package webpush_test
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/webpush"
|
"heckel.io/ntfy/v2/webpush"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
const testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
||||||
|
|
||||||
func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpush.Store) {
|
func forEachBackend(t *testing.T, f func(t *testing.T, store *webpush.Store)) {
|
||||||
|
t.Run("sqlite", func(t *testing.T) {
|
||||||
|
store, err := webpush.NewSQLiteStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { store.Close() })
|
||||||
|
f(t, store)
|
||||||
|
})
|
||||||
|
t.Run("postgres", func(t *testing.T) {
|
||||||
|
testDB := dbtest.CreateTestPostgres(t)
|
||||||
|
store, err := webpush.NewPostgresStore(testDB)
|
||||||
|
require.Nil(t, err)
|
||||||
|
f(t, store)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||||
|
|
||||||
subs, err := store.SubscriptionsForTopic("test-topic")
|
subs, err := store.SubscriptionsForTopic("test-topic")
|
||||||
@@ -27,9 +45,11 @@ func testStoreUpsertSubscriptionSubscriptionsForTopic(t *testing.T, store webpus
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs2, 1)
|
require.Len(t, subs2, 1)
|
||||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store webpush.Store) {
|
func TestStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert 10 subscriptions with the same IP address
|
// Insert 10 subscriptions with the same IP address
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||||
@@ -44,9 +64,11 @@ func testStoreUpsertSubscriptionSubscriberIPLimitReached(t *testing.T, store web
|
|||||||
|
|
||||||
// But with a different IP address it should be fine again
|
// But with a different IP address it should be fine again
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store) {
|
func TestStoreUpsertSubscriptionUpdateTopics(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics, and another with one topic
|
// Insert subscription with two topics, and another with one topic
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||||
@@ -73,9 +95,11 @@ func testStoreUpsertSubscriptionUpdateTopics(t *testing.T, store webpush.Store)
|
|||||||
subs, err = store.SubscriptionsForTopic("topic2")
|
subs, err = store.SubscriptionsForTopic("topic2")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 0)
|
require.Len(t, subs, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store) {
|
func TestStoreUpsertSubscriptionUpdateFields(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert a subscription
|
// Insert a subscription
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
|
|
||||||
@@ -96,9 +120,11 @@ func testStoreUpsertSubscriptionUpdateFields(t *testing.T, store webpush.Store)
|
|||||||
require.Equal(t, "new-auth", subs[0].Auth)
|
require.Equal(t, "new-auth", subs[0].Auth)
|
||||||
require.Equal(t, "new-p256dh", subs[0].P256dh)
|
require.Equal(t, "new-p256dh", subs[0].P256dh)
|
||||||
require.Equal(t, "u_5678", subs[0].UserID)
|
require.Equal(t, "u_5678", subs[0].UserID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
|
func TestStoreRemoveByUserIDMultiple(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert two subscriptions for u_1234 and one for u_5678
|
// Insert two subscriptions for u_1234 and one for u_5678
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||||
@@ -117,9 +143,11 @@ func testStoreRemoveByUserIDMultiple(t *testing.T, store webpush.Store) {
|
|||||||
require.Len(t, subs, 1)
|
require.Len(t, subs, 1)
|
||||||
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
|
require.Equal(t, testWebPushEndpoint+"2", subs[0].Endpoint)
|
||||||
require.Equal(t, "u_5678", subs[0].UserID)
|
require.Equal(t, "u_5678", subs[0].UserID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
|
func TestStoreRemoveByEndpoint(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -131,9 +159,11 @@ func testStoreRemoveByEndpoint(t *testing.T, store webpush.Store) {
|
|||||||
subs, err = store.SubscriptionsForTopic("topic1")
|
subs, err = store.SubscriptionsForTopic("topic1")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 0)
|
require.Len(t, subs, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
|
func TestStoreRemoveByUserID(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -145,18 +175,22 @@ func testStoreRemoveByUserID(t *testing.T, store webpush.Store) {
|
|||||||
subs, err = store.SubscriptionsForTopic("topic1")
|
subs, err = store.SubscriptionsForTopic("topic1")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 0)
|
require.Len(t, subs, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreRemoveByUserIDEmpty(t *testing.T, store webpush.Store) {
|
func TestStoreRemoveByUserIDEmpty(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
|
require.Equal(t, webpush.ErrWebPushUserIDCannotBeEmpty, store.RemoveSubscriptionsByUserID(""))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
func TestStoreExpiryWarningSent(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
|
|
||||||
// Set updated_at to the past so it shows up as expiring
|
// Set updated_at to the past so it shows up as expiring
|
||||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
||||||
|
|
||||||
// Verify subscription appears in expiring list (warned_at == 0)
|
// Verify subscription appears in expiring list (warned_at == 0)
|
||||||
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
subs, err := store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||||
@@ -171,9 +205,11 @@ func testStoreExpiryWarningSent(t *testing.T, store webpush.Store, setUpdatedAt
|
|||||||
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
subs, err = store.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 0)
|
require.Len(t, subs, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
func TestStoreExpiring(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -181,7 +217,7 @@ func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endp
|
|||||||
require.Len(t, subs, 1)
|
require.Len(t, subs, 1)
|
||||||
|
|
||||||
// Fake-mark them as soon-to-expire
|
// Fake-mark them as soon-to-expire
|
||||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-8*24*time.Hour).Unix()))
|
||||||
|
|
||||||
// Should not be cleaned up yet
|
// Should not be cleaned up yet
|
||||||
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||||
@@ -191,9 +227,11 @@ func testStoreExpiring(t *testing.T, store webpush.Store, setUpdatedAt func(endp
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 1)
|
require.Len(t, subs, 1)
|
||||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func(endpoint string, updatedAt int64) error) {
|
func TestStoreRemoveExpired(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, store *webpush.Store) {
|
||||||
// Insert subscription with two topics
|
// Insert subscription with two topics
|
||||||
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
require.Nil(t, store.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||||
subs, err := store.SubscriptionsForTopic("topic1")
|
subs, err := store.SubscriptionsForTopic("topic1")
|
||||||
@@ -201,7 +239,7 @@ func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func
|
|||||||
require.Len(t, subs, 1)
|
require.Len(t, subs, 1)
|
||||||
|
|
||||||
// Fake-mark them as expired
|
// Fake-mark them as expired
|
||||||
require.Nil(t, setUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
|
require.Nil(t, store.SetSubscriptionUpdatedAt(testWebPushEndpoint, time.Now().Add(-10*24*time.Hour).Unix()))
|
||||||
|
|
||||||
// Run expiration
|
// Run expiration
|
||||||
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
require.Nil(t, store.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||||
@@ -210,4 +248,5 @@ func testStoreRemoveExpired(t *testing.T, store webpush.Store, setUpdatedAt func
|
|||||||
subs, err = store.SubscriptionsForTopic("topic1")
|
subs, err = store.SubscriptionsForTopic("topic1")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Len(t, subs, 0)
|
require.Len(t, subs, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user