Refine
This commit is contained in:
21
cmd/serve.go
21
cmd/serve.go
@@ -282,7 +282,9 @@ func execServe(c *cli.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check values
|
// Check values
|
||||||
if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") {
|
||||||
|
return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set")
|
||||||
|
} else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
|
||||||
return errors.New("if set, FCM key file must exist")
|
return errors.New("if set, FCM key file must exist")
|
||||||
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
} else if firebaseKeyFile != "" && !server.FirebaseAvailable {
|
||||||
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)")
|
||||||
@@ -414,6 +416,15 @@ func execServe(c *cli.Context) error {
|
|||||||
payments.Setup(stripeSecretKey)
|
payments.Setup(stripeSecretKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse Twilio template
|
||||||
|
var twilioCallFormatTemplate *template.Template
|
||||||
|
if twilioCallFormat != "" {
|
||||||
|
twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Add default forbidden topics
|
// Add default forbidden topics
|
||||||
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
|
disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...)
|
||||||
|
|
||||||
@@ -461,13 +472,7 @@ func execServe(c *cli.Context) error {
|
|||||||
conf.TwilioAuthToken = twilioAuthToken
|
conf.TwilioAuthToken = twilioAuthToken
|
||||||
conf.TwilioPhoneNumber = twilioPhoneNumber
|
conf.TwilioPhoneNumber = twilioPhoneNumber
|
||||||
conf.TwilioVerifyService = twilioVerifyService
|
conf.TwilioVerifyService = twilioVerifyService
|
||||||
if twilioCallFormat != "" {
|
conf.TwilioCallFormat = twilioCallFormatTemplate
|
||||||
tmpl, err := template.New("twiml").Parse(twilioCallFormat)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse twilio-call-format template: %w", err)
|
|
||||||
}
|
|
||||||
conf.TwilioCallFormat = tmpl
|
|
||||||
}
|
|
||||||
conf.MessageSizeLimit = int(messageSizeLimit)
|
conf.MessageSizeLimit = int(messageSizeLimit)
|
||||||
conf.MessageDelayMax = messageDelayLimit
|
conf.MessageDelayMax = messageDelayLimit
|
||||||
conf.TotalTopicLimit = totalTopicLimit
|
conf.TotalTopicLimit = totalTopicLimit
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/server"
|
"heckel.io/ntfy/v2/server"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
@@ -380,11 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
}
|
}
|
||||||
var store user.Store
|
var store user.Store
|
||||||
if databaseURL != "" {
|
if databaseURL != "" {
|
||||||
db, dbErr := postgres.OpenDB(databaseURL)
|
pool, dbErr := db.Open(databaseURL)
|
||||||
if dbErr != nil {
|
if dbErr != nil {
|
||||||
return nil, dbErr
|
return nil, dbErr
|
||||||
}
|
}
|
||||||
store, err = user.NewPostgresStore(db)
|
store, err = user.NewPostgresStore(pool)
|
||||||
} else if authFile != "" {
|
} else if authFile != "" {
|
||||||
if !util.FileExists(authFile) {
|
if !util.FileExists(authFile) {
|
||||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package postgres
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
|
|
||||||
const defaultMaxOpenConns = 10
|
const defaultMaxOpenConns = 10
|
||||||
|
|
||||||
// OpenDB opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||||
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
|
||||||
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||||
// the DSN before passing it to the driver.
|
// the DSN before passing it to the driver.
|
||||||
func OpenDB(dsn string) (*sql.DB, error) {
|
func Open(dsn string) (*sql.DB, error) {
|
||||||
u, err := url.Parse(dsn)
|
u, err := url.Parse(dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid database URL: %w", err)
|
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||||
63
db/test/test.go
Normal file
63
db/test/test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package dbtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testPoolMaxConns = "2"
|
||||||
|
|
||||||
|
// CreateTestSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it.
|
||||||
|
// It registers a cleanup function to drop the schema when the test finishes.
|
||||||
|
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||||
|
func CreateTestSchema(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||||
|
}
|
||||||
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("pool_max_conns", testPoolMaxConns)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
dsn = u.String()
|
||||||
|
setupDB, err := db.Open(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
q.Set("search_path", schema)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
schemaDSN := u.String()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cleanDB, err := db.Open(dsn)
|
||||||
|
if err == nil {
|
||||||
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return schemaDSN
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTestDB creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it.
|
||||||
|
// It registers cleanup functions to close the DB and drop the schema when the test finishes.
|
||||||
|
// If NTFY_TEST_DATABASE_URL is not set, the test is skipped.
|
||||||
|
func CreateTestDB(t *testing.T) *sql.DB {
|
||||||
|
t.Helper()
|
||||||
|
schemaDSN := CreateTestSchema(t)
|
||||||
|
testDB, err := db.Open(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
testDB.Close()
|
||||||
|
})
|
||||||
|
return testDB
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -614,9 +613,3 @@ func readMessage(rows *sql.Rows) (*model.Message, error) {
|
|||||||
Encoding: encoding,
|
Encoding: encoding,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure commonStore implements Store
|
|
||||||
var _ Store = (*commonStore)(nil)
|
|
||||||
|
|
||||||
// Needed by store.go but not part of Store interface; unused import guard
|
|
||||||
var _ = fmt.Sprintf
|
|
||||||
|
|||||||
@@ -1,49 +1,18 @@
|
|||||||
package message_test
|
package message_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) message.Store {
|
func newTestPostgresStore(t *testing.T) message.Store {
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
testDB := dbtest.CreateTestDB(t)
|
||||||
if dsn == "" {
|
store, err := message.NewPostgresStore(testDB, 0, 0)
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
|
||||||
q.Set("pool_max_conns", "2")
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
dsn = u.String()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
schemaDSN := u.String()
|
|
||||||
setupDB, err := postgres.OpenDB(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
db, err := postgres.OpenDB(schemaDSN)
|
|
||||||
require.Nil(t, err)
|
|
||||||
store, err := message.NewPostgresStore(db, 0, 0)
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := postgres.OpenDB(dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/payments"
|
"heckel.io/ntfy/v2/payments"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"heckel.io/ntfy/v2/util/sprig"
|
"heckel.io/ntfy/v2/util/sprig"
|
||||||
@@ -179,22 +179,22 @@ func New(conf *Config) (*Server, error) {
|
|||||||
stripe = newStripeAPI()
|
stripe = newStripeAPI()
|
||||||
}
|
}
|
||||||
// Open shared PostgreSQL connection pool if configured
|
// Open shared PostgreSQL connection pool if configured
|
||||||
var db *sql.DB
|
var pool *sql.DB
|
||||||
if conf.DatabaseURL != "" {
|
if conf.DatabaseURL != "" {
|
||||||
var err error
|
var err error
|
||||||
db, err = postgres.OpenDB(conf.DatabaseURL)
|
pool, err = db.Open(conf.DatabaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
messageCache, err := createMessageCache(conf, db)
|
messageCache, err := createMessageCache(conf, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var wp webpush.Store
|
var wp webpush.Store
|
||||||
if conf.WebPushPublicKey != "" {
|
if conf.WebPushPublicKey != "" {
|
||||||
if db != nil {
|
if pool != nil {
|
||||||
wp, err = webpush.NewPostgresStore(db)
|
wp, err = webpush.NewPostgresStore(pool)
|
||||||
} else {
|
} else {
|
||||||
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" || db != nil {
|
if conf.AuthFile != "" || pool != nil {
|
||||||
authConfig := &user.Config{
|
authConfig := &user.Config{
|
||||||
Filename: conf.AuthFile,
|
Filename: conf.AuthFile,
|
||||||
DatabaseURL: conf.DatabaseURL,
|
DatabaseURL: conf.DatabaseURL,
|
||||||
@@ -236,8 +236,8 @@ func New(conf *Config) (*Server, error) {
|
|||||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
var store user.Store
|
var store user.Store
|
||||||
if db != nil {
|
if pool != nil {
|
||||||
store, err = user.NewPostgresStore(db)
|
store, err = user.NewPostgresStore(pool)
|
||||||
} else {
|
} else {
|
||||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||||
}
|
}
|
||||||
@@ -265,7 +265,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
db: db,
|
db: pool,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
webPush: wp,
|
webPush: wp,
|
||||||
fileCache: fileCache,
|
fileCache: fileCache,
|
||||||
@@ -282,11 +282,11 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config, db *sql.DB) (message.Store, error) {
|
func createMessageCache(conf *Config, pool *sql.DB) (message.Store, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return message.NewNopStore()
|
return message.NewNopStore()
|
||||||
} else if db != nil {
|
} else if pool != nil {
|
||||||
return message.NewPostgresStore(db, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||||
} else if conf.CacheFile != "" {
|
} else if conf.CacheFile != "" {
|
||||||
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
@@ -25,10 +24,10 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -4129,33 +4128,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) {
|
|||||||
f(t, "")
|
f(t, "")
|
||||||
})
|
})
|
||||||
t.Run("postgres", func(t *testing.T) {
|
t.Run("postgres", func(t *testing.T) {
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
f(t, dbtest.CreateTestSchema(t))
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
|
||||||
}
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("pool_max_conns", "2")
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
dsn = u.String()
|
|
||||||
setupDB, err := postgres.OpenDB(dsn)
|
|
||||||
require.Nil(t, err, "failed to open postgres: %s", err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
schemaDSN := u.String()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cleanDB, _ := postgres.OpenDB(dsn)
|
|
||||||
if cleanDB != nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
f(t, schemaDSN)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -13,7 +11,8 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
"heckel.io/ntfy/v2/db"
|
||||||
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,36 +33,11 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("postgres", func(t *testing.T) {
|
t.Run("postgres", func(t *testing.T) {
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
schemaDSN := dbtest.CreateTestSchema(t)
|
||||||
if dsn == "" {
|
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
|
||||||
}
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("pool_max_conns", "2")
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
dsn = u.String()
|
|
||||||
setupDB, err := postgres.OpenDB(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
schemaDSN := u.String()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
cleanDB, _ := postgres.OpenDB(dsn)
|
|
||||||
if cleanDB != nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
f(t, func() Store {
|
f(t, func() Store {
|
||||||
db, err := postgres.OpenDB(schemaDSN)
|
pool, err := db.Open(schemaDSN)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
store, err := NewPostgresStore(db)
|
store, err := NewPostgresStore(pool)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return store
|
return store
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,49 +1,17 @@
|
|||||||
package user_test
|
package user_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) user.Store {
|
func newTestPostgresStore(t *testing.T) user.Store {
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
testDB := dbtest.CreateTestDB(t)
|
||||||
if dsn == "" {
|
store, err := user.NewPostgresStore(testDB)
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
|
||||||
q.Set("pool_max_conns", "2")
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
dsn = u.String()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
schemaDSN := u.String()
|
|
||||||
setupDB, err := postgres.OpenDB(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
db, err := postgres.OpenDB(schemaDSN)
|
|
||||||
require.Nil(t, err)
|
|
||||||
store, err := user.NewPostgresStore(db)
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := postgres.OpenDB(dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,49 +1,17 @@
|
|||||||
package webpush_test
|
package webpush_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/v2/postgres"
|
dbtest "heckel.io/ntfy/v2/db/test"
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
"heckel.io/ntfy/v2/webpush"
|
"heckel.io/ntfy/v2/webpush"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestPostgresStore(t *testing.T) webpush.Store {
|
func newTestPostgresStore(t *testing.T) webpush.Store {
|
||||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
testDB := dbtest.CreateTestDB(t)
|
||||||
if dsn == "" {
|
store, err := webpush.NewPostgresStore(testDB)
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
|
||||||
}
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
|
||||||
u, err := url.Parse(dsn)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
|
||||||
q.Set("pool_max_conns", "2")
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
dsn = u.String()
|
|
||||||
q.Set("search_path", schema)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
schemaDSN := u.String()
|
|
||||||
setupDB, err := postgres.OpenDB(dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
db, err := postgres.OpenDB(schemaDSN)
|
|
||||||
require.Nil(t, err)
|
|
||||||
store, err := webpush.NewPostgresStore(db)
|
|
||||||
require.Nil(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
store.Close()
|
|
||||||
cleanDB, err := postgres.OpenDB(dsn)
|
|
||||||
if err == nil {
|
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
|
||||||
cleanDB.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user