diff --git a/user/manager_test.go b/user/manager_test.go index a27f18f1..e59e3cff 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -3,14 +3,15 @@ package user import ( "database/sql" "fmt" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/bcrypt" - "heckel.io/ntfy/v2/util" "net/netip" "path/filepath" "strings" "testing" "time" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/util" ) const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources @@ -1563,7 +1564,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { var schemaVersion int require.Nil(t, rows.Scan(&schemaVersion)) - require.Equal(t, currentSchemaVersion, schemaVersion) + require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion) require.Nil(t, rows.Close()) } diff --git a/user/store_postgres.go b/user/store_postgres.go index bfcf4b54..a2a78e3f 100644 --- a/user/store_postgres.go +++ b/user/store_postgres.go @@ -2,87 +2,12 @@ package user import ( "database/sql" - "fmt" _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) -// PostgreSQL schema and queries +// PostgreSQL queries const ( - postgresCreateTablesQueries = ` - CREATE TABLE IF NOT EXISTS tier ( - id TEXT PRIMARY KEY, - code TEXT NOT NULL, - name TEXT NOT NULL, - messages_limit BIGINT NOT NULL, - messages_expiry_duration BIGINT NOT NULL, - emails_limit BIGINT NOT NULL, - calls_limit BIGINT NOT NULL, - reservations_limit BIGINT NOT NULL, - attachment_file_size_limit BIGINT NOT NULL, - attachment_total_size_limit BIGINT NOT NULL, - attachment_expiry_duration BIGINT NOT NULL, - attachment_bandwidth_limit BIGINT NOT NULL, - stripe_monthly_price_id TEXT, - stripe_yearly_price_id TEXT, - UNIQUE(code), - UNIQUE(stripe_monthly_price_id), - UNIQUE(stripe_yearly_price_id) - ); - CREATE TABLE IF NOT EXISTS "user" ( - id TEXT PRIMARY KEY, - tier_id TEXT REFERENCES tier(id), - user_name TEXT NOT NULL UNIQUE, - pass TEXT NOT NULL, - role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')), - prefs JSONB NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - provisioned BOOLEAN NOT NULL, - stats_messages BIGINT NOT NULL DEFAULT 0, - stats_emails BIGINT NOT NULL DEFAULT 0, - stats_calls BIGINT NOT NULL DEFAULT 0, - stripe_customer_id TEXT UNIQUE, - stripe_subscription_id TEXT UNIQUE, - stripe_subscription_status TEXT, - stripe_subscription_interval TEXT, - stripe_subscription_paid_until BIGINT, - stripe_subscription_cancel_at BIGINT, - created BIGINT NOT NULL, - deleted BIGINT - ); - CREATE TABLE IF NOT EXISTS user_access ( - user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, - topic TEXT NOT NULL, - read BOOLEAN NOT NULL, - write BOOLEAN NOT NULL, - owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE, - provisioned BOOLEAN NOT NULL, - PRIMARY KEY (user_id, topic) - ); - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, - token TEXT NOT NULL UNIQUE, - label TEXT NOT NULL, - last_access BIGINT NOT NULL, - last_origin TEXT NOT NULL, - expires BIGINT NOT NULL, - provisioned BOOLEAN NOT NULL, - PRIMARY KEY (user_id, token) - ); - CREATE TABLE IF NOT EXISTS user_phone ( - user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, - phone_number TEXT NOT NULL, - PRIMARY KEY (user_id, phone_number) - ); - CREATE TABLE IF NOT EXISTS schema_version ( - store TEXT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) - VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT) - ON CONFLICT (id) DO NOTHING; - ` - // User queries postgresSelectUserByID = ` SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id @@ -279,15 +204,6 @@ const ( SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6 WHERE user_name = $7 ` - - // Schema version queries - postgresSelectSchemaVersionExists = `SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_schema = current_schema() - AND table_name = 'schema_version' - )` - postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'` - postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)` ) // NewPostgresStore creates a new PostgreSQL-backed user store @@ -299,126 +215,78 @@ func NewPostgresStore(dsn string) (Store, error) { if err := db.Ping(); err != nil { return nil, err } - if err := setupPostgresDB(db); err != nil { + if err := setupPostgres(db); err != nil { return nil, err } return &commonStore{ - db: db, - queries: postgresQueries(), + db: db, + queries: storeQueries{ + // User queries + selectUserByID: postgresSelectUserByID, + selectUserByName: postgresSelectUserByName, + selectUserByToken: postgresSelectUserByToken, + selectUserByStripeID: postgresSelectUserByStripeID, + selectUsernames: postgresSelectUsernames, + selectUserCount: postgresSelectUserCount, + selectUserIDFromUsername: postgresSelectUserIDFromUsername, + insertUser: postgresInsertUser, + updateUserPass: postgresUpdateUserPass, + updateUserRole: postgresUpdateUserRole, + updateUserProvisioned: postgresUpdateUserProvisioned, + updateUserPrefs: postgresUpdateUserPrefs, + updateUserStats: postgresUpdateUserStats, + updateUserStatsResetAll: postgresUpdateUserStatsResetAll, + updateUserTier: postgresUpdateUserTier, + updateUserDeleted: postgresUpdateUserDeleted, + deleteUser: postgresDeleteUser, + deleteUserTier: postgresDeleteUserTier, + deleteUsersMarked: postgresDeleteUsersMarked, + + // Access queries + selectTopicPerms: postgresSelectTopicPerms, + selectUserAllAccess: postgresSelectUserAllAccess, + selectUserAccess: postgresSelectUserAccess, + selectUserReservations: postgresSelectUserReservations, + selectUserReservationsCount: postgresSelectUserReservationsCount, + selectUserReservationsOwner: postgresSelectUserReservationsOwner, + selectUserHasReservation: postgresSelectUserHasReservation, + selectOtherAccessCount: postgresSelectOtherAccessCount, + upsertUserAccess: postgresUpsertUserAccess, + deleteUserAccess: postgresDeleteUserAccess, + deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned, + deleteTopicAccess: postgresDeleteTopicAccess, + deleteAllAccess: postgresDeleteAllAccess, + + // Token queries + selectToken: postgresSelectToken, + selectTokens: postgresSelectTokens, + selectTokenCount: postgresSelectTokenCount, + selectAllProvisionedTokens: postgresSelectAllProvisionedTokens, + upsertToken: postgresUpsertToken, + updateTokenLabel: postgresUpdateTokenLabel, + updateTokenExpiry: postgresUpdateTokenExpiry, + updateTokenLastAccess: postgresUpdateTokenLastAccess, + deleteToken: postgresDeleteToken, + deleteProvisionedToken: postgresDeleteProvisionedToken, + deleteAllToken: postgresDeleteAllToken, + deleteExpiredTokens: postgresDeleteExpiredTokens, + deleteExcessTokens: postgresDeleteExcessTokens, + + // Tier queries + insertTier: postgresInsertTier, + selectTiers: postgresSelectTiers, + selectTierByCode: postgresSelectTierByCode, + selectTierByPriceID: postgresSelectTierByPriceID, + updateTier: postgresUpdateTier, + deleteTier: postgresDeleteTier, + + // Phone queries + selectPhoneNumbers: postgresSelectPhoneNumbers, + insertPhoneNumber: postgresInsertPhoneNumber, + deletePhoneNumber: postgresDeletePhoneNumber, + + // Billing queries + updateBilling: postgresUpdateBilling, + }, }, nil } - -func postgresQueries() storeQueries { - return storeQueries{ - // User queries - selectUserByID: postgresSelectUserByID, - selectUserByName: postgresSelectUserByName, - selectUserByToken: postgresSelectUserByToken, - selectUserByStripeID: postgresSelectUserByStripeID, - selectUsernames: postgresSelectUsernames, - selectUserCount: postgresSelectUserCount, - selectUserIDFromUsername: postgresSelectUserIDFromUsername, - insertUser: postgresInsertUser, - updateUserPass: postgresUpdateUserPass, - updateUserRole: postgresUpdateUserRole, - updateUserProvisioned: postgresUpdateUserProvisioned, - updateUserPrefs: postgresUpdateUserPrefs, - updateUserStats: postgresUpdateUserStats, - updateUserStatsResetAll: postgresUpdateUserStatsResetAll, - updateUserTier: postgresUpdateUserTier, - updateUserDeleted: postgresUpdateUserDeleted, - deleteUser: postgresDeleteUser, - deleteUserTier: postgresDeleteUserTier, - deleteUsersMarked: postgresDeleteUsersMarked, - - // Access queries - selectTopicPerms: postgresSelectTopicPerms, - selectUserAllAccess: postgresSelectUserAllAccess, - selectUserAccess: postgresSelectUserAccess, - selectUserReservations: postgresSelectUserReservations, - selectUserReservationsCount: postgresSelectUserReservationsCount, - selectUserReservationsOwner: postgresSelectUserReservationsOwner, - selectUserHasReservation: postgresSelectUserHasReservation, - selectOtherAccessCount: postgresSelectOtherAccessCount, - upsertUserAccess: postgresUpsertUserAccess, - deleteUserAccess: postgresDeleteUserAccess, - deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned, - deleteTopicAccess: postgresDeleteTopicAccess, - deleteAllAccess: postgresDeleteAllAccess, - - // Token queries - selectToken: postgresSelectToken, - selectTokens: postgresSelectTokens, - selectTokenCount: postgresSelectTokenCount, - selectAllProvisionedTokens: postgresSelectAllProvisionedTokens, - upsertToken: postgresUpsertToken, - updateTokenLabel: postgresUpdateTokenLabel, - updateTokenExpiry: postgresUpdateTokenExpiry, - updateTokenLastAccess: postgresUpdateTokenLastAccess, - deleteToken: postgresDeleteToken, - deleteProvisionedToken: postgresDeleteProvisionedToken, - deleteAllToken: postgresDeleteAllToken, - deleteExpiredTokens: postgresDeleteExpiredTokens, - deleteExcessTokens: postgresDeleteExcessTokens, - - // Tier queries - insertTier: postgresInsertTier, - selectTiers: postgresSelectTiers, - selectTierByCode: postgresSelectTierByCode, - selectTierByPriceID: postgresSelectTierByPriceID, - updateTier: postgresUpdateTier, - deleteTier: postgresDeleteTier, - - // Phone queries - selectPhoneNumbers: postgresSelectPhoneNumbers, - insertPhoneNumber: postgresInsertPhoneNumber, - deletePhoneNumber: postgresDeletePhoneNumber, - - // Billing queries - updateBilling: postgresUpdateBilling, - } -} - -func setupPostgresDB(db *sql.DB) error { - // Check if schema version table exists - var exists bool - if err := db.QueryRow(postgresSelectSchemaVersionExists).Scan(&exists); err != nil { - return err - } - - if !exists { - // New database, create all tables - if _, err := db.Exec(postgresCreateTablesQueries); err != nil { - return err - } - if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil - } - - // Table exists, check if user store has a row - var schemaVersion int - err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion) - if err == sql.ErrNoRows { - // schema_version table exists (e.g. created by webpush) but no user row yet - if _, err := db.Exec(postgresCreateTablesQueries); err != nil { - return err - } - if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil - } else if err != nil { - return fmt.Errorf("cannot determine schema version: %v", err) - } - - if schemaVersion > currentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) - } - - // Note: PostgreSQL migrations will be added when needed - // For now, we only support new installations at the latest schema version - - return nil -} diff --git a/user/store_postgres_schema.go b/user/store_postgres_schema.go new file mode 100644 index 00000000..c5824970 --- /dev/null +++ b/user/store_postgres_schema.go @@ -0,0 +1,113 @@ +package user + +import ( + "database/sql" + "fmt" +) + +// Initial PostgreSQL schema +const ( + postgresCreateTablesQueries = ` + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit BIGINT NOT NULL, + messages_expiry_duration BIGINT NOT NULL, + emails_limit BIGINT NOT NULL, + calls_limit BIGINT NOT NULL, + reservations_limit BIGINT NOT NULL, + attachment_file_size_limit BIGINT NOT NULL, + attachment_total_size_limit BIGINT NOT NULL, + attachment_expiry_duration BIGINT NOT NULL, + attachment_bandwidth_limit BIGINT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT, + UNIQUE(code), + UNIQUE(stripe_monthly_price_id), + UNIQUE(stripe_yearly_price_id) + ); + CREATE TABLE IF NOT EXISTS "user" ( + id TEXT PRIMARY KEY, + tier_id TEXT REFERENCES tier(id), + user_name TEXT NOT NULL UNIQUE, + pass TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')), + prefs JSONB NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned BOOLEAN NOT NULL, + stats_messages BIGINT NOT NULL DEFAULT 0, + stats_emails BIGINT NOT NULL DEFAULT 0, + stats_calls BIGINT NOT NULL DEFAULT 0, + stripe_customer_id TEXT UNIQUE, + stripe_subscription_id TEXT UNIQUE, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until BIGINT, + stripe_subscription_cancel_at BIGINT, + created BIGINT NOT NULL, + deleted BIGINT + ); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + topic TEXT NOT NULL, + read BOOLEAN NOT NULL, + write BOOLEAN NOT NULL, + owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, topic) + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + token TEXT NOT NULL UNIQUE, + label TEXT NOT NULL, + last_access BIGINT NOT NULL, + last_origin TEXT NOT NULL, + expires BIGINT NOT NULL, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, token) + ); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number) + ); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT) + ON CONFLICT (id) DO NOTHING; + ` +) + +// Schema table management queries for Postgres +const ( + postgresCurrentSchemaVersion = 6 + postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'` + postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)` +) + +func setupPostgres(db *sql.DB) error { + var schemaVersion int + err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion) + if err != nil { + return setupNewPostgres(db) + } + if schemaVersion > postgresCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion) + } + // Note: PostgreSQL migrations will be added when needed + return nil +} + +func setupNewPostgres(db *sql.DB) error { + if _, err := db.Exec(postgresCreateTablesQueries); err != nil { + return err + } + if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil { + return err + } + return nil +} diff --git a/user/store_sqlite.go b/user/store_sqlite.go index b2c26d37..fc5bc7eb 100644 --- a/user/store_sqlite.go +++ b/user/store_sqlite.go @@ -2,102 +2,11 @@ package user import ( "database/sql" - "fmt" _ "github.com/mattn/go-sqlite3" // SQLite driver ) -// SQLite schema and queries const ( - sqliteCreateTablesQueries = ` - BEGIN; - CREATE TABLE IF NOT EXISTS tier ( - id TEXT PRIMARY KEY, - code TEXT NOT NULL, - name TEXT NOT NULL, - messages_limit INT NOT NULL, - messages_expiry_duration INT NOT NULL, - emails_limit INT NOT NULL, - calls_limit INT NOT NULL, - reservations_limit INT NOT NULL, - attachment_file_size_limit INT NOT NULL, - attachment_total_size_limit INT NOT NULL, - attachment_expiry_duration INT NOT NULL, - attachment_bandwidth_limit INT NOT NULL, - stripe_monthly_price_id TEXT, - stripe_yearly_price_id TEXT - ); - CREATE UNIQUE INDEX idx_tier_code ON tier (code); - CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); - CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - provisioned INT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stats_calls INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_interval TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE TABLE IF NOT EXISTS user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE UNIQUE INDEX idx_user_token ON user_token (token); - CREATE TABLE IF NOT EXISTS user_phone ( - user_id TEXT NOT NULL, - phone_number TEXT NOT NULL, - PRIMARY KEY (user_id, phone_number), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) - VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH()) - ON CONFLICT (id) DO NOTHING; - COMMIT; - ` - - sqliteBuiltinStartupQueries = ` - PRAGMA foreign_keys = ON; - ` - // User queries sqliteSelectUserByID = ` SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id @@ -295,126 +204,70 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { if err != nil { return nil, err } - if err := setupSQLiteDB(db); err != nil { + if err := setupSQLite(db); err != nil { return nil, err } if err := runSQLiteStartupQueries(db, startupQueries); err != nil { return nil, err } return &commonStore{ - db: db, - queries: sqliteQueries(), + db: db, + queries: storeQueries{ + selectUserByID: sqliteSelectUserByID, + selectUserByName: sqliteSelectUserByName, + selectUserByToken: sqliteSelectUserByToken, + selectUserByStripeID: sqliteSelectUserByStripeID, + selectUsernames: sqliteSelectUsernames, + selectUserCount: sqliteSelectUserCount, + selectUserIDFromUsername: sqliteSelectUserIDFromUsername, + insertUser: sqliteInsertUser, + updateUserPass: sqliteUpdateUserPass, + updateUserRole: sqliteUpdateUserRole, + updateUserProvisioned: sqliteUpdateUserProvisioned, + updateUserPrefs: sqliteUpdateUserPrefs, + updateUserStats: sqliteUpdateUserStats, + updateUserStatsResetAll: sqliteUpdateUserStatsResetAll, + updateUserTier: sqliteUpdateUserTier, + updateUserDeleted: sqliteUpdateUserDeleted, + deleteUser: sqliteDeleteUser, + deleteUserTier: sqliteDeleteUserTier, + deleteUsersMarked: sqliteDeleteUsersMarked, + selectTopicPerms: sqliteSelectTopicPerms, + selectUserAllAccess: sqliteSelectUserAllAccess, + selectUserAccess: sqliteSelectUserAccess, + selectUserReservations: sqliteSelectUserReservations, + selectUserReservationsCount: sqliteSelectUserReservationsCount, + selectUserReservationsOwner: sqliteSelectUserReservationsOwner, + selectUserHasReservation: sqliteSelectUserHasReservation, + selectOtherAccessCount: sqliteSelectOtherAccessCount, + upsertUserAccess: sqliteUpsertUserAccess, + deleteUserAccess: sqliteDeleteUserAccess, + deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned, + deleteTopicAccess: sqliteDeleteTopicAccess, + deleteAllAccess: sqliteDeleteAllAccess, + selectToken: sqliteSelectToken, + selectTokens: sqliteSelectTokens, + selectTokenCount: sqliteSelectTokenCount, + selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens, + upsertToken: sqliteUpsertToken, + updateTokenLabel: sqliteUpdateTokenLabel, + updateTokenExpiry: sqliteUpdateTokenExpiry, + updateTokenLastAccess: sqliteUpdateTokenLastAccess, + deleteToken: sqliteDeleteToken, + deleteProvisionedToken: sqliteDeleteProvisionedToken, + deleteAllToken: sqliteDeleteAllToken, + deleteExpiredTokens: sqliteDeleteExpiredTokens, + deleteExcessTokens: sqliteDeleteExcessTokens, + insertTier: sqliteInsertTier, + selectTiers: sqliteSelectTiers, + selectTierByCode: sqliteSelectTierByCode, + selectTierByPriceID: sqliteSelectTierByPriceID, + updateTier: sqliteUpdateTier, + deleteTier: sqliteDeleteTier, + selectPhoneNumbers: sqliteSelectPhoneNumbers, + insertPhoneNumber: sqliteInsertPhoneNumber, + deletePhoneNumber: sqliteDeletePhoneNumber, + updateBilling: sqliteUpdateBilling, + }, }, nil } - -func sqliteQueries() storeQueries { - return storeQueries{ - selectUserByID: sqliteSelectUserByID, - selectUserByName: sqliteSelectUserByName, - selectUserByToken: sqliteSelectUserByToken, - selectUserByStripeID: sqliteSelectUserByStripeID, - selectUsernames: sqliteSelectUsernames, - selectUserCount: sqliteSelectUserCount, - selectUserIDFromUsername: sqliteSelectUserIDFromUsername, - insertUser: sqliteInsertUser, - updateUserPass: sqliteUpdateUserPass, - updateUserRole: sqliteUpdateUserRole, - updateUserProvisioned: sqliteUpdateUserProvisioned, - updateUserPrefs: sqliteUpdateUserPrefs, - updateUserStats: sqliteUpdateUserStats, - updateUserStatsResetAll: sqliteUpdateUserStatsResetAll, - updateUserTier: sqliteUpdateUserTier, - updateUserDeleted: sqliteUpdateUserDeleted, - deleteUser: sqliteDeleteUser, - deleteUserTier: sqliteDeleteUserTier, - deleteUsersMarked: sqliteDeleteUsersMarked, - selectTopicPerms: sqliteSelectTopicPerms, - selectUserAllAccess: sqliteSelectUserAllAccess, - selectUserAccess: sqliteSelectUserAccess, - selectUserReservations: sqliteSelectUserReservations, - selectUserReservationsCount: sqliteSelectUserReservationsCount, - selectUserReservationsOwner: sqliteSelectUserReservationsOwner, - selectUserHasReservation: sqliteSelectUserHasReservation, - selectOtherAccessCount: sqliteSelectOtherAccessCount, - upsertUserAccess: sqliteUpsertUserAccess, - deleteUserAccess: sqliteDeleteUserAccess, - deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned, - deleteTopicAccess: sqliteDeleteTopicAccess, - deleteAllAccess: sqliteDeleteAllAccess, - selectToken: sqliteSelectToken, - selectTokens: sqliteSelectTokens, - selectTokenCount: sqliteSelectTokenCount, - selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens, - upsertToken: sqliteUpsertToken, - updateTokenLabel: sqliteUpdateTokenLabel, - updateTokenExpiry: sqliteUpdateTokenExpiry, - updateTokenLastAccess: sqliteUpdateTokenLastAccess, - deleteToken: sqliteDeleteToken, - deleteProvisionedToken: sqliteDeleteProvisionedToken, - deleteAllToken: sqliteDeleteAllToken, - deleteExpiredTokens: sqliteDeleteExpiredTokens, - deleteExcessTokens: sqliteDeleteExcessTokens, - insertTier: sqliteInsertTier, - selectTiers: sqliteSelectTiers, - selectTierByCode: sqliteSelectTierByCode, - selectTierByPriceID: sqliteSelectTierByPriceID, - updateTier: sqliteUpdateTier, - deleteTier: sqliteDeleteTier, - selectPhoneNumbers: sqliteSelectPhoneNumbers, - insertPhoneNumber: sqliteInsertPhoneNumber, - deletePhoneNumber: sqliteDeletePhoneNumber, - updateBilling: sqliteUpdateBilling, - } -} - -func setupSQLiteDB(db *sql.DB) error { - rowsSV, err := db.Query(selectSchemaVersionQuery) - if err != nil { - return setupNewSQLiteDB(db) - } - defer rowsSV.Close() - schemaVersion := 0 - if !rowsSV.Next() { - return fmt.Errorf("cannot determine schema version: database file may be corrupt") - } - if err := rowsSV.Scan(&schemaVersion); err != nil { - return err - } - rowsSV.Close() - if schemaVersion == currentSchemaVersion { - return nil - } else if schemaVersion > currentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) - } - for i := schemaVersion; i < currentSchemaVersion; i++ { - fn, ok := migrations[i] - if !ok { - return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) - } else if err := fn(db); err != nil { - return err - } - } - return nil -} - -func setupNewSQLiteDB(db *sql.DB) error { - if _, err := db.Exec(sqliteCreateTablesQueries); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil -} - -func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { - if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil { - return err - } - if startupQueries != "" { - if _, err := db.Exec(startupQueries); err != nil { - return err - } - } - return nil -} diff --git a/user/store_sqlite_migrations.go b/user/store_sqlite_schema.go similarity index 56% rename from user/store_sqlite_migrations.go rename to user/store_sqlite_schema.go index 442bcc05..f8e65cb6 100644 --- a/user/store_sqlite_migrations.go +++ b/user/store_sqlite_schema.go @@ -2,20 +2,116 @@ package user import ( "database/sql" + "fmt" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/util" ) -// SQLite migrations +// Initial SQLite schema const ( - currentSchemaVersion = 6 - insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` - selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` + sqliteCreateTablesQueries = ` + BEGIN; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + calls_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned INT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE UNIQUE INDEX idx_user_token ON user_token (token); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; + COMMIT; + ` +) +const ( + sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;` +) + +// Schema version table management for SQLite +const ( + sqliteCurrentSchemaVersion = 6 + sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` + sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` +) + +// Schema migrations for SQLite +const ( // 1 -> 2 (complex migration!) - migrate1To2CreateTablesQueries = ` + sqliteMigrate1To2CreateTablesQueries = ` ALTER TABLE user RENAME TO user_old; CREATE TABLE IF NOT EXISTS tier ( id TEXT PRIMARY KEY, @@ -83,12 +179,12 @@ const ( VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) ON CONFLICT (id) DO NOTHING; ` - migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` - migrate1To2InsertUserNoTx = ` + sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` + sqliteMigrate1To2InsertUserNoTx = ` INSERT INTO user (id, user, pass, role, sync_topic, created) SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? ` - migrate1To2InsertFromOldTablesAndDropNoTx = ` + sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = ` INSERT INTO user_access (user_id, topic, read, write) SELECT u.id, a.topic, a.read, a.write FROM user u @@ -99,7 +195,7 @@ const ( ` // 2 -> 3 - migrate2To3UpdateQueries = ` + sqliteMigrate2To3UpdateQueries = ` ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; @@ -109,7 +205,7 @@ const ( ` // 3 -> 4 - migrate3To4UpdateQueries = ` + sqliteMigrate3To4UpdateQueries = ` ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); CREATE TABLE IF NOT EXISTS user_phone ( @@ -121,12 +217,12 @@ const ( ` // 4 -> 5 - migrate4To5UpdateQueries = ` + sqliteMigrate4To5UpdateQueries = ` UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); ` // 5 -> 6 - migrate5To6UpdateQueries = ` + sqliteMigrate5To6UpdateQueries = ` PRAGMA foreign_keys=off; -- Alter user table: Add provisioned column @@ -221,16 +317,60 @@ const ( ) var ( - migrations = map[int]func(db *sql.DB) error{ - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, + sqliteMigrations = map[int]func(db *sql.DB) error{ + 1: sqliteMigrateFrom1, + 2: sqliteMigrateFrom2, + 3: sqliteMigrateFrom3, + 4: sqliteMigrateFrom4, + 5: sqliteMigrateFrom5, } ) -func migrateFrom1(db *sql.DB) error { +func setupSQLite(db *sql.DB) error { + var schemaVersion int + err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion) + if err != nil { + return setupNewSQLite(db) + } + if schemaVersion == sqliteCurrentSchemaVersion { + return nil + } else if schemaVersion > sqliteCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion) + } + for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ { + fn, ok := sqliteMigrations[i] + if !ok { + return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) + } else if err := fn(db); err != nil { + return err + } + } + return nil +} + +func setupNewSQLite(db *sql.DB) error { + if _, err := db.Exec(sqliteCreateTablesQueries); err != nil { + return err + } + if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil { + return err + } + return nil +} + +func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { + if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil { + return err + } + if startupQueries != "" { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + } + return nil +} + +func sqliteMigrateFrom1(db *sql.DB) error { log.Tag(tag).Info("Migrating user database schema: from 1 to 2") tx, err := db.Begin() if err != nil { @@ -238,11 +378,11 @@ func migrateFrom1(db *sql.DB) error { } defer tx.Rollback() // Rename user -> user_old, and create new tables - if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { + if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil { return err } // Insert users from user_old into new user table, with ID and sync_topic - rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) + rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx) if err != nil { return err } @@ -261,15 +401,15 @@ func migrateFrom1(db *sql.DB) error { for _, username := range usernames { userID := util.RandomStringPrefix(userIDPrefix, userIDLength) syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) - if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { + if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { return err } } // Migrate old "access" table to "user_access" and drop "access" and "user_old" - if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { + if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil { return err } - if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil { return err } if err := tx.Commit(); err != nil { @@ -278,65 +418,65 @@ func migrateFrom1(db *sql.DB) error { return nil } -func migrateFrom2(db *sql.DB) error { +func sqliteMigrateFrom2(db *sql.DB) error { log.Tag(tag).Info("Migrating user database schema: from 2 to 3") tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { + if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil { return err } - if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil { return err } return tx.Commit() } -func migrateFrom3(db *sql.DB) error { +func sqliteMigrateFrom3(db *sql.DB) error { log.Tag(tag).Info("Migrating user database schema: from 3 to 4") tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { + if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil { return err } - if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil { return err } return tx.Commit() } -func migrateFrom4(db *sql.DB) error { +func sqliteMigrateFrom4(db *sql.DB) error { log.Tag(tag).Info("Migrating user database schema: from 4 to 5") tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { + if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil { return err } - if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil { return err } return tx.Commit() } -func migrateFrom5(db *sql.DB) error { +func sqliteMigrateFrom5(db *sql.DB) error { log.Tag(tag).Info("Migrating user database schema: from 5 to 6") tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() - if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { + if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil { return err } - if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil { return err } return tx.Commit() diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index fdec1f29..661e18cd 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -2,6 +2,7 @@ package webpush import ( "database/sql" + "fmt" _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver ) @@ -102,14 +103,13 @@ func NewPostgresStore(dsn string) (Store, error) { } func setupPostgresDB(db *sql.DB) error { - // If 'schema_version' table does not exist or no webpush row, this must be a new database - rows, err := db.Query(pgSelectSchemaVersionQuery) + var schemaVersion int + err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion) if err != nil { return setupNewPostgresDB(db) } - defer rows.Close() - if !rows.Next() { - return setupNewPostgresDB(db) + if schemaVersion > pgCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion) } return nil } diff --git a/webpush/store_sqlite.go b/webpush/store_sqlite.go index 97bacf74..07823672 100644 --- a/webpush/store_sqlite.go +++ b/webpush/store_sqlite.go @@ -2,6 +2,7 @@ package webpush import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" // SQLite driver ) @@ -82,10 +83,10 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { if err != nil { return nil, err } - if err := setupSQLiteWebPushDB(db); err != nil { + if err := setupSQLite(db); err != nil { return nil, err } - if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil { + if err := runSQLiteStartupQueries(db, startupQueries); err != nil { return nil, err } return &commonStore{ @@ -108,16 +109,19 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { }, nil } -func setupSQLiteWebPushDB(db *sql.DB) error { - // If 'schemaVersion' table does not exist, this must be a new database - rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery) +func setupSQLite(db *sql.DB) error { + var schemaVersion int + err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion) if err != nil { - return setupNewSQLiteWebPushDB(db) + return setupNewSQLite(db) } - return rows.Close() + if schemaVersion > sqliteCurrentWebPushSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion) + } + return nil } -func setupNewSQLiteWebPushDB(db *sql.DB) error { +func setupNewSQLite(db *sql.DB) error { if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil { return err } @@ -127,7 +131,7 @@ func setupNewSQLiteWebPushDB(db *sql.DB) error { return nil } -func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error { +func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { if _, err := db.Exec(startupQueries); err != nil { return err }