Use one PG connection, add support for connection params
This commit is contained in:
11
cmd/user.go
11
cmd/user.go
@@ -6,13 +6,14 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/v2/server"
|
|
||||||
"heckel.io/ntfy/v2/user"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"github.com/urfave/cli/v2/altsrc"
|
"github.com/urfave/cli/v2/altsrc"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
|
"heckel.io/ntfy/v2/server"
|
||||||
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -379,7 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
}
|
}
|
||||||
var store user.Store
|
var store user.Store
|
||||||
if databaseURL != "" {
|
if databaseURL != "" {
|
||||||
store, err = user.NewPostgresStore(databaseURL)
|
db, dbErr := postgres.OpenDB(databaseURL)
|
||||||
|
if dbErr != nil {
|
||||||
|
return nil, dbErr
|
||||||
|
}
|
||||||
|
store, err = user.NewPostgresStore(db)
|
||||||
} else if authFile != "" {
|
} else if authFile != "" {
|
||||||
if !util.FileExists(authFile) {
|
if !util.FileExists(authFile) {
|
||||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ var errNoRows = errors.New("no rows found")
|
|||||||
type Store interface {
|
type Store interface {
|
||||||
AddMessage(m *model.Message) error
|
AddMessage(m *model.Message) error
|
||||||
AddMessages(ms []*model.Message) error
|
AddMessages(ms []*model.Message) error
|
||||||
DB() *sql.DB
|
|
||||||
Message(id string) (*model.Message, error)
|
Message(id string) (*model.Message, error)
|
||||||
MessageCounts() (map[string]int, error)
|
MessageCounts() (map[string]int, error)
|
||||||
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
||||||
@@ -99,11 +98,6 @@ func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeou
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DB returns the underlying database connection
|
|
||||||
func (c *commonStore) DB() *sql.DB {
|
|
||||||
return c.db
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||||
func (c *commonStore) AddMessage(m *model.Message) error {
|
func (c *commonStore) AddMessage(m *model.Message) error {
|
||||||
if c.queue != nil {
|
if c.queue != nil {
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package message
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostgreSQL runtime query constants
|
// PostgreSQL runtime query constants
|
||||||
@@ -104,16 +102,8 @@ var pgQueries = storeQueries{
|
|||||||
updateMessageTime: pgUpdateMessageTimesQuery,
|
updateMessageTime: pgUpdateMessageTimesQuery,
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
|
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||||
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
|
func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (Store, error) {
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
db.SetMaxOpenConns(25)
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgresDB(db); err != nil {
|
if err := setupPostgresDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package message_test
|
package message_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@@ -9,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) message.Store {
|
|||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||||
}
|
}
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
u, err := url.Parse(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
q := u.Query()
|
||||||
q.Set("search_path", schema)
|
q.Set("search_path", schema)
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
store, err := message.NewPostgresStore(u.String(), 0, 0)
|
schemaDSN := u.String()
|
||||||
|
setupDB, err := postgres.OpenDB(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
db, err := postgres.OpenDB(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
store, err := message.NewPostgresStore(db, 0, 0)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
store.Close()
|
store.Close()
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
cleanDB, err := postgres.OpenDB(dsn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
cleanDB.Close()
|
cleanDB.Close()
|
||||||
|
|||||||
86
postgres/postgres.go
Normal file
86
postgres/postgres.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultMaxOpenConns = 25
|
||||||
|
|
||||||
|
// OpenDB opens a PostgreSQL database connection pool from a DSN string. It supports custom
|
||||||
|
// query parameters for pool configuration: pool_max_conns (default 25), pool_max_idle_conns,
|
||||||
|
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
|
||||||
|
// the DSN before passing it to the driver.
|
||||||
|
func OpenDB(dsn string) (*sql.DB, error) {
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid database URL: %w", err)
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
maxOpenConns, err := extractIntParam(q, "pool_max_conns", defaultMaxOpenConns)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
maxIdleConns, err := extractIntParam(q, "pool_max_idle_conns", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxLifetime, err := extractDurationParam(q, "pool_conn_max_lifetime", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
connMaxIdleTime, err := extractDurationParam(q, "pool_conn_max_idle_time", 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
db, err := sql.Open("pgx", u.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.SetMaxOpenConns(maxOpenConns)
|
||||||
|
if maxIdleConns > 0 {
|
||||||
|
db.SetMaxIdleConns(maxIdleConns)
|
||||||
|
}
|
||||||
|
if connMaxLifetime > 0 {
|
||||||
|
db.SetConnMaxLifetime(connMaxLifetime)
|
||||||
|
}
|
||||||
|
if connMaxIdleTime > 0 {
|
||||||
|
db.SetConnMaxIdleTime(connMaxIdleTime)
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
v, err := strconv.Atoi(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) {
|
||||||
|
s := q.Get(key)
|
||||||
|
if s == "" {
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
q.Del(key)
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err)
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"database/sql"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -36,6 +37,7 @@ import (
|
|||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/payments"
|
"heckel.io/ntfy/v2/payments"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"heckel.io/ntfy/v2/util/sprig"
|
"heckel.io/ntfy/v2/util/sprig"
|
||||||
@@ -45,6 +47,7 @@ import (
|
|||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
|
db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
httpsServer *http.Server
|
httpsServer *http.Server
|
||||||
httpMetricsServer *http.Server
|
httpMetricsServer *http.Server
|
||||||
@@ -175,14 +178,23 @@ func New(conf *Config) (*Server, error) {
|
|||||||
if payments.Available && conf.StripeSecretKey != "" {
|
if payments.Available && conf.StripeSecretKey != "" {
|
||||||
stripe = newStripeAPI()
|
stripe = newStripeAPI()
|
||||||
}
|
}
|
||||||
messageCache, err := createMessageCache(conf)
|
// Open shared PostgreSQL connection pool if configured
|
||||||
|
var db *sql.DB
|
||||||
|
if conf.DatabaseURL != "" {
|
||||||
|
var err error
|
||||||
|
db, err = postgres.OpenDB(conf.DatabaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messageCache, err := createMessageCache(conf, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var wp webpush.Store
|
var wp webpush.Store
|
||||||
if conf.WebPushPublicKey != "" {
|
if conf.WebPushPublicKey != "" {
|
||||||
if conf.DatabaseURL != "" {
|
if db != nil {
|
||||||
wp, err = webpush.NewPostgresStore(conf.DatabaseURL)
|
wp, err = webpush.NewPostgresStore(db)
|
||||||
} else {
|
} else {
|
||||||
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||||
}
|
}
|
||||||
@@ -210,7 +222,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
if conf.AuthFile != "" || db != nil {
|
||||||
authConfig := &user.Config{
|
authConfig := &user.Config{
|
||||||
Filename: conf.AuthFile,
|
Filename: conf.AuthFile,
|
||||||
DatabaseURL: conf.DatabaseURL,
|
DatabaseURL: conf.DatabaseURL,
|
||||||
@@ -224,8 +236,8 @@ func New(conf *Config) (*Server, error) {
|
|||||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
var store user.Store
|
var store user.Store
|
||||||
if conf.DatabaseURL != "" {
|
if db != nil {
|
||||||
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
store, err = user.NewPostgresStore(db)
|
||||||
} else {
|
} else {
|
||||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||||
}
|
}
|
||||||
@@ -253,6 +265,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
s := &Server{
|
s := &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
|
db: db,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
webPush: wp,
|
webPush: wp,
|
||||||
fileCache: fileCache,
|
fileCache: fileCache,
|
||||||
@@ -269,11 +282,11 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config) (message.Store, error) {
|
func createMessageCache(conf *Config, db *sql.DB) (message.Store, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return message.NewNopStore()
|
return message.NewNopStore()
|
||||||
} else if conf.DatabaseURL != "" {
|
} else if db != nil {
|
||||||
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
return message.NewPostgresStore(db, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||||
} else if conf.CacheFile != "" {
|
} else if conf.CacheFile != "" {
|
||||||
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -24,12 +23,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/message"
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -4135,7 +4134,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) {
|
|||||||
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||||
}
|
}
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
setupDB, err := postgres.OpenDB(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -4147,9 +4146,11 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) {
|
|||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
schemaDSN := u.String()
|
schemaDSN := u.String()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
cleanDB, _ := sql.Open("pgx", dsn)
|
cleanDB, _ := postgres.OpenDB(dsn)
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
if cleanDB != nil {
|
||||||
cleanDB.Close()
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
f(t, schemaDSN)
|
f(t, schemaDSN)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) {
|
|||||||
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||||
}
|
}
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
setupDB, err := postgres.OpenDB(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -50,12 +51,16 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) {
|
|||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
schemaDSN := u.String()
|
schemaDSN := u.String()
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
cleanDB, _ := sql.Open("pgx", dsn)
|
cleanDB, _ := postgres.OpenDB(dsn)
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
if cleanDB != nil {
|
||||||
cleanDB.Close()
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
f(t, func() Store {
|
f(t, func() Store {
|
||||||
store, err := NewPostgresStore(schemaDSN)
|
db, err := postgres.OpenDB(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
store, err := NewPostgresStore(db)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return store
|
return store
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,8 +2,6 @@ package user
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostgreSQL queries
|
// PostgreSQL queries
|
||||||
@@ -206,16 +204,8 @@ const (
|
|||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
// NewPostgresStore creates a new PostgreSQL-backed user store using an existing database connection pool.
|
||||||
func NewPostgresStore(dsn string) (Store, error) {
|
func NewPostgresStore(db *sql.DB) (Store, error) {
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
db.SetMaxOpenConns(25)
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgres(db); err != nil {
|
if err := setupPostgres(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package user_test
|
package user_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) user.Store {
|
|||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||||
}
|
}
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
u, err := url.Parse(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
q := u.Query()
|
||||||
q.Set("search_path", schema)
|
q.Set("search_path", schema)
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
store, err := user.NewPostgresStore(u.String())
|
schemaDSN := u.String()
|
||||||
|
setupDB, err := postgres.OpenDB(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
db, err := postgres.OpenDB(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
store, err := user.NewPostgresStore(db)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
store.Close()
|
store.Close()
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
cleanDB, err := postgres.OpenDB(dsn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
cleanDB.Close()
|
cleanDB.Close()
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package webpush
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -70,16 +68,8 @@ const (
|
|||||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool.
|
||||||
func NewPostgresStore(dsn string) (Store, error) {
|
func NewPostgresStore(db *sql.DB) (Store, error) {
|
||||||
db, err := sql.Open("pgx", dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
db.SetMaxOpenConns(25)
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := setupPostgresDB(db); err != nil {
|
if err := setupPostgresDB(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package webpush_test
|
package webpush_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/postgres"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"heckel.io/ntfy/v2/webpush"
|
"heckel.io/ntfy/v2/webpush"
|
||||||
)
|
)
|
||||||
@@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) webpush.Store {
|
|||||||
if dsn == "" {
|
if dsn == "" {
|
||||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||||
}
|
}
|
||||||
// Create a unique schema for this test
|
|
||||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
setupDB, err := sql.Open("pgx", dsn)
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, setupDB.Close())
|
|
||||||
// Open store with search_path set to the new schema
|
|
||||||
u, err := url.Parse(dsn)
|
u, err := url.Parse(dsn)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
q := u.Query()
|
q := u.Query()
|
||||||
q.Set("search_path", schema)
|
q.Set("search_path", schema)
|
||||||
u.RawQuery = q.Encode()
|
u.RawQuery = q.Encode()
|
||||||
store, err := webpush.NewPostgresStore(u.String())
|
schemaDSN := u.String()
|
||||||
|
setupDB, err := postgres.OpenDB(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
db, err := postgres.OpenDB(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
store, err := webpush.NewPostgresStore(db)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
store.Close()
|
store.Close()
|
||||||
cleanDB, err := sql.Open("pgx", dsn)
|
cleanDB, err := postgres.OpenDB(dsn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
cleanDB.Close()
|
cleanDB.Close()
|
||||||
|
|||||||
Reference in New Issue
Block a user