diff --git a/user/manager.go b/user/manager.go index 1d900604..59aa883a 100644 --- a/user/manager.go +++ b/user/manager.go @@ -6,17 +6,17 @@ import ( "encoding/json" "errors" "fmt" + "net/netip" + "path/filepath" + "slices" + "sync" + "time" + "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" - "net/netip" - "path/filepath" - "slices" - "strings" - "sync" - "time" ) const ( @@ -326,229 +326,6 @@ const ( ` ) -// Schema management queries -const ( - currentSchemaVersion = 6 - insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` - selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` - - // 1 -> 2 (complex migration!) - migrate1To2CreateTablesQueries = ` - ALTER TABLE user RENAME TO user_old; - CREATE TABLE IF NOT EXISTS tier ( - id TEXT PRIMARY KEY, - code TEXT NOT NULL, - name TEXT NOT NULL, - messages_limit INT NOT NULL, - messages_expiry_duration INT NOT NULL, - emails_limit INT NOT NULL, - reservations_limit INT NOT NULL, - attachment_file_size_limit INT NOT NULL, - attachment_total_size_limit INT NOT NULL, - attachment_expiry_duration INT NOT NULL, - attachment_bandwidth_limit INT NOT NULL, - stripe_price_id TEXT - ); - CREATE UNIQUE INDEX idx_tier_code ON tier (code); - CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE TABLE IF NOT EXISTS user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO user (id, user, pass, role, sync_topic, created) - VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) - ON CONFLICT (id) DO NOTHING; - ` - migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` - migrate1To2InsertUserNoTx = ` - INSERT INTO user (id, user, pass, role, sync_topic, created) - SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? - ` - migrate1To2InsertFromOldTablesAndDropNoTx = ` - INSERT INTO user_access (user_id, topic, read, write) - SELECT u.id, a.topic, a.read, a.write - FROM user u - JOIN access a ON u.user = a.user; - - DROP TABLE access; - DROP TABLE user_old; - ` - - // 2 -> 3 - migrate2To3UpdateQueries = ` - ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; - ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; - ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; - DROP INDEX IF EXISTS idx_tier_price_id; - CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); - CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); - ` - - // 3 -> 4 - migrate3To4UpdateQueries = ` - ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); - ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); - CREATE TABLE IF NOT EXISTS user_phone ( - user_id TEXT NOT NULL, - phone_number TEXT NOT NULL, - PRIMARY KEY (user_id, phone_number), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - ` - - // 4 -> 5 - migrate4To5UpdateQueries = ` - UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); - ` - - // 5 -> 6 - migrate5To6UpdateQueries = ` - PRAGMA foreign_keys=off; - - -- Alter user table: Add provisioned column - ALTER TABLE user RENAME TO user_old; - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - provisioned INT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stats_calls INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_interval TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - INSERT INTO user - SELECT - id, - tier_id, - user, - pass, - role, - prefs, - sync_topic, - 0, -- provisioned - stats_messages, - stats_emails, - stats_calls, - stripe_customer_id, - stripe_subscription_id, - stripe_subscription_status, - stripe_subscription_interval, - stripe_subscription_paid_until, - stripe_subscription_cancel_at, - created, - deleted - FROM user_old; - DROP TABLE user_old; - - -- Alter user_access table: Add provisioned column - ALTER TABLE user_access RENAME TO user_access_old; - CREATE TABLE user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - provisioned INTEGER NOT NULL, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - INSERT INTO user_access SELECT *, 0 FROM user_access_old; - DROP TABLE user_access_old; - - -- Alter user_token table: Add provisioned column - ALTER TABLE user_token RENAME TO user_token_old; - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - INSERT INTO user_token SELECT *, 0 FROM user_token_old; - DROP TABLE user_token_old; - - -- Recreate indices - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE UNIQUE INDEX idx_user_token ON user_token (token); - - -- Re-enable foreign keys - PRAGMA foreign_keys=on; - ` -) - -var ( - migrations = map[int]func(db *sql.DB) error{ - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, - } -) - // Manager is an implementation of Manager. It stores users and access control list // in a SQLite database. type Manager struct { @@ -1840,7 +1617,7 @@ func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, e return nil } -// maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. +// maybeProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. // // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last // access time) or do not have dependent resources (such as grants or tokens). @@ -1909,26 +1686,6 @@ func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) return nil } -// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, -// and escapes '_', assuming '\' as escape character. -func toSQLWildcard(s string) string { - return escapeUnderscore(strings.ReplaceAll(s, "*", "%")) -} - -// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*', -// and removes the '\_' escape character. -func fromSQLWildcard(s string) string { - return strings.ReplaceAll(unescapeUnderscore(s), "%", "*") -} - -func escapeUnderscore(s string) string { - return strings.ReplaceAll(s, "_", "\\_") -} - -func unescapeUnderscore(s string) string { - return strings.ReplaceAll(s, "\\_", "_") -} - func runStartupQueries(db *sql.DB, startupQueries string) error { if _, err := db.Exec(startupQueries); err != nil { return err @@ -1983,161 +1740,3 @@ func setupNewDB(db *sql.DB) error { } return nil } - -func migrateFrom1(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 1 to 2") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // Rename user -> user_old, and create new tables - if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { - return err - } - // Insert users from user_old into new user table, with ID and sync_topic - rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) - if err != nil { - return err - } - defer rows.Close() - usernames := make([]string, 0) - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return err - } - usernames = append(usernames, username) - } - if err := rows.Close(); err != nil { - return err - } - for _, username := range usernames { - userID := util.RandomStringPrefix(userIDPrefix, userIDLength) - syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) - if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { - return err - } - } - // Migrate old "access" table to "user_access" and drop "access" and "user_old" - if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { - return err - } - if err := tx.Commit(); err != nil { - return err - } - return nil -} - -func migrateFrom2(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 2 to 3") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom3(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 3 to 4") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom4(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 4 to 5") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom5(db *sql.DB) error { - log.Tag(tag).Info("Migrating user database schema: from 5 to 6") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { - return err - } - return tx.Commit() -} - -func nullString(s string) sql.NullString { - if s == "" { - return sql.NullString{} - } - return sql.NullString{String: s, Valid: true} -} - -func nullInt64(v int64) sql.NullInt64 { - if v == 0 { - return sql.NullInt64{} - } - return sql.NullInt64{Int64: v, Valid: true} -} - -// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. -func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if err := f(tx); err != nil { - return err - } - return tx.Commit() -} - -// queryTx executes a function in a transaction and returns the result. If the function -// returns an error, the transaction is rolled back. -func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { - tx, err := db.Begin() - if err != nil { - var zero T - return zero, err - } - defer tx.Rollback() - t, err := f(tx) - if err != nil { - return t, err - } - if err := tx.Commit(); err != nil { - return t, err - } - return t, nil -} diff --git a/user/migrations.go b/user/migrations.go new file mode 100644 index 00000000..a7eb317f --- /dev/null +++ b/user/migrations.go @@ -0,0 +1,342 @@ +package user + +import ( + "database/sql" + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" +) + +// Schema management queries +const ( + currentSchemaVersion = 6 + insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` + selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` + + // 1 -> 2 (complex migration!) + migrate1To2CreateTablesQueries = ` + ALTER TABLE user RENAME TO user_old; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, created) + VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; + ` + migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old` + migrate1To2InsertUserNoTx = ` + INSERT INTO user (id, user, pass, role, sync_topic, created) + SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ? + ` + migrate1To2InsertFromOldTablesAndDropNoTx = ` + INSERT INTO user_access (user_id, topic, read, write) + SELECT u.id, a.topic, a.read, a.write + FROM user u + JOIN access a ON u.user = a.user; + + DROP TABLE access; + DROP TABLE user_old; + ` + + // 2 -> 3 + migrate2To3UpdateQueries = ` + ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT; + ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id; + ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT; + DROP INDEX IF EXISTS idx_tier_price_id; + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + ` + + // 3 -> 4 + migrate3To4UpdateQueries = ` + ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0); + ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + ` + + // 4 -> 5 + migrate4To5UpdateQueries = ` + UPDATE user_access SET topic = REPLACE(topic, '_', '\_'); + ` + + // 5 -> 6 + migrate5To6UpdateQueries = ` + PRAGMA foreign_keys=off; + + -- Alter user table: Add provisioned column + ALTER TABLE user RENAME TO user_old; + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned INT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + INSERT INTO user + SELECT + id, + tier_id, + user, + pass, + role, + prefs, + sync_topic, + 0, -- provisioned + stats_messages, + stats_emails, + stats_calls, + stripe_customer_id, + stripe_subscription_id, + stripe_subscription_status, + stripe_subscription_interval, + stripe_subscription_paid_until, + stripe_subscription_cancel_at, + created, + deleted + FROM user_old; + DROP TABLE user_old; + + -- Alter user_access table: Add provisioned column + ALTER TABLE user_access RENAME TO user_access_old; + CREATE TABLE user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + provisioned INTEGER NOT NULL, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + INSERT INTO user_access SELECT *, 0 FROM user_access_old; + DROP TABLE user_access_old; + + -- Alter user_token table: Add provisioned column + ALTER TABLE user_token RENAME TO user_token_old; + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + INSERT INTO user_token SELECT *, 0 FROM user_token_old; + DROP TABLE user_token_old; + + -- Recreate indices + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE UNIQUE INDEX idx_user_token ON user_token (token); + + -- Re-enable foreign keys + PRAGMA foreign_keys=on; + ` +) + +var ( + migrations = map[int]func(db *sql.DB) error{ + 1: migrateFrom1, + 2: migrateFrom2, + 3: migrateFrom3, + 4: migrateFrom4, + 5: migrateFrom5, + } +) + +func migrateFrom1(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 1 to 2") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Rename user -> user_old, and create new tables + if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil { + return err + } + // Insert users from user_old into new user table, with ID and sync_topic + rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx) + if err != nil { + return err + } + defer rows.Close() + usernames := make([]string, 0) + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return err + } + usernames = append(usernames, username) + } + if err := rows.Close(); err != nil { + return err + } + for _, username := range usernames { + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) + if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil { + return err + } + } + // Migrate old "access" table to "user_access" and drop "access" and "user_old" + if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 2); err != nil { + return err + } + if err := tx.Commit(); err != nil { + return err + } + return nil +} + +func migrateFrom2(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 2 to 3") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 3); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom3(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 3 to 4") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 4); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom4(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 4 to 5") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 5); err != nil { + return err + } + return tx.Commit() +} + +func migrateFrom5(db *sql.DB) error { + log.Tag(tag).Info("Migrating user database schema: from 5 to 6") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil { + return err + } + if _, err := tx.Exec(updateSchemaVersion, 6); err != nil { + return err + } + return tx.Commit() +} diff --git a/user/util.go b/user/util.go index 91230fae..170b7717 100644 --- a/user/util.go +++ b/user/util.go @@ -1,10 +1,12 @@ package user import ( - "golang.org/x/crypto/bcrypt" - "heckel.io/ntfy/v2/util" + "database/sql" "regexp" "strings" + + "golang.org/x/crypto/bcrypt" + "heckel.io/ntfy/v2/util" ) var ( @@ -77,3 +79,69 @@ func hashPassword(password string, cost int) (string, error) { } return string(hash), nil } + +func nullString(s string) sql.NullString { + if s == "" { + return sql.NullString{} + } + return sql.NullString{String: s, Valid: true} +} + +func nullInt64(v int64) sql.NullInt64 { + if v == 0 { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: v, Valid: true} +} + +// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back. +func execTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := f(tx); err != nil { + return err + } + return tx.Commit() +} + +// queryTx executes a function in a transaction and returns the result. If the function +// returns an error, the transaction is rolled back. +func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) { + tx, err := db.Begin() + if err != nil { + var zero T + return zero, err + } + defer tx.Rollback() + t, err := f(tx) + if err != nil { + return t, err + } + if err := tx.Commit(); err != nil { + return t, err + } + return t, nil +} + +// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, +// and escapes '_', assuming '\' as escape character. +func toSQLWildcard(s string) string { + return escapeUnderscore(strings.ReplaceAll(s, "*", "%")) +} + +// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*', +// and removes the '\_' escape character. +func fromSQLWildcard(s string) string { + return strings.ReplaceAll(unescapeUnderscore(s), "%", "*") +} + +func escapeUnderscore(s string) string { + return strings.ReplaceAll(s, "_", "\\_") +} + +func unescapeUnderscore(s string) string { + return strings.ReplaceAll(s, "\\_", "_") +} diff --git a/user/util_test.go b/user/util_test.go new file mode 100644 index 00000000..97c4bc4a --- /dev/null +++ b/user/util_test.go @@ -0,0 +1,281 @@ +package user + +import ( + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestAllowedRole(t *testing.T) { + require.True(t, AllowedRole(RoleUser)) + require.True(t, AllowedRole(RoleAdmin)) + require.False(t, AllowedRole(RoleAnonymous)) + require.False(t, AllowedRole(Role("invalid"))) + require.False(t, AllowedRole(Role(""))) + require.False(t, AllowedRole(Role("superadmin"))) +} + +func TestAllowedTopic(t *testing.T) { + // Valid topics + require.True(t, AllowedTopic("test")) + require.True(t, AllowedTopic("mytopic")) + require.True(t, AllowedTopic("topic123")) + require.True(t, AllowedTopic("my-topic")) + require.True(t, AllowedTopic("my_topic")) + require.True(t, AllowedTopic("Topic123")) + require.True(t, AllowedTopic("a")) + require.True(t, AllowedTopic(strings.Repeat("a", 64))) // Max length + + // Invalid topics - wildcards not allowed + require.False(t, AllowedTopic("topic*")) + require.False(t, AllowedTopic("*")) + require.False(t, AllowedTopic("my*topic")) + + // Invalid topics - special characters + require.False(t, AllowedTopic("my topic")) // Space + require.False(t, AllowedTopic("my.topic")) // Dot + require.False(t, AllowedTopic("my/topic")) // Slash + require.False(t, AllowedTopic("my@topic")) // At sign + require.False(t, AllowedTopic("my+topic")) // Plus + require.False(t, AllowedTopic("topic!")) // Exclamation + require.False(t, AllowedTopic("topic#")) // Hash + require.False(t, AllowedTopic("topic$")) // Dollar + require.False(t, AllowedTopic("topic%")) // Percent + require.False(t, AllowedTopic("topic&")) // Ampersand + require.False(t, AllowedTopic("my\\topic")) // Backslash + + // Invalid topics - length + require.False(t, AllowedTopic("")) // Empty + require.False(t, AllowedTopic(strings.Repeat("a", 65))) // Too long +} + +func TestAllowedTopicPattern(t *testing.T) { + // Valid patterns - same as AllowedTopic + require.True(t, AllowedTopicPattern("test")) + require.True(t, AllowedTopicPattern("mytopic")) + require.True(t, AllowedTopicPattern("topic123")) + require.True(t, AllowedTopicPattern("my-topic")) + require.True(t, AllowedTopicPattern("my_topic")) + require.True(t, AllowedTopicPattern("a")) + require.True(t, AllowedTopicPattern(strings.Repeat("a", 64))) // Max length + + // Valid patterns - with wildcards + require.True(t, AllowedTopicPattern("*")) + require.True(t, AllowedTopicPattern("topic*")) + require.True(t, AllowedTopicPattern("*topic")) + require.True(t, AllowedTopicPattern("my*topic")) + require.True(t, AllowedTopicPattern("***")) + require.True(t, AllowedTopicPattern("test_*")) + require.True(t, AllowedTopicPattern("my-*-topic")) + require.True(t, AllowedTopicPattern(strings.Repeat("*", 64))) // Max length with wildcards + + // Invalid patterns - special characters (other than wildcard) + require.False(t, AllowedTopicPattern("my topic")) // Space + require.False(t, AllowedTopicPattern("my.topic")) // Dot + require.False(t, AllowedTopicPattern("my/topic")) // Slash + require.False(t, AllowedTopicPattern("my@topic")) // At sign + require.False(t, AllowedTopicPattern("my+topic")) // Plus + require.False(t, AllowedTopicPattern("topic!")) // Exclamation + require.False(t, AllowedTopicPattern("topic#")) // Hash + require.False(t, AllowedTopicPattern("topic$")) // Dollar + require.False(t, AllowedTopicPattern("topic%")) // Percent + require.False(t, AllowedTopicPattern("topic&")) // Ampersand + require.False(t, AllowedTopicPattern("my\\topic")) // Backslash + + // Invalid patterns - length + require.False(t, AllowedTopicPattern("")) // Empty + require.False(t, AllowedTopicPattern(strings.Repeat("a", 65))) // Too long +} + +func TestValidPasswordHash(t *testing.T) { + // Valid bcrypt hashes with different versions + require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10)) + require.Nil(t, ValidPasswordHash("$2b$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", 10)) + require.Nil(t, ValidPasswordHash("$2y$12$1234567890123456789012u1234567890123456789012345678901", 10)) + + // Valid hash with minimum cost + require.Nil(t, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 4)) + + // Invalid - wrong prefix + require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$2c$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10)) + require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("$3a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10)) + require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("bcrypt$10$hash", 10)) + require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("nothash", 10)) + require.Equal(t, ErrPasswordHashInvalid, ValidPasswordHash("", 10)) + + // Invalid - malformed hash + require.NotNil(t, ValidPasswordHash("$2a$10$tooshort", 10)) + require.NotNil(t, ValidPasswordHash("$2a$10", 10)) + require.NotNil(t, ValidPasswordHash("$2a$", 10)) + + // Invalid - cost too low + require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$04$1234567890123456789012u1234567890123456789012345678901", 10)) + require.Equal(t, ErrPasswordHashWeak, ValidPasswordHash("$2a$09$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10)) + + // Edge case - cost exactly at minimum + require.Nil(t, ValidPasswordHash("$2a$10$EEp6gBheOsqEFsXlo523E.gBVoeg1ytphXiEvTPlNzkenBlHZBPQy", 10)) +} + +func TestValidToken(t *testing.T) { + // Valid tokens + require.True(t, ValidToken("tk_1234567890123456789012345678x")) + require.True(t, ValidToken("tk_abcdefghijklmnopqrstuvwxyzabc")) + require.True(t, ValidToken("tk_ABCDEFGHIJKLMNOPQRSTUVWXYZABC")) + require.True(t, ValidToken("tk_012345678901234567890123456ab")) + require.True(t, ValidToken("tk_-----------------------------")) + require.True(t, ValidToken("tk______________________________")) + + // Invalid tokens - wrong prefix + require.False(t, ValidToken("tx_1234567890123456789012345678x")) + require.False(t, ValidToken("tk1234567890123456789012345678xy")) + require.False(t, ValidToken("token_1234567890123456789012345")) + + // Invalid tokens - wrong length + require.False(t, ValidToken("tk_")) // Too short + require.False(t, ValidToken("tk_123")) // Too short + require.False(t, ValidToken("tk_123456789012345678901234567890")) // Too long (30 chars after prefix) + require.False(t, ValidToken("tk_123456789012345678901234567")) // Too short (28 chars) + + // Invalid tokens - invalid characters + require.False(t, ValidToken("tk_123456789012345678901234567!@")) + require.False(t, ValidToken("tk_12345678901234567890123456 8x")) + require.False(t, ValidToken("tk_123456789012345678901234567.x")) + require.False(t, ValidToken("tk_123456789012345678901234567*x")) + + // Invalid tokens - no prefix + require.False(t, ValidToken("1234567890123456789012345678901x")) + require.False(t, ValidToken("")) +} + +func TestGenerateToken(t *testing.T) { + // Generate multiple tokens + tokens := make(map[string]bool) + for i := 0; i < 100; i++ { + token := GenerateToken() + + // Check format + require.True(t, strings.HasPrefix(token, "tk_"), "Token should start with tk_") + require.Equal(t, 32, len(token), "Token should be 32 characters long") + + // Check it's valid + require.True(t, ValidToken(token), "Generated token should be valid") + + // Check it's lowercase + require.Equal(t, strings.ToLower(token), token, "Token should be lowercase") + + // Check uniqueness + require.False(t, tokens[token], "Token should be unique") + tokens[token] = true + } + + // Verify we got 100 unique tokens + require.Equal(t, 100, len(tokens)) +} + +func TestHashPassword(t *testing.T) { + password := "test-password-123" + + // Hash the password + hash, err := HashPassword(password) + require.Nil(t, err) + require.NotEmpty(t, hash) + + // Check it's a valid bcrypt hash + require.Nil(t, ValidPasswordHash(hash, DefaultUserPasswordBcryptCost)) + + // Check it starts with correct prefix + require.True(t, strings.HasPrefix(hash, "$2a$")) + + // Hash the same password again - should produce different hash + hash2, err := HashPassword(password) + require.Nil(t, err) + require.NotEqual(t, hash, hash2, "Same password should produce different hashes (salt)") + + // Empty password should still work + emptyHash, err := HashPassword("") + require.Nil(t, err) + require.NotEmpty(t, emptyHash) + require.Nil(t, ValidPasswordHash(emptyHash, DefaultUserPasswordBcryptCost)) +} + +func TestHashPassword_WithCost(t *testing.T) { + password := "test-password" + + // Test with different costs + hash4, err := hashPassword(password, 4) + require.Nil(t, err) + require.True(t, strings.HasPrefix(hash4, "$2a$04$")) + + hash10, err := hashPassword(password, 10) + require.Nil(t, err) + require.True(t, strings.HasPrefix(hash10, "$2a$10$")) + + hash12, err := hashPassword(password, 12) + require.Nil(t, err) + require.True(t, strings.HasPrefix(hash12, "$2a$12$")) + + // All should be valid + require.Nil(t, ValidPasswordHash(hash4, 4)) + require.Nil(t, ValidPasswordHash(hash10, 10)) + require.Nil(t, ValidPasswordHash(hash12, 12)) +} + +func TestUser_TierID(t *testing.T) { + // User with tier + u := &User{ + Tier: &Tier{ + ID: "ti_123", + Code: "pro", + }, + } + require.Equal(t, "ti_123", u.TierID()) + + // User without tier + u2 := &User{ + Tier: nil, + } + require.Equal(t, "", u2.TierID()) + + // Nil user + var u3 *User + require.Equal(t, "", u3.TierID()) +} + +func TestUser_IsAdmin(t *testing.T) { + admin := &User{Role: RoleAdmin} + require.True(t, admin.IsAdmin()) + require.False(t, admin.IsUser()) + + user := &User{Role: RoleUser} + require.False(t, user.IsAdmin()) + + anonymous := &User{Role: RoleAnonymous} + require.False(t, anonymous.IsAdmin()) + + // Nil user + var nilUser *User + require.False(t, nilUser.IsAdmin()) +} + +func TestUser_IsUser(t *testing.T) { + user := &User{Role: RoleUser} + require.True(t, user.IsUser()) + require.False(t, user.IsAdmin()) + + admin := &User{Role: RoleAdmin} + require.False(t, admin.IsUser()) + + anonymous := &User{Role: RoleAnonymous} + require.False(t, anonymous.IsUser()) + + // Nil user + var nilUser *User + require.False(t, nilUser.IsUser()) +} + +func TestPermission_String(t *testing.T) { + require.Equal(t, "read-write", PermissionReadWrite.String()) + require.Equal(t, "read-only", PermissionRead.String()) + require.Equal(t, "write-only", PermissionWrite.String()) + require.Equal(t, "deny-all", PermissionDenyAll.String()) +}