This commit is contained in:
binwiederhier
2026-02-20 15:36:12 -05:00
parent e818b063f7
commit a4c836b531
11 changed files with 113 additions and 200 deletions

View File

@@ -33,11 +33,11 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/postgres"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"heckel.io/ntfy/v2/util/sprig"
@@ -179,22 +179,22 @@ func New(conf *Config) (*Server, error) {
stripe = newStripeAPI()
}
// Open shared PostgreSQL connection pool if configured
var db *sql.DB
var pool *sql.DB
if conf.DatabaseURL != "" {
var err error
db, err = postgres.OpenDB(conf.DatabaseURL)
pool, err = db.Open(conf.DatabaseURL)
if err != nil {
return nil, err
}
}
messageCache, err := createMessageCache(conf, db)
messageCache, err := createMessageCache(conf, pool)
if err != nil {
return nil, err
}
var wp webpush.Store
if conf.WebPushPublicKey != "" {
if db != nil {
wp, err = webpush.NewPostgresStore(db)
if pool != nil {
wp, err = webpush.NewPostgresStore(pool)
} else {
wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries)
}
@@ -222,7 +222,7 @@ func New(conf *Config) (*Server, error) {
}
}
var userManager *user.Manager
if conf.AuthFile != "" || db != nil {
if conf.AuthFile != "" || pool != nil {
authConfig := &user.Config{
Filename: conf.AuthFile,
DatabaseURL: conf.DatabaseURL,
@@ -236,8 +236,8 @@ func New(conf *Config) (*Server, error) {
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
}
var store user.Store
if db != nil {
store, err = user.NewPostgresStore(db)
if pool != nil {
store, err = user.NewPostgresStore(pool)
} else {
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
}
@@ -265,7 +265,7 @@ func New(conf *Config) (*Server, error) {
}
s := &Server{
config: conf,
db: db,
db: pool,
messageCache: messageCache,
webPush: wp,
fileCache: fileCache,
@@ -282,11 +282,11 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config, db *sql.DB) (message.Store, error) {
func createMessageCache(conf *Config, pool *sql.DB) (message.Store, error) {
if conf.CacheDuration == 0 {
return message.NewNopStore()
} else if db != nil {
return message.NewPostgresStore(db, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if pool != nil {
return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" {
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
}

View File

@@ -13,7 +13,6 @@ import (
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"os"
"path/filepath"
"runtime/debug"
@@ -25,10 +24,10 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
dbtest "heckel.io/ntfy/v2/db/test"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/postgres"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -4129,33 +4128,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) {
f(t, "")
})
t.Run("postgres", func(t *testing.T) {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set")
}
schema := fmt.Sprintf("test_%s", util.RandomString(10))
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("pool_max_conns", "2")
u.RawQuery = q.Encode()
dsn = u.String()
setupDB, err := postgres.OpenDB(dsn)
require.Nil(t, err, "failed to open postgres: %s", err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
q.Set("search_path", schema)
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanDB, _ := postgres.OpenDB(dsn)
if cleanDB != nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
f(t, schemaDSN)
f(t, dbtest.CreateTestSchema(t))
})
}