diff --git a/cmd/serve.go b/cmd/serve.go index 33dc838d..62c88b96 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -282,7 +282,9 @@ func execServe(c *cli.Context) error { } // Check values - if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { + if databaseURL != "" && (authFile != "" || cacheFile != "" || webPushFile != "") { + return errors.New("if database-url is set, auth-file, cache-file, and web-push-file must not be set") + } else if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) { return errors.New("if set, FCM key file must exist") } else if firebaseKeyFile != "" && !server.FirebaseAvailable { return errors.New("cannot set firebase-key-file, support for Firebase is not available (nofirebase)") @@ -414,6 +416,15 @@ func execServe(c *cli.Context) error { payments.Setup(stripeSecretKey) } + // Parse Twilio template + var twilioCallFormatTemplate *template.Template + if twilioCallFormat != "" { + twilioCallFormatTemplate, err = template.New("").Parse(twilioCallFormat) + if err != nil { + return fmt.Errorf("failed to parse twilio-call-format template: %w", err) + } + } + // Add default forbidden topics disallowedTopics = append(disallowedTopics, server.DefaultDisallowedTopics...) @@ -461,13 +472,7 @@ func execServe(c *cli.Context) error { conf.TwilioAuthToken = twilioAuthToken conf.TwilioPhoneNumber = twilioPhoneNumber conf.TwilioVerifyService = twilioVerifyService - if twilioCallFormat != "" { - tmpl, err := template.New("twiml").Parse(twilioCallFormat) - if err != nil { - return fmt.Errorf("failed to parse twilio-call-format template: %w", err) - } - conf.TwilioCallFormat = tmpl - } + conf.TwilioCallFormat = twilioCallFormatTemplate conf.MessageSizeLimit = int(messageSizeLimit) conf.MessageDelayMax = messageDelayLimit conf.TotalTopicLimit = totalTopicLimit diff --git a/cmd/user.go b/cmd/user.go index e3d91cbe..f2c82a6c 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -11,7 +11,7 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" - "heckel.io/ntfy/v2/postgres" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/server" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" @@ -380,11 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { } var store user.Store if databaseURL != "" { - db, dbErr := postgres.OpenDB(databaseURL) + pool, dbErr := db.Open(databaseURL) if dbErr != nil { return nil, dbErr } - store, err = user.NewPostgresStore(db) + store, err = user.NewPostgresStore(pool) } else if authFile != "" { if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") diff --git a/postgres/postgres.go b/db/db.go similarity index 92% rename from postgres/postgres.go rename to db/db.go index 437b413e..60807b4a 100644 --- a/postgres/postgres.go +++ b/db/db.go @@ -1,4 +1,4 @@ -package postgres +package db import ( "database/sql" @@ -12,11 +12,11 @@ import ( const defaultMaxOpenConns = 10 -// OpenDB opens a PostgreSQL database connection pool from a DSN string. It supports custom +// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom // query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns, // pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from // the DSN before passing it to the driver. -func OpenDB(dsn string) (*sql.DB, error) { +func Open(dsn string) (*sql.DB, error) { u, err := url.Parse(dsn) if err != nil { return nil, fmt.Errorf("invalid database URL: %w", err) diff --git a/db/test/test.go b/db/test/test.go new file mode 100644 index 00000000..07a2f9a9 --- /dev/null +++ b/db/test/test.go @@ -0,0 +1,63 @@ +package dbtest + +import ( + "database/sql" + "fmt" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/db" + "heckel.io/ntfy/v2/util" +) + +const testPoolMaxConns = "2" + +// CreateTestSchema creates a temporary PostgreSQL schema and returns the DSN pointing to it. +// It registers a cleanup function to drop the schema when the test finishes. +// If NTFY_TEST_DATABASE_URL is not set, the test is skipped. +func CreateTestSchema(t *testing.T) string { + t.Helper() + dsn := os.Getenv("NTFY_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("NTFY_TEST_DATABASE_URL not set") + } + schema := fmt.Sprintf("test_%s", util.RandomString(10)) + u, err := url.Parse(dsn) + require.Nil(t, err) + q := u.Query() + q.Set("pool_max_conns", testPoolMaxConns) + u.RawQuery = q.Encode() + dsn = u.String() + setupDB, err := db.Open(dsn) + require.Nil(t, err) + _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.Nil(t, err) + require.Nil(t, setupDB.Close()) + q.Set("search_path", schema) + u.RawQuery = q.Encode() + schemaDSN := u.String() + t.Cleanup(func() { + cleanDB, err := db.Open(dsn) + if err == nil { + cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) + cleanDB.Close() + } + }) + return schemaDSN +} + +// CreateTestDB creates a temporary PostgreSQL schema and returns an open *sql.DB connection to it. +// It registers cleanup functions to close the DB and drop the schema when the test finishes. +// If NTFY_TEST_DATABASE_URL is not set, the test is skipped. +func CreateTestDB(t *testing.T) *sql.DB { + t.Helper() + schemaDSN := CreateTestSchema(t) + testDB, err := db.Open(schemaDSN) + require.Nil(t, err) + t.Cleanup(func() { + testDB.Close() + }) + return testDB +} diff --git a/message/store.go b/message/store.go index eeb34739..1351b823 100644 --- a/message/store.go +++ b/message/store.go @@ -4,7 +4,6 @@ import ( "database/sql" "encoding/json" "errors" - "fmt" "net/netip" "strings" "sync" @@ -614,9 +613,3 @@ func readMessage(rows *sql.Rows) (*model.Message, error) { Encoding: encoding, }, nil } - -// Ensure commonStore implements Store -var _ Store = (*commonStore)(nil) - -// Needed by store.go but not part of Store interface; unused import guard -var _ = fmt.Sprintf diff --git a/message/store_postgres_test.go b/message/store_postgres_test.go index 951d571f..3dec51a2 100644 --- a/message/store_postgres_test.go +++ b/message/store_postgres_test.go @@ -1,49 +1,18 @@ package message_test import ( - "fmt" - "net/url" - "os" "testing" - "github.com/stretchr/testify/require" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/message" - "heckel.io/ntfy/v2/postgres" - "heckel.io/ntfy/v2/util" + + "github.com/stretchr/testify/require" ) func newTestPostgresStore(t *testing.T) message.Store { - dsn := os.Getenv("NTFY_TEST_DATABASE_URL") - if dsn == "" { - t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") - } - schema := fmt.Sprintf("test_%s", util.RandomString(10)) - u, err := url.Parse(dsn) + testDB := dbtest.CreateTestDB(t) + store, err := message.NewPostgresStore(testDB, 0, 0) require.Nil(t, err) - q := u.Query() - q.Set("pool_max_conns", "2") - u.RawQuery = q.Encode() - dsn = u.String() - q.Set("search_path", schema) - u.RawQuery = q.Encode() - schemaDSN := u.String() - setupDB, err := postgres.OpenDB(dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - db, err := postgres.OpenDB(schemaDSN) - require.Nil(t, err) - store, err := message.NewPostgresStore(db, 0, 0) - require.Nil(t, err) - t.Cleanup(func() { - store.Close() - cleanDB, err := postgres.OpenDB(dsn) - if err == nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() - } - }) return store } diff --git a/server/server.go b/server/server.go index bf19b3ba..b8d3a18b 100644 --- a/server/server.go +++ b/server/server.go @@ -33,11 +33,11 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" + "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/payments" - "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util/sprig" @@ -179,22 +179,22 @@ func New(conf *Config) (*Server, error) { stripe = newStripeAPI() } // Open shared PostgreSQL connection pool if configured - var db *sql.DB + var pool *sql.DB if conf.DatabaseURL != "" { var err error - db, err = postgres.OpenDB(conf.DatabaseURL) + pool, err = db.Open(conf.DatabaseURL) if err != nil { return nil, err } } - messageCache, err := createMessageCache(conf, db) + messageCache, err := createMessageCache(conf, pool) if err != nil { return nil, err } var wp webpush.Store if conf.WebPushPublicKey != "" { - if db != nil { - wp, err = webpush.NewPostgresStore(db) + if pool != nil { + wp, err = webpush.NewPostgresStore(pool) } else { wp, err = webpush.NewSQLiteStore(conf.WebPushFile, conf.WebPushStartupQueries) } @@ -222,7 +222,7 @@ func New(conf *Config) (*Server, error) { } } var userManager *user.Manager - if conf.AuthFile != "" || db != nil { + if conf.AuthFile != "" || pool != nil { authConfig := &user.Config{ Filename: conf.AuthFile, DatabaseURL: conf.DatabaseURL, @@ -236,8 +236,8 @@ func New(conf *Config) (*Server, error) { QueueWriterInterval: conf.AuthStatsQueueWriterInterval, } var store user.Store - if db != nil { - store, err = user.NewPostgresStore(db) + if pool != nil { + store, err = user.NewPostgresStore(pool) } else { store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries) } @@ -265,7 +265,7 @@ func New(conf *Config) (*Server, error) { } s := &Server{ config: conf, - db: db, + db: pool, messageCache: messageCache, webPush: wp, fileCache: fileCache, @@ -282,11 +282,11 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config, db *sql.DB) (message.Store, error) { +func createMessageCache(conf *Config, pool *sql.DB) (message.Store, error) { if conf.CacheDuration == 0 { return message.NewNopStore() - } else if db != nil { - return message.NewPostgresStore(db, conf.CacheBatchSize, conf.CacheBatchTimeout) + } else if pool != nil { + return message.NewPostgresStore(pool, conf.CacheBatchSize, conf.CacheBatchTimeout) } else if conf.CacheFile != "" { return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) } diff --git a/server/server_test.go b/server/server_test.go index b1eacc8f..2acdbac1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -13,7 +13,6 @@ import ( "net/http" "net/http/httptest" "net/netip" - "net/url" "os" "path/filepath" "runtime/debug" @@ -25,10 +24,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/model" - "heckel.io/ntfy/v2/postgres" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -4129,33 +4128,7 @@ func forEachBackend(t *testing.T, f func(t *testing.T, databaseURL string)) { f(t, "") }) t.Run("postgres", func(t *testing.T) { - dsn := os.Getenv("NTFY_TEST_DATABASE_URL") - if dsn == "" { - t.Skip("NTFY_TEST_DATABASE_URL not set") - } - schema := fmt.Sprintf("test_%s", util.RandomString(10)) - u, err := url.Parse(dsn) - require.Nil(t, err) - q := u.Query() - q.Set("pool_max_conns", "2") - u.RawQuery = q.Encode() - dsn = u.String() - setupDB, err := postgres.OpenDB(dsn) - require.Nil(t, err, "failed to open postgres: %s", err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - q.Set("search_path", schema) - u.RawQuery = q.Encode() - schemaDSN := u.String() - t.Cleanup(func() { - cleanDB, _ := postgres.OpenDB(dsn) - if cleanDB != nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() - } - }) - f(t, schemaDSN) + f(t, dbtest.CreateTestSchema(t)) }) } diff --git a/user/manager_test.go b/user/manager_test.go index 139eecb3..83f21511 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -4,8 +4,6 @@ import ( "database/sql" "fmt" "net/netip" - "net/url" - "os" "path/filepath" "strings" "testing" @@ -13,7 +11,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" - "heckel.io/ntfy/v2/postgres" + "heckel.io/ntfy/v2/db" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/util" ) @@ -34,36 +33,11 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) { }) }) t.Run("postgres", func(t *testing.T) { - dsn := os.Getenv("NTFY_TEST_DATABASE_URL") - if dsn == "" { - t.Skip("NTFY_TEST_DATABASE_URL not set") - } - schema := fmt.Sprintf("test_%s", util.RandomString(10)) - u, err := url.Parse(dsn) - require.Nil(t, err) - q := u.Query() - q.Set("pool_max_conns", "2") - u.RawQuery = q.Encode() - dsn = u.String() - setupDB, err := postgres.OpenDB(dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - q.Set("search_path", schema) - u.RawQuery = q.Encode() - schemaDSN := u.String() - t.Cleanup(func() { - cleanDB, _ := postgres.OpenDB(dsn) - if cleanDB != nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() - } - }) + schemaDSN := dbtest.CreateTestSchema(t) f(t, func() Store { - db, err := postgres.OpenDB(schemaDSN) + pool, err := db.Open(schemaDSN) require.Nil(t, err) - store, err := NewPostgresStore(db) + store, err := NewPostgresStore(pool) require.Nil(t, err) return store }) diff --git a/user/store_postgres_test.go b/user/store_postgres_test.go index fafa62e9..319e9aa8 100644 --- a/user/store_postgres_test.go +++ b/user/store_postgres_test.go @@ -1,49 +1,17 @@ package user_test import ( - "fmt" - "net/url" - "os" "testing" "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/postgres" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/user" - "heckel.io/ntfy/v2/util" ) func newTestPostgresStore(t *testing.T) user.Store { - dsn := os.Getenv("NTFY_TEST_DATABASE_URL") - if dsn == "" { - t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") - } - schema := fmt.Sprintf("test_%s", util.RandomString(10)) - u, err := url.Parse(dsn) + testDB := dbtest.CreateTestDB(t) + store, err := user.NewPostgresStore(testDB) require.Nil(t, err) - q := u.Query() - q.Set("pool_max_conns", "2") - u.RawQuery = q.Encode() - dsn = u.String() - q.Set("search_path", schema) - u.RawQuery = q.Encode() - schemaDSN := u.String() - setupDB, err := postgres.OpenDB(dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - db, err := postgres.OpenDB(schemaDSN) - require.Nil(t, err) - store, err := user.NewPostgresStore(db) - require.Nil(t, err) - t.Cleanup(func() { - store.Close() - cleanDB, err := postgres.OpenDB(dsn) - if err == nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() - } - }) return store } diff --git a/webpush/store_postgres_test.go b/webpush/store_postgres_test.go index bdcc7a8e..4c675c82 100644 --- a/webpush/store_postgres_test.go +++ b/webpush/store_postgres_test.go @@ -1,49 +1,17 @@ package webpush_test import ( - "fmt" - "net/url" - "os" "testing" "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/postgres" - "heckel.io/ntfy/v2/util" + dbtest "heckel.io/ntfy/v2/db/test" "heckel.io/ntfy/v2/webpush" ) func newTestPostgresStore(t *testing.T) webpush.Store { - dsn := os.Getenv("NTFY_TEST_DATABASE_URL") - if dsn == "" { - t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") - } - schema := fmt.Sprintf("test_%s", util.RandomString(10)) - u, err := url.Parse(dsn) + testDB := dbtest.CreateTestDB(t) + store, err := webpush.NewPostgresStore(testDB) require.Nil(t, err) - q := u.Query() - q.Set("pool_max_conns", "2") - u.RawQuery = q.Encode() - dsn = u.String() - q.Set("search_path", schema) - u.RawQuery = q.Encode() - schemaDSN := u.String() - setupDB, err := postgres.OpenDB(dsn) - require.Nil(t, err) - _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) - require.Nil(t, err) - require.Nil(t, setupDB.Close()) - db, err := postgres.OpenDB(schemaDSN) - require.Nil(t, err) - store, err := webpush.NewPostgresStore(db) - require.Nil(t, err) - t.Cleanup(func() { - store.Close() - cleanDB, err := postgres.OpenDB(dsn) - if err == nil { - cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) - cleanDB.Close() - } - }) return store }