diff --git a/cmd/user.go b/cmd/user.go index bf822970..e3d91cbe 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -6,13 +6,14 @@ import ( "crypto/subtle" "errors" "fmt" - "heckel.io/ntfy/v2/server" - "heckel.io/ntfy/v2/user" "os" "strings" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" + "heckel.io/ntfy/v2/postgres" + "heckel.io/ntfy/v2/server" + "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -379,7 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { } var store user.Store if databaseURL != "" { - store, err = user.NewPostgresStore(databaseURL) + db, dbErr := postgres.OpenDB(databaseURL) + if dbErr != nil { + return nil, dbErr + } + store, err = user.NewPostgresStore(db) } else if authFile != "" { if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") diff --git a/message/store.go b/message/store.go index 05ab31d0..eeb34739 100644 --- a/message/store.go +++ b/message/store.go @@ -25,7 +25,6 @@ var errNoRows = errors.New("no rows found") type Store interface { AddMessage(m *model.Message) error AddMessages(ms []*model.Message) error - DB() *sql.DB Message(id string) (*model.Message, error) MessageCounts() (map[string]int, error) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) @@ -99,11 +98,6 @@ func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeou return c } -// DB returns the underlying database connection -func (c *commonStore) DB() *sql.DB { - return c.db -} - // AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously. func (c *commonStore) AddMessage(m *model.Message) error { if c.queue != nil { diff --git a/message/store_postgres.go b/message/store_postgres.go index 7ddb2464..32150e2d 100644 --- a/message/store_postgres.go +++ b/message/store_postgres.go @@ -3,8 +3,6 @@ package message import ( "database/sql" "time" - - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) // PostgreSQL runtime query constants @@ -104,16 +102,8 @@ var pgQueries = storeQueries{ updateMessageTime: pgUpdateMessageTimesQuery, } -// NewPostgresStore creates a new PostgreSQL-backed message cache store. -func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) { - db, err := sql.Open("pgx", dsn) - if err != nil { - return nil, err - } - db.SetMaxOpenConns(25) - if err := db.Ping(); err != nil { - return nil, err - } +// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. +func NewPostgresStore(db *sql.DB, batchSize int, batchTimeout time.Duration) (Store, error) { if err := setupPostgresDB(db); err != nil { return nil, err } diff --git a/message/store_postgres_test.go b/message/store_postgres_test.go index 930d700d..ebe60c70 100644 --- a/message/store_postgres_test.go +++ b/message/store_postgres_test.go @@ -1,7 +1,6 @@ package message_test import ( - "database/sql" "fmt" "net/url" "os" @@ -9,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/util" ) @@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) message.Store { if dsn == "" { t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") } - // Create a unique schema for this test schema := fmt.Sprintf("test_%s", util.RandomString(10)) - setupDB, err := sql.Open("pgx", dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - // Open store with search_path set to the new schema u, err := url.Parse(dsn) require.Nil(t, err) q := u.Query() q.Set("search_path", schema) u.RawQuery = q.Encode() - store, err := message.NewPostgresStore(u.String(), 0, 0) + schemaDSN := u.String() + setupDB, err := postgres.OpenDB(dsn) + require.Nil(t, err) + _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.Nil(t, err) + require.Nil(t, setupDB.Close()) + db, err := postgres.OpenDB(schemaDSN) + require.Nil(t, err) + store, err := message.NewPostgresStore(db, 0, 0) require.Nil(t, err) t.Cleanup(func() { store.Close() - cleanDB, err := sql.Open("pgx", dsn) + cleanDB, err := postgres.OpenDB(dsn) if err == nil { cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanDB.Close() diff --git a/postgres/postgres.go b/postgres/postgres.go new file mode 100644 index 00000000..94913f58 --- /dev/null +++ b/postgres/postgres.go @@ -0,0 +1,86 @@ +package postgres + +import ( + "database/sql" + "fmt" + "net/url" + "strconv" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver +) + +const defaultMaxOpenConns = 25 + +// OpenDB opens a PostgreSQL database connection pool from a DSN string. It supports custom +// query parameters for pool configuration: pool_max_conns (default 25), pool_max_idle_conns, +// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from +// the DSN before passing it to the driver. +func OpenDB(dsn string) (*sql.DB, error) { + u, err := url.Parse(dsn) + if err != nil { + return nil, fmt.Errorf("invalid database URL: %w", err) + } + q := u.Query() + maxOpenConns, err := extractIntParam(q, "pool_max_conns", defaultMaxOpenConns) + if err != nil { + return nil, err + } + maxIdleConns, err := extractIntParam(q, "pool_max_idle_conns", 0) + if err != nil { + return nil, err + } + connMaxLifetime, err := extractDurationParam(q, "pool_conn_max_lifetime", 0) + if err != nil { + return nil, err + } + connMaxIdleTime, err := extractDurationParam(q, "pool_conn_max_idle_time", 0) + if err != nil { + return nil, err + } + u.RawQuery = q.Encode() + db, err := sql.Open("pgx", u.String()) + if err != nil { + return nil, err + } + db.SetMaxOpenConns(maxOpenConns) + if maxIdleConns > 0 { + db.SetMaxIdleConns(maxIdleConns) + } + if connMaxLifetime > 0 { + db.SetConnMaxLifetime(connMaxLifetime) + } + if connMaxIdleTime > 0 { + db.SetConnMaxIdleTime(connMaxIdleTime) + } + if err := db.Ping(); err != nil { + return nil, err + } + return db, nil +} + +func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { + s := q.Get(key) + if s == "" { + return defaultValue, nil + } + q.Del(key) + v, err := strconv.Atoi(s) + if err != nil { + return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) + } + return v, nil +} + +func extractDurationParam(q url.Values, key string, defaultValue time.Duration) (time.Duration, error) { + s := q.Get(key) + if s == "" { + return defaultValue, nil + } + q.Del(key) + d, err := time.ParseDuration(s) + if err != nil { + return 0, fmt.Errorf("invalid %s value %q: %w", key, s, err) + } + return d, nil +} diff --git a/server/server.go b/server/server.go index 80f367c2..243a5094 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/sha256" + "database/sql" "embed" "encoding/base64" "encoding/json" @@ -36,6 +37,7 @@ import ( "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/payments" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util/sprig" @@ -45,6 +47,7 @@ import ( // Server is the main server, providing the UI and API for ntfy type Server struct { config *Config + db *sql.DB // Shared PostgreSQL connection pool, nil when using SQLite httpServer *http.Server httpsServer *http.Server httpMetricsServer *http.Server @@ -175,14 +178,23 @@ func New(conf *Config) (*Server, error) { if payments.Available && conf.StripeSecretKey != "" { stripe = newStripeAPI() } - messageCache, err := createMessageCache(conf) + // Open shared PostgreSQL connection pool if configured + var db *sql.DB + if conf.DatabaseURL != "" { + var err error + db, err = postgres.OpenDB(conf.DatabaseURL) + if err != nil { + return nil, err + } + } + messageCache, err := createMessageCache(conf, db) if err != nil { return nil, err } var wp webpush.Store if conf.WebPushPublicKey != "" { - if conf.DatabaseURL != "" { - wp, err = webpush.NewPostgresStore(conf.DatabaseURL) + if db != nil { + wp, err = webpush.NewPostgresStore(db) } else { wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries) } @@ -210,7 +222,7 @@ func New(conf *Config) (*Server, error) { } } var userManager *user.Manager - if conf.AuthFile != "" || conf.DatabaseURL != "" { + if conf.AuthFile != "" || db != nil { authConfig := &user.Config{ Filename: conf.AuthFile, DatabaseURL: conf.DatabaseURL, @@ -224,8 +236,8 @@ func New(conf *Config) (*Server, error) { QueueWriterInterval: conf.AuthStatsQueueWriterInterval, } var store user.Store - if conf.DatabaseURL != "" { - store, err = user.NewPostgresStore(conf.DatabaseURL) + if db != nil { + store, err = user.NewPostgresStore(db) } else { store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries) } @@ -253,6 +265,7 @@ func New(conf *Config) (*Server, error) { } s := &Server{ config: conf, + db: db, messageCache: messageCache, webPush: wp, fileCache: fileCache, @@ -269,11 +282,11 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config) (message.Store, error) { +func createMessageCache(conf *Config, db *sql.DB) (message.Store, error) { if conf.CacheDuration == 0 { return message.NewNopStore() - } else if conf.DatabaseURL != "" { - return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout) + } else if db != nil { + return message.NewPostgresStore(db, conf.CacheBatchSize, conf.CacheBatchTimeout) } else if conf.CacheFile != "" { return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) } diff --git a/server/server_test.go b/server/server_test.go index 9d27f361..88bc1d5a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "crypto/rand" - "database/sql" _ "embed" "encoding/base64" "encoding/json" @@ -24,12 +23,12 @@ import ( "testing" "time" - _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -4135,7 +4134,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) { t.Skip("NTFY_TEST_DATABASE_URL not set") } schema := fmt.Sprintf("test_%s", util.RandomString(10)) - setupDB, err := sql.Open("pgx", dsn) + setupDB, err := postgres.OpenDB(dsn) require.Nil(t, err) _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) require.Nil(t, err) @@ -4147,9 +4146,11 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) { u.RawQuery = q.Encode() schemaDSN := u.String() t.Cleanup(func() { - cleanDB, _ := sql.Open("pgx", dsn) - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() + cleanDB, _ := postgres.OpenDB(dsn) + if cleanDB != nil { + cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) + cleanDB.Close() + } }) f(t, schemaDSN) }) diff --git a/user/manager_test.go b/user/manager_test.go index 13cff861..133c7b05 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/util" ) @@ -38,7 +39,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) { t.Skip("NTFY_TEST_DATABASE_URL not set") } schema := fmt.Sprintf("test_%s", util.RandomString(10)) - setupDB, err := sql.Open("pgx", dsn) + setupDB, err := postgres.OpenDB(dsn) require.Nil(t, err) _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) require.Nil(t, err) @@ -50,12 +51,16 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) { u.RawQuery = q.Encode() schemaDSN := u.String() t.Cleanup(func() { - cleanDB, _ := sql.Open("pgx", dsn) - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() + cleanDB, _ := postgres.OpenDB(dsn) + if cleanDB != nil { + cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) + cleanDB.Close() + } }) f(t, func() Store { - store, err := NewPostgresStore(schemaDSN) + db, err := postgres.OpenDB(schemaDSN) + require.Nil(t, err) + store, err := NewPostgresStore(db) require.Nil(t, err) return store }) diff --git a/user/store_postgres.go b/user/store_postgres.go index fb3b221a..be7f998f 100644 --- a/user/store_postgres.go +++ b/user/store_postgres.go @@ -2,8 +2,6 @@ package user import ( "database/sql" - - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) // PostgreSQL queries @@ -206,16 +204,8 @@ const ( ` ) -// NewPostgresStore creates a new PostgreSQL-backed user store -func NewPostgresStore(dsn string) (Store, error) { - db, err := sql.Open("pgx", dsn) - if err != nil { - return nil, err - } - db.SetMaxOpenConns(25) - if err := db.Ping(); err != nil { - return nil, err - } +// NewPostgresStore creates a new PostgreSQL-backed user store using an existing database connection pool. +func NewPostgresStore(db *sql.DB) (Store, error) { if err := setupPostgres(db); err != nil { return nil, err } diff --git a/user/store_postgres_test.go b/user/store_postgres_test.go index bc3fab68..ea222539 100644 --- a/user/store_postgres_test.go +++ b/user/store_postgres_test.go @@ -1,13 +1,13 @@ package user_test import ( - "database/sql" "fmt" "net/url" "os" "testing" "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) user.Store { if dsn == "" { t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") } - // Create a unique schema for this test schema := fmt.Sprintf("test_%s", util.RandomString(10)) - setupDB, err := sql.Open("pgx", dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - // Open store with search_path set to the new schema u, err := url.Parse(dsn) require.Nil(t, err) q := u.Query() q.Set("search_path", schema) u.RawQuery = q.Encode() - store, err := user.NewPostgresStore(u.String()) + schemaDSN := u.String() + setupDB, err := postgres.OpenDB(dsn) + require.Nil(t, err) + _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.Nil(t, err) + require.Nil(t, setupDB.Close()) + db, err := postgres.OpenDB(schemaDSN) + require.Nil(t, err) + store, err := user.NewPostgresStore(db) require.Nil(t, err) t.Cleanup(func() { store.Close() - cleanDB, err := sql.Open("pgx", dsn) + cleanDB, err := postgres.OpenDB(dsn) if err == nil { cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanDB.Close() diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index 1403af68..c6366367 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -3,8 +3,6 @@ package webpush import ( "database/sql" "fmt" - - _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) const ( @@ -70,16 +68,8 @@ const ( pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'` ) -// NewPostgresStore creates a new PostgreSQL-backed web push store. -func NewPostgresStore(dsn string) (Store, error) { - db, err := sql.Open("pgx", dsn) - if err != nil { - return nil, err - } - db.SetMaxOpenConns(25) - if err := db.Ping(); err != nil { - return nil, err - } +// NewPostgresStore creates a new PostgreSQL-backed web push store using an existing database connection pool. +func NewPostgresStore(db *sql.DB) (Store, error) { if err := setupPostgresDB(db); err != nil { return nil, err } diff --git a/webpush/store_postgres_test.go b/webpush/store_postgres_test.go index 0124441c..bacd8163 100644 --- a/webpush/store_postgres_test.go +++ b/webpush/store_postgres_test.go @@ -1,13 +1,13 @@ package webpush_test import ( - "database/sql" "fmt" "net/url" "os" "testing" "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/webpush" ) @@ -17,24 +17,25 @@ func newTestPostgresStore(t *testing.T) webpush.Store { if dsn == "" { t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") } - // Create a unique schema for this test schema := fmt.Sprintf("test_%s", util.RandomString(10)) - setupDB, err := sql.Open("pgx", dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - // Open store with search_path set to the new schema u, err := url.Parse(dsn) require.Nil(t, err) q := u.Query() q.Set("search_path", schema) u.RawQuery = q.Encode() - store, err := webpush.NewPostgresStore(u.String()) + schemaDSN := u.String() + setupDB, err := postgres.OpenDB(dsn) + require.Nil(t, err) + _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.Nil(t, err) + require.Nil(t, setupDB.Close()) + db, err := postgres.OpenDB(schemaDSN) + require.Nil(t, err) + store, err := webpush.NewPostgresStore(db) require.Nil(t, err) t.Cleanup(func() { store.Close() - cleanDB, err := sql.Open("pgx", dsn) + cleanDB, err := postgres.OpenDB(dsn) if err == nil { cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanDB.Close()