diff --git a/user/manager.go b/user/manager.go index 1d900604..59aa883a 100644 --- a/user/manager.go +++ b/user/manager.go @@ -6,17 +6,17 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "path/filepath" + "slices" + "sync" + "time" + "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" - "net/netip" - "path/filepath" - "slices" - "strings" - "sync" - "time" ) const ( @@ -326,229 +326,6 @@ const ( ` ) -// Schema management queries -const ( - currentSchemaVersion = 6 - insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` - selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` - - // 1 -> 2 (complex migration!) - migrate1To2CreateTablesQueries = ` - ALTER TABLE user RENAME TO user_old; - CREATE TABLE IF NOT EXISTS tier ( - id TEXT PRIMARY KEY, - code TEXT NOT NULL, - name TEXT NOT NULL, - messages_limit INT NOT NULL, - messages_expiry_duration INT NOT NULL, - emails_limit INT NOT NULL, - reservations_limit INT NOT NULL, - attachment_file_size_limit INT NOT NULL, - attachment_total_size_limit INT NOT NULL, - attachment_expiry_duration INT NOT NULL, - attachment_bandwidth_limit INT NOT NULL, - stripe_price_id TEXT - ); - CREATE UNIQUE INDEX idx_tier_code ON tier (code); - CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE TABLE IF NOT EXISTS user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO user (id, user, pass, role, sync_topic, created) - VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) - ON CONFLICT (id) DO NOTHING; - ` - migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` - migrate1To2InsertUserNoTx = ` - INSERT INTO user (id, user, pass, role, sync_topic, created) - SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? - ` - migrate1To2InsertFromOldTablesAndDropNoTx = ` - INSERT INTO user_access (user_id, topic, read, write) - SELECT u.id, a.topic, a.read, a.write - FROM user u - JOIN access a ON u.user = a.user; - - DROP TABLE access; - DROP TABLE user_old; - ` - - // 2 -> 3 - migrate2To3UpdateQueries = ` - ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; - ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; - ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; - DROP INDEX IF EXISTS idx_tier_price_id; - CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); - CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); - ` - - // 3 -> 4 - migrate3To4UpdateQueries = ` - ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); - ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); - CREATE TABLE IF NOT EXISTS user_phone ( - user_id TEXT NOT NULL, - phone_number TEXT NOT NULL, - PRIMARY KEY (user_id, phone_number), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - ` - - // 4 -> 5 - migrate4To5UpdateQueries = ` - UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); - ` - - // 5 -> 6 - migrate5To6UpdateQueries = ` - PRAGMA foreign_keys=off; - - -- Alter user table: Add provisioned column - ALTER TABLE user RENAME TO user_old; - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - provisioned INT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stats_calls INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_interval TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - INSERT INTO user - SELECT - id, - tier_id, - user, - pass, - role, - prefs, - sync_topic, - 0, -- provisioned - stats_messages, - stats_emails, - stats_calls, - stripe_customer_id, - stripe_subscription_id, - stripe_subscription_status, - stripe_subscription_interval, - stripe_subscription_paid_until, - stripe_subscription_cancel_at, - created, - deleted - FROM user_old; - DROP TABLE user_old; - - -- Alter user_access table: Add provisioned column - ALTER TABLE user_access RENAME TO user_access_old; - CREATE TABLE user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - provisioned INTEGER NOT NULL, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - INSERT INTO user_access SELECT *, 0 FROM user_access_old; - DROP TABLE user_access_old; - - -- Alter user_token table: Add provisioned column - ALTER TABLE user_token RENAME TO user_token_old; - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - INSERT INTO user_token SELECT *, 0 FROM user_token_old; - DROP TABLE user_token_old; - - -- Recreate indices - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE UNIQUE INDEX idx_user_token ON user_token (token); - - -- Re-enable foreign keys - PRAGMA foreign_keys=on; - ` -) - -var ( - migrations = map[int]func(db *sql.DB) error{ - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, - } -) - // Manager is an implementation of Manager. It stores users and access control list // in a SQLite database. type Manager struct { @@ -1840,7 +1617,7 @@ func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, e return nil } -// maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. +// maybeProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. // // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last // access time) or do not have dependent resources (such as grants or tokens). @@ -1909,26 +1686,6 @@ func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) return nil } -// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, -// and escapes '_', assuming '\' as escape character. -func toSQLWildcard(s string) string { - return escapeUnderscore(strings.ReplaceAll(s, "*", "%")) -} - -// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*', -// and removes the '\_' escape character. -func fromSQLWildcard(s string) string { - return strings.ReplaceAll(unescapeUnderscore(s), "%", "*") -} - -func escapeUnderscore(s string) string { - return strings.ReplaceAll(s, "_", "\\_") -} - -func unescapeUnderscore(s string) string { - return strings.ReplaceAll(s, "\\_", "_") -} - func runStartupQueries(db *sql.DB, startupQueries string) error { if _, err := db.Exec(startupQueries); err != nil { return err @@ -1983,161 +1740,3 @@ func setupNewDB(db *sql.DB) error { } return nil } - -func migrateFrom1(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 1 to 2") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Rename user -> user_old, and create new tables - if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { - return err - } - // Insert users from user_old into new user table, with ID and sync_topic - rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) - if err != nil { - return err - } - defer rows.Close() - usernames := make([]string, 0) - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return err - } - usernames = append(usernames, username) - } - if err := rows.Close(); err != nil { - return err - } - for _, username := range usernames { - userID := util.RandomStringPrefix(userIDPrefix, userIDLength) - syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) - if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { - return err - } - } - // Migrate old "access" table to "user_access" and drop "access" and "user_old" - if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func migrateFrom2(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 2 to 3") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom3(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 3 to 4") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom4(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 4 to 5") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom5(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 5 to 6") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { - return err - } - return tx.Commit() -} - -func nullString(s string) sql.NullString { - if s == "" { - return sql.NullString{} - } - return sql.NullString{String: s, Valid: true} -} - -func nullInt64(v int64) sql.NullInt64 { - if v == 0 { - return sql.NullInt64{} - } - return sql.NullInt64{Int64: v, Valid: true} -} - -// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. -func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if err := f(tx); err != nil { - return err - } - return tx.Commit() -} - -// queryTx executes a function in a transaction and returns the result. If the function -// returns an error, the transaction is rolled back. -func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { - tx, err := db.Begin() - if err != nil { - var zero T - return zero, err - } - defer tx.Rollback() - t, err := f(tx) - if err != nil { - return t, err - } - if err := tx.Commit(); err != nil { - return t, err - } - return t, nil -} diff --git a/user/migrations.go b/user/migrations.go new file mode 100644 index 00000000..a7eb317f --- /dev/null +++ b/user/migrations.go @@ -0,0 +1,342 @@ +package user + +import ( + "database/sql" + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" +) + +// Schema management queries +const ( + currentSchemaVersion = 6 + insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` + selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` + + // 1 -> 2 (complex migration!) + migrate1To2CreateTablesQueries = ` + ALTER TABLE user RENAME TO user_old; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; + ` + migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` + migrate1To2InsertUserNoTx = ` + INSERT INTO user (id, user, pass, role, sync_topic, created) + SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? + ` + migrate1To2InsertFromOldTablesAndDropNoTx = ` + INSERT INTO user_access (user_id, topic, read, write) + SELECT u.id, a.topic, a.read, a.write + FROM user u + JOIN access a ON u.user = a.user; + + DROP TABLE access; + DROP TABLE user_old; + ` + + // 2 -> 3 + migrate2To3UpdateQueries = ` + ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; + ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; + ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; + DROP INDEX IF EXISTS idx_tier_price_id; + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + ` + + // 3 -> 4 + migrate3To4UpdateQueries = ` + ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); + ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + ` + + // 4 -> 5 + migrate4To5UpdateQueries = ` + UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); + ` + + // 5 -> 6 + migrate5To6UpdateQueries = ` + PRAGMA foreign_keys=off; + + -- Alter user table: Add provisioned column + ALTER TABLE user RENAME TO user_old; + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned INT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + INSERT INTO user + SELECT + id, + tier_id, + user, + pass, + role, + prefs, + sync_topic, + 0, -- provisioned + stats_messages, + stats_emails, + stats_calls, + stripe_customer_id, + stripe_subscription_id, + stripe_subscription_status, + stripe_subscription_interval, + stripe_subscription_paid_until, + stripe_subscription_cancel_at, + created, + deleted + FROM user_old; + DROP TABLE user_old; + + -- Alter user_access table: Add provisioned column + ALTER TABLE user_access RENAME TO user_access_old; + CREATE TABLE user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + provisioned INTEGER NOT NULL, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + INSERT INTO user_access SELECT *, 0 FROM user_access_old; + DROP TABLE user_access_old; + + -- Alter user_token table: Add provisioned column + ALTER TABLE user_token RENAME TO user_token_old; + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + INSERT INTO user_token SELECT *, 0 FROM user_token_old; + DROP TABLE user_token_old; + + -- Recreate indices + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE UNIQUE INDEX idx_user_token ON user_token (token); + + -- Re-enable foreign keys + PRAGMA foreign_keys=on; + ` +) + +var ( + migrations = map[int]func(db *sql.DB) error{ + 1: migrateFrom1, + 2: migrateFrom2, + 3: migrateFrom3, + 4: migrateFrom4, + 5: migrateFrom5, + } +) + +func migrateFrom1(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 1 to 2") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Rename user -> user_old, and create new tables + if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { + return err + } + // Insert users from user_old into new user table, with ID and sync_topic + rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) + if err != nil { + return err + } + defer rows.Close() + usernames := make([]string, 0) + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return err + } + usernames = append(usernames, username) + } + if err := rows.Close(); err != nil { + return err + } + for _, username := range usernames { + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) + if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { + return err + } + } + // Migrate old "access" table to "user_access" and drop "access" and "user_old" + if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func migrateFrom2(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 2 to 3") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom3(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 3 to 4") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom4(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 4 to 5") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom5(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 5 to 6") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { + return err + } + return tx.Commit() +} diff --git a/user/util.go b/user/util.go index 91230fae..170b7717 100644 --- a/user/util.go +++ b/user/util.go @@ -1,10 +1,12 @@ package user import ( - "golang.org/x/crypto/bcrypt" - "heckel.io/ntfy/v2/util" + "database/sql" "regexp" "strings" + + "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/util" ) var ( @@ -77,3 +79,69 @@ func hashPassword(password string, cost int) (string, error) { } return string(hash), nil } + +func nullString(s string) sql.NullString { + if s == "" { + return sql.NullString{} + } + return sql.NullString{String: s, Valid: true} +} + +func nullInt64(v int64) sql.NullInt64 { + if v == 0 { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: v, Valid: true} +} + +// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. +func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := f(tx); err != nil { + return err + } + return tx.Commit() +} + +// queryTx executes a function in a transaction and returns the result. If the function +// returns an error, the transaction is rolled back. +func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { + tx, err := db.Begin() + if err != nil { + var zero T + return zero, err + } + defer tx.Rollback() + t, err := f(tx) + if err != nil { + return t, err + } + if err := tx.Commit(); err != nil { + return t, err + } + return t, nil +} + +// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, +// and escapes '_', assuming '\' as escape character. +func toSQLWildcard(s string) string { + return escapeUnderscore(strings.ReplaceAll(s, "*", "%")) +} + +// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*', +// and removes the '\_' escape character. +func fromSQLWildcard(s string) string { + return strings.ReplaceAll(unescapeUnderscore(s), "%", "*") +} + +func escapeUnderscore(s string) string { + return strings.ReplaceAll(s, "_", "\\_") +} + +func unescapeUnderscore(s string) string { + return strings.ReplaceAll(s, "\\_", "_") +}