From 07c3e280bf670597dfb3cbc1f5ad678cef6faaf4 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Feb 2026 22:39:54 -0500 Subject: [PATCH] Refactor user package to Store interface with PostgreSQL support Extract database operations from Manager into a Store interface with SQLite and PostgreSQL implementations using a shared commonStore. Split SQLite migrations into store_sqlite_migrations.go, use shared schema_version table for PostgreSQL, rename user_user/user_tier tables to "user"/tier, and wire up database-url in CLI commands. --- cmd/user.go | 25 +- server/server.go | 14 +- user/manager.go | 1193 ++--------------- user/manager_test.go | 89 +- user/store.go | 992 ++++++++++++++ user/store_postgres.go | 424 ++++++ user/store_sqlite.go | 420 ++++++ ...grations.go => store_sqlite_migrations.go} | 3 +- user/types.go | 14 + webpush/store_postgres.go | 16 +- 10 files changed, 2053 insertions(+), 1137 deletions(-) create mode 100644 user/store.go create mode 100644 user/store_postgres.go create mode 100644 user/store_sqlite.go rename user/{migrations.go => store_sqlite_migrations.go} (99%) diff --git a/cmd/user.go b/cmd/user.go index 6bf7030e..bf822970 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -29,6 +29,7 @@ var flagsUser = append( &cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"}, altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}), ) var cmdUser = &cli.Command{ @@ -365,24 +366,32 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { authFile := c.String("auth-file") authStartupQueries := c.String("auth-startup-queries") authDefaultAccess := c.String("auth-default-access") - if authFile == "" { - return nil, errors.New("option auth-file not set; auth is unconfigured for this server") - } else if !util.FileExists(authFile) { - return nil, errors.New("auth-file does not exist; please start the server at least once to create it") - } + databaseURL := c.String("database-url") authDefault, err := user.ParsePermission(authDefaultAccess) if err != nil { return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'") } authConfig := &user.Config{ - Filename: authFile, - StartupQueries: authStartupQueries, DefaultAccess: authDefault, ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization BcryptCost: user.DefaultUserPasswordBcryptCost, QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, } - return user.NewManager(authConfig) + var store user.Store + if databaseURL != "" { + store, err = user.NewPostgresStore(databaseURL) + } else if authFile != "" { + if !util.FileExists(authFile) { + return nil, errors.New("auth-file does not exist; please start the server at least once to create it") + } + store, err = user.NewSQLiteStore(authFile, authStartupQueries) + } else { + return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server") + } + if err != nil { + return nil, err + } + return user.NewManager(store, authConfig) } func readPasswordAndConfirm(c *cli.Context) (string, error) { diff --git a/server/server.go b/server/server.go index de8af35f..ea03355b 100644 --- a/server/server.go +++ b/server/server.go @@ -204,9 +204,10 @@ func New(conf *Config) (*Server, error) { } } var userManager *user.Manager - if conf.AuthFile != "" { + if conf.AuthFile != "" || conf.DatabaseURL != "" { authConfig := &user.Config{ Filename: conf.AuthFile, + DatabaseURL: conf.DatabaseURL, StartupQueries: conf.AuthStartupQueries, DefaultAccess: conf.AuthDefault, ProvisionEnabled: true, // Enable provisioning of users and access @@ -216,7 +217,16 @@ func New(conf *Config) (*Server, error) { BcryptCost: conf.AuthBcryptCost, QueueWriterInterval: conf.AuthStatsQueueWriterInterval, } - userManager, err = user.NewManager(authConfig) + var store user.Store + if conf.DatabaseURL != "" { + store, err = user.NewPostgresStore(conf.DatabaseURL) + } else { + store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries) + } + if err != nil { + return nil, err + } + userManager, err = user.NewManager(store, authConfig) if err != nil { return nil, err } diff --git a/user/manager.go b/user/manager.go index 59aa883a..1e1b27f8 100644 --- a/user/manager.go +++ b/user/manager.go @@ -2,20 +2,15 @@ package user import ( - "database/sql" - "encoding/json" "errors" "fmt" "net/netip" - "path/filepath" "slices" "sync" "time" - "github.com/mattn/go-sqlite3" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" ) @@ -46,313 +41,19 @@ var ( errNoRows = errors.New("no rows found") ) -// Manager-related queries -const ( - createTablesQueries = ` - BEGIN; - CREATE TABLE IF NOT EXISTS tier ( - id TEXT PRIMARY KEY, - code TEXT NOT NULL, - name TEXT NOT NULL, - messages_limit INT NOT NULL, - messages_expiry_duration INT NOT NULL, - emails_limit INT NOT NULL, - calls_limit INT NOT NULL, - reservations_limit INT NOT NULL, - attachment_file_size_limit INT NOT NULL, - attachment_total_size_limit INT NOT NULL, - attachment_expiry_duration INT NOT NULL, - attachment_bandwidth_limit INT NOT NULL, - stripe_monthly_price_id TEXT, - stripe_yearly_price_id TEXT - ); - CREATE UNIQUE INDEX idx_tier_code ON tier (code); - CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); - CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); - CREATE TABLE IF NOT EXISTS user ( - id TEXT PRIMARY KEY, - tier_id TEXT, - user TEXT NOT NULL, - pass TEXT NOT NULL, - role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, - prefs JSON NOT NULL DEFAULT '{}', - sync_topic TEXT NOT NULL, - provisioned INT NOT NULL, - stats_messages INT NOT NULL DEFAULT (0), - stats_emails INT NOT NULL DEFAULT (0), - stats_calls INT NOT NULL DEFAULT (0), - stripe_customer_id TEXT, - stripe_subscription_id TEXT, - stripe_subscription_status TEXT, - stripe_subscription_interval TEXT, - stripe_subscription_paid_until INT, - stripe_subscription_cancel_at INT, - created INT NOT NULL, - deleted INT, - FOREIGN KEY (tier_id) REFERENCES tier (id) - ); - CREATE UNIQUE INDEX idx_user ON user (user); - CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); - CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); - CREATE TABLE IF NOT EXISTS user_access ( - user_id TEXT NOT NULL, - topic TEXT NOT NULL, - read INT NOT NULL, - write INT NOT NULL, - owner_user_id INT, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, topic), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, - FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS user_token ( - user_id TEXT NOT NULL, - token TEXT NOT NULL, - label TEXT NOT NULL, - last_access INT NOT NULL, - last_origin TEXT NOT NULL, - expires INT NOT NULL, - provisioned INT NOT NULL, - PRIMARY KEY (user_id, token), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE UNIQUE INDEX idx_user_token ON user_token (token); - CREATE TABLE IF NOT EXISTS user_phone ( - user_id TEXT NOT NULL, - phone_number TEXT NOT NULL, - PRIMARY KEY (user_id, phone_number), - FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE - ); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) - VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH()) - ON CONFLICT (id) DO NOTHING; - COMMIT; - ` - - builtinStartupQueries = ` - PRAGMA foreign_keys = ON; - ` - - selectUserByIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id - FROM user u - LEFT JOIN tier t on t.id = u.tier_id - WHERE u.id = ? - ` - selectUserByNameQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id - FROM user u - LEFT JOIN tier t on t.id = u.tier_id - WHERE user = ? - ` - selectUserByTokenQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id - FROM user u - JOIN user_token tk on u.id = tk.user_id - LEFT JOIN tier t on t.id = u.tier_id - WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) - ` - selectUserByStripeCustomerIDQuery = ` - SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id - FROM user u - LEFT JOIN tier t on t.id = u.tier_id - WHERE u.stripe_customer_id = ? - ` - selectTopicPermsQuery = ` - SELECT read, write - FROM user_access a - JOIN user u ON u.id = a.user_id - WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\' - ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC - ` - - insertUserQuery = ` - INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) - VALUES (?, ?, ?, ?, ?, ?, ?) - ` - selectUsernamesQuery = ` - SELECT user - FROM user - ORDER BY - CASE role - WHEN 'admin' THEN 1 - WHEN 'anonymous' THEN 3 - ELSE 2 - END, user - ` - selectUserCountQuery = `SELECT COUNT(*) FROM user` - selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?` - updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?` - updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?` - updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?` - updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?` - updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?` - updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0` - updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?` - deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?` - deleteUserQuery = `DELETE FROM user WHERE user = ?` - - upsertUserAccessQuery = ` - INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) - VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?) - ON CONFLICT (user_id, topic) - DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned - ` - selectUserAllAccessQuery = ` - SELECT user_id, topic, read, write, provisioned - FROM user_access - ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic - ` - selectUserAccessQuery = ` - SELECT topic, read, write, provisioned - FROM user_access - WHERE user_id = (SELECT id FROM user WHERE user = ?) - ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic - ` - selectUserReservationsQuery = ` - SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write - FROM user_access a_user - LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?) - WHERE a_user.user_id = a_user.owner_user_id - AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?) - ORDER BY a_user.topic - ` - selectUserReservationsCountQuery = ` - SELECT COUNT(*) - FROM user_access - WHERE user_id = owner_user_id - AND owner_user_id = (SELECT id FROM user WHERE user = ?) - ` - selectUserReservationsOwnerQuery = ` - SELECT owner_user_id - FROM user_access - WHERE topic = ? - AND user_id = owner_user_id - ` - selectUserHasReservationQuery = ` - SELECT COUNT(*) - FROM user_access - WHERE user_id = owner_user_id - AND owner_user_id = (SELECT id FROM user WHERE user = ?) - AND topic = ? - ` - selectOtherAccessCountQuery = ` - SELECT COUNT(*) - FROM user_access - WHERE (topic = ? OR ? LIKE topic ESCAPE '\') - AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) - ` - deleteAllAccessQuery = `DELETE FROM user_access` - deleteUserAccessQuery = ` - DELETE FROM user_access - WHERE user_id = (SELECT id FROM user WHERE user = ?) - OR owner_user_id = (SELECT id FROM user WHERE user = ?) - ` - deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1` - deleteTopicAccessQuery = ` - DELETE FROM user_access - WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?)) - AND topic = ? - ` - - selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` - selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?` - selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?` - selectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1` - upsertTokenQuery = ` - INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) - VALUES (?, ?, ?, ?, ?, ?, ?) - ON CONFLICT (user_id, token) - DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned; - ` - updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` - updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` - updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` - deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` - deleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?` - deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` - deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` - deleteExcessTokensQuery = ` - DELETE FROM user_token - WHERE user_id = ? - AND (user_id, token) NOT IN ( - SELECT user_id, token - FROM user_token - WHERE user_id = ? - ORDER BY expires DESC - LIMIT ? - ) - ` - - selectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?` - insertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)` - deletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?` - - insertTierQuery = ` - INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - updateTierQuery = ` - UPDATE tier - SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? - WHERE code = ? - ` - selectTiersQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id - FROM tier - ` - selectTierByCodeQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id - FROM tier - WHERE code = ? - ` - selectTierByPriceIDQuery = ` - SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id - FROM tier - WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?) - ` - updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` - deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?` - deleteTierQuery = `DELETE FROM tier WHERE code = ?` - - updateBillingQuery = ` - UPDATE user - SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ? - WHERE user = ? - ` -) - -// Manager is an implementation of Manager. It stores users and access control list -// in a SQLite database. +// Manager handles user authentication, authorization, and management type Manager struct { config *Config - db *sql.DB + store Store // Database store statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate) mu sync.Mutex } -// Config holds the configuration for the user Manager -type Config struct { - Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" - StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers - DefaultAccess Permission // Default permission if no ACL matches - ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands - Users []*User // Predefined users to create on startup - Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant) - Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token) - QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database - BcryptCost int // Cost of generated passwords; lowering makes testing faster -} - var _ Auther = (*Manager)(nil) // NewManager creates a new Manager instance -func NewManager(config *Config) (*Manager, error) { +func NewManager(store Store, config *Config) (*Manager, error) { // Set defaults if config.BcryptCost <= 0 { config.BcryptCost = DefaultUserPasswordBcryptCost @@ -360,24 +61,8 @@ func NewManager(config *Config) (*Manager, error) { if config.QueueWriterInterval.Seconds() <= 0 { config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval } - // Check the parent directory of the database file (makes for friendly error messages) - parentDir := filepath.Dir(config.Filename) - if !util.FileExists(parentDir) { - return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir) - } - // Open DB and run setup queries - db, err := sql.Open("sqlite3", config.Filename) - if err != nil { - return nil, err - } - if err := setupDB(db); err != nil { - return nil, err - } - if err := runStartupQueries(db, config.StartupQueries); err != nil { - return nil, err - } manager := &Manager{ - db: db, + store: store, config: config, statsQueue: make(map[string]*Stats), tokenQueue: make(map[string]*TokenUpdate), @@ -396,7 +81,7 @@ func (a *Manager) Authenticate(username, password string) (*User, error) { if username == Everyone { return nil, ErrUnauthenticated } - user, err := a.User(username) + user, err := a.store.User(username) if err != nil { log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)") bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) @@ -418,7 +103,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { if len(token) != tokenLength { return nil, ErrUnauthenticated } - user, err := a.userByToken(token) + user, err := a.store.UserByToken(token) if err != nil { log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed") return nil, ErrUnauthenticated @@ -431,118 +116,34 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { - return queryTx(a.db, func(tx *sql.Tx) (*Token, error) { - return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned) - }) -} - -func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { + token := GenerateToken() access := time.Now() - if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil { - return nil, err - } - rows, err := tx.Query(selectTokenCountQuery, userID) + // Create the token + createdToken, err := a.store.CreateToken(userID, token, label, access, origin, expires, provisioned) if err != nil { return nil, err } - defer rows.Close() - if !rows.Next() { - return nil, errNoRows - } - var tokenCount int - if err := rows.Scan(&tokenCount); err != nil { + // Check token count and prune if necessary + tokenCount, err := a.store.TokenCount(userID) + if err != nil { return nil, err } if tokenCount >= tokenMaxCount { - // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup - // on two indices, whereas the query below is a full table scan. - if _, err := tx.Exec(deleteExcessTokensQuery, userID, userID, tokenMaxCount); err != nil { + if err := a.store.RemoveExcessTokens(userID, tokenMaxCount); err != nil { return nil, err } } - return &Token{ - Value: token, - Label: label, - LastAccess: access, - LastOrigin: origin, - Expires: expires, - Provisioned: provisioned, - }, nil + return createdToken, nil } // Tokens returns all existing tokens for the user with the given user ID func (a *Manager) Tokens(userID string) ([]*Token, error) { - rows, err := a.db.Query(selectTokensQuery, userID) - if err != nil { - return nil, err - } - defer rows.Close() - tokens := make([]*Token, 0) - for { - token, err := a.readToken(rows) - if errors.Is(err, ErrTokenNotFound) { - break - } else if err != nil { - return nil, err - } - tokens = append(tokens, token) - } - return tokens, nil -} - -func (a *Manager) allProvisionedTokens() ([]*Token, error) { - rows, err := a.db.Query(selectAllProvisionedTokensQuery) - if err != nil { - return nil, err - } - defer rows.Close() - tokens := make([]*Token, 0) - for { - token, err := a.readToken(rows) - if errors.Is(err, ErrTokenNotFound) { - break - } else if err != nil { - return nil, err - } - tokens = append(tokens, token) - } - return tokens, nil + return a.store.Tokens(userID) } // Token returns a specific token for a user func (a *Manager) Token(userID, token string) (*Token, error) { - rows, err := a.db.Query(selectTokenQuery, userID, token) - if err != nil { - return nil, err - } - defer rows.Close() - return a.readToken(rows) -} - -func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { - var token, label, lastOrigin string - var lastAccess, expires int64 - var provisioned bool - if !rows.Next() { - return nil, ErrTokenNotFound - } - if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - lastOriginIP, err := netip.ParseAddr(lastOrigin) - if err != nil { - lastOriginIP = netip.IPv4Unspecified() - } - return &Token{ - Value: token, - Label: label, - LastAccess: time.Unix(lastAccess, 0), - LastOrigin: lastOriginIP, - Expires: time.Unix(expires, 0), - Provisioned: provisioned, - }, nil + return a.store.Token(userID, token) } // ChangeToken updates a token's label and/or expiry date @@ -553,24 +154,16 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time if err := a.CanChangeToken(userID, token); err != nil { return nil, err } - tx, err := a.db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() if label != nil { - if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil { + if err := a.store.ChangeTokenLabel(userID, token, *label); err != nil { return nil, err } } if expires != nil { - if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil { + if err := a.store.ChangeTokenExpiry(userID, token, *expires); err != nil { return nil, err } } - if err := tx.Commit(); err != nil { - return nil, err - } return a.Token(userID, token) } @@ -579,19 +172,7 @@ func (a *Manager) RemoveToken(userID, token string) error { if err := a.CanChangeToken(userID, token); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { - return a.removeTokenTx(tx, userID, token) - }) -} - -func (a *Manager) removeTokenTx(tx *sql.Tx, userID, token string) error { - if token == "" { - return errNoTokenProvided - } - if _, err := tx.Exec(deleteTokenQuery, userID, token); err != nil { - return err - } - return nil + return a.store.RemoveToken(userID, token) } // CanChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed. @@ -607,87 +188,39 @@ func (a *Manager) CanChangeToken(userID, token string) error { // RemoveExpiredTokens deletes all expired tokens from the database func (a *Manager) RemoveExpiredTokens() error { - if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil { - return err - } - return nil + return a.store.RemoveExpiredTokens() } // PhoneNumbers returns all phone numbers for the user with the given user ID func (a *Manager) PhoneNumbers(userID string) ([]string, error) { - rows, err := a.db.Query(selectPhoneNumbersQuery, userID) - if err != nil { - return nil, err - } - defer rows.Close() - phoneNumbers := make([]string, 0) - for { - phoneNumber, err := a.readPhoneNumber(rows) - if errors.Is(err, ErrPhoneNumberNotFound) { - break - } else if err != nil { - return nil, err - } - phoneNumbers = append(phoneNumbers, phoneNumber) - } - return phoneNumbers, nil -} - -func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) { - var phoneNumber string - if !rows.Next() { - return "", ErrPhoneNumberNotFound - } - if err := rows.Scan(&phoneNumber); err != nil { - return "", err - } else if err := rows.Err(); err != nil { - return "", err - } - return phoneNumber, nil + return a.store.PhoneNumbers(userID) } // AddPhoneNumber adds a phone number to the user with the given user ID func (a *Manager) AddPhoneNumber(userID string, phoneNumber string) error { - if _, err := a.db.Exec(insertPhoneNumberQuery, userID, phoneNumber); err != nil { - if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique { - return ErrPhoneNumberExists - } - return err - } - return nil + return a.store.AddPhoneNumber(userID, phoneNumber) } // RemovePhoneNumber deletes a phone number from the user with the given user ID func (a *Manager) RemovePhoneNumber(userID string, phoneNumber string) error { - _, err := a.db.Exec(deletePhoneNumberQuery, userID, phoneNumber) - return err + return a.store.RemovePhoneNumber(userID, phoneNumber) } -// RemoveDeletedUsers deletes all users that have been marked deleted for +// RemoveDeletedUsers deletes all users that have been marked deleted func (a *Manager) RemoveDeletedUsers() error { - if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil { - return err - } - return nil + return a.store.RemoveDeletedUsers() } // ChangeSettings persists the user settings func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error { - b, err := json.Marshal(prefs) - if err != nil { - return err - } - if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil { - return err - } - return nil + return a.store.ChangeSettings(userID, prefs) } // ResetStats resets all user stats in the user database. This touches all users. func (a *Manager) ResetStats() error { a.mu.Lock() // Includes database query to avoid races! defer a.mu.Unlock() - if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil { + if err := a.store.ResetStats(); err != nil { return err } a.statsQueue = make(map[string]*Stats) @@ -702,7 +235,7 @@ func (a *Manager) EnqueueUserStats(userID string, stats *Stats) { a.statsQueue[userID] = stats } -// EnqueueTokenUpdate adds the token update to a queue which writes out token access times +// EnqueueTokenUpdate adds the token update to a queue which writes out token access times // in batches at a regular interval func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) { a.mu.Lock() @@ -732,11 +265,7 @@ func (a *Manager) writeUserStatsQueue() error { statsQueue := a.statsQueue a.statsQueue = make(map[string]*Stats) a.mu.Unlock() - tx, err := a.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue)) for userID, update := range statsQueue { log. @@ -748,11 +277,11 @@ func (a *Manager) writeUserStatsQueue() error { "calls_count": update.Calls, }). Trace("Updating stats for user %s", userID) - if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.Calls, userID); err != nil { + if err := a.store.UpdateStats(userID, update); err != nil { return err } } - return tx.Commit() + return nil } func (a *Manager) writeTokenUpdateQueue() error { @@ -765,25 +294,14 @@ func (a *Manager) writeTokenUpdateQueue() error { tokenQueue := a.tokenQueue a.tokenQueue = make(map[string]*TokenUpdate) a.mu.Unlock() - tx, err := a.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) - if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil { + if err := a.store.UpdateTokenLastAccess(tokenID, update.LastAccess, update.LastOrigin); err != nil { return err } } - return tx.Commit() -} - -func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error { - if _, err := tx.Exec(updateTokenLastAccessQuery, lastAccess, lastOrigin, token); err != nil { - return err - } return nil } @@ -798,23 +316,13 @@ func (a *Manager) Authorize(user *User, topic string, perm Permission) error { username = user.Name } // Select the read/write permissions for this user/topic combo. - // - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. - // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" - // - It also prioritizes write permissions over read permissions - rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic) + read, write, found, err := a.store.AuthorizeTopicAccess(username, topic) if err != nil { return err } - defer rows.Close() - if !rows.Next() { + if !found { return a.resolvePerms(a.config.DefaultAccess, perm) } - var read, write bool - if err := rows.Scan(&read, &write); err != nil { - return err - } else if err := rows.Err(); err != nil { - return err - } return a.resolvePerms(NewPermission(read, write), perm) } @@ -829,18 +337,15 @@ func (a *Manager) resolvePerms(base, perm Permission) error { // AddUser adds a user with the given username, password and role func (a *Manager) AddUser(username, password string, role Role, hashed bool) error { - return execTx(a.db, func(tx *sql.Tx) error { - return a.addUserTx(tx, username, password, role, hashed, false) - }) + return a.addUser(username, password, role, hashed, false) } -// AddUser adds a user with the given username, password and role -func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error { +func (a *Manager) addUser(username, password string, role Role, hashed, provisioned bool) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } var hash string - var err error = nil + var err error if hashed { hash = password if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil { @@ -852,15 +357,7 @@ func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, ha return err } } - userID := util.RandomStringPrefix(userIDPrefix, userIDLength) - syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix() - if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil { - if errors.Is(err, sqlite3.ErrConstraintUnique) { - return ErrUserExists - } - return err - } - return nil + return a.store.AddUser(username, hash, role, provisioned) } // RemoveUser deletes the user with the given username. The function returns nil on success, even @@ -869,20 +366,7 @@ func (a *Manager) RemoveUser(username string) error { if err := a.CanChangeUser(username); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { - return a.removeUserTx(tx, username) - }) -} - -func (a *Manager) removeUserTx(tx *sql.Tx, username string) error { - if !AllowedUsername(username) { - return ErrInvalidArgument - } - // Rows in user_access, user_token, etc. are deleted via foreign keys - if _, err := tx.Exec(deleteUserQuery, username); err != nil { - return err - } - return nil + return a.store.RemoveUser(username) } // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents @@ -891,297 +375,64 @@ func (a *Manager) MarkUserRemoved(user *User) error { if !AllowedUsername(user.Name) { return ErrInvalidArgument } - tx, err := a.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil { - return err - } - if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil { - return err - } - if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil { - return err - } - return tx.Commit() + return a.store.MarkUserRemoved(user.ID) } // Users returns a list of users. It always also returns the Everyone user ("*"). func (a *Manager) Users() ([]*User, error) { - rows, err := a.db.Query(selectUsernamesQuery) - if err != nil { - return nil, err - } - defer rows.Close() - usernames := make([]string, 0) - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - usernames = append(usernames, username) - } - rows.Close() - users := make([]*User, 0) - for _, username := range usernames { - user, err := a.User(username) - if err != nil { - return nil, err - } - users = append(users, user) - } - return users, nil + return a.store.Users() } -// UsersCount returns the number of users in the databsae +// UsersCount returns the number of users in the database func (a *Manager) UsersCount() (int64, error) { - rows, err := a.db.Query(selectUserCountQuery) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil + return a.store.UsersCount() } // User returns the user with the given username if it exists, or ErrUserNotFound otherwise. // You may also pass Everyone to retrieve the anonymous user and its Grant list. func (a *Manager) User(username string) (*User, error) { - rows, err := a.db.Query(selectUserByNameQuery, username) - if err != nil { - return nil, err - } - return a.readUser(rows) + return a.store.User(username) } // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByID(id string) (*User, error) { - rows, err := a.db.Query(selectUserByIDQuery, id) - if err != nil { - return nil, err - } - return a.readUser(rows) + return a.store.UserByID(id) } // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { - rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID) - if err != nil { - return nil, err - } - return a.readUser(rows) -} - -func (a *Manager) userByToken(token string) (*User, error) { - rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix()) - if err != nil { - return nil, err - } - return a.readUser(rows) -} - -func (a *Manager) readUser(rows *sql.Rows) (*User, error) { - defer rows.Close() - var id, username, hash, role, prefs, syncTopic string - var provisioned bool - var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString - var messages, emails, calls int64 - var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 - if !rows.Next() { - return nil, ErrUserNotFound - } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - user := &User{ - ID: id, - Name: username, - Hash: hash, - Role: Role(role), - Prefs: &Prefs{}, - SyncTopic: syncTopic, - Provisioned: provisioned, - Stats: &Stats{ - Messages: messages, - Emails: emails, - Calls: calls, - }, - Billing: &Billing{ - StripeCustomerID: stripeCustomerID.String, // May be empty - StripeSubscriptionID: stripeSubscriptionID.String, // May be empty - StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty - StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty - StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero - StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero - }, - Deleted: deleted.Valid, - } - if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { - return nil, err - } - if tierCode.Valid { - // See readTier() when this is changed! - user.Tier = &Tier{ - ID: tierID.String, - Code: tierCode.String, - Name: tierName.String, - MessageLimit: messagesLimit.Int64, - MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailLimit: emailsLimit.Int64, - CallLimit: callsLimit.Int64, - ReservationLimit: reservationsLimit.Int64, - AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, - AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, - AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty - StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty - } - } - return user, nil + return a.store.UserByStripeCustomer(stripeCustomerID) } // AllGrants returns all user-specific access control entries, mapped to their respective user IDs func (a *Manager) AllGrants() (map[string][]Grant, error) { - rows, err := a.db.Query(selectUserAllAccessQuery) - if err != nil { - return nil, err - } - defer rows.Close() - grants := make(map[string][]Grant, 0) - for rows.Next() { - var userID, topic string - var read, write, provisioned bool - if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - if _, ok := grants[userID]; !ok { - grants[userID] = make([]Grant, 0) - } - grants[userID] = append(grants[userID], Grant{ - TopicPattern: fromSQLWildcard(topic), - Permission: NewPermission(read, write), - Provisioned: provisioned, - }) - } - return grants, nil + return a.store.AllGrants() } // Grants returns all user-specific access control entries func (a *Manager) Grants(username string) ([]Grant, error) { - rows, err := a.db.Query(selectUserAccessQuery, username) - if err != nil { - return nil, err - } - defer rows.Close() - grants := make([]Grant, 0) - for rows.Next() { - var topic string - var read, write, provisioned bool - if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - grants = append(grants, Grant{ - TopicPattern: fromSQLWildcard(topic), - Permission: NewPermission(read, write), - Provisioned: provisioned, - }) - } - return grants, nil + return a.store.Grants(username) } // Reservations returns all user-owned topics, and the associated everyone-access func (a *Manager) Reservations(username string) ([]Reservation, error) { - rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username) - if err != nil { - return nil, err - } - defer rows.Close() - reservations := make([]Reservation, 0) - for rows.Next() { - var topic string - var ownerRead, ownerWrite bool - var everyoneRead, everyoneWrite sql.NullBool - if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - reservations = append(reservations, Reservation{ - Topic: unescapeUnderscore(topic), - Owner: NewPermission(ownerRead, ownerWrite), - Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null - }) - } - return reservations, nil + return a.store.Reservations(username) } // HasReservation returns true if the given topic access is owned by the user func (a *Manager) HasReservation(username, topic string) (bool, error) { - rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic)) - if err != nil { - return false, err - } - defer rows.Close() - if !rows.Next() { - return false, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return false, err - } - return count > 0, nil + return a.store.HasReservation(username, topic) } // ReservationsCount returns the number of reservations owned by this user func (a *Manager) ReservationsCount(username string) (int64, error) { - rows, err := a.db.Query(selectUserReservationsCountQuery, username) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil + return a.store.ReservationsCount(username) } // ReservationOwner returns user ID of the user that owns this topic, or an // empty string if it's not owned by anyone func (a *Manager) ReservationOwner(topic string) (string, error) { - rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic)) - if err != nil { - return "", err - } - defer rows.Close() - if !rows.Next() { - return "", nil - } - var ownerUserID string - if err := rows.Scan(&ownerUserID); err != nil { - return "", err - } - return ownerUserID, nil + return a.store.ReservationOwner(topic) } // ChangePassword changes a user's password @@ -1189,9 +440,20 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error { if err := a.CanChangeUser(username); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { - return a.changePasswordTx(tx, username, password, hashed) - }) + var hash string + var err error + if hashed { + hash = password + if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil { + return err + } + } else { + hash, err = hashPassword(password, a.config.BcryptCost) + if err != nil { + return err + } + } + return a.store.ChangePassword(username, hash) } // CanChangeUser checks if the user with the given username can be changed. @@ -1206,59 +468,13 @@ func (a *Manager) CanChangeUser(username string) error { return nil } -func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error { - var hash string - var err error - if hashed { - hash = password - if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil { - return err - } - } else { - hash, err = hashPassword(password, a.config.BcryptCost) - if err != nil { - return err - } - } - if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil { - return err - } - return nil -} - // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin, // all existing access control entries (Grant) are removed, since they are no longer needed. func (a *Manager) ChangeRole(username string, role Role) error { if err := a.CanChangeUser(username); err != nil { return err } - return execTx(a.db, func(tx *sql.Tx) error { - return a.changeRoleTx(tx, username, role) - }) -} - -func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error { - if !AllowedUsername(username) || !AllowedRole(role) { - return ErrInvalidArgument - } - if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil { - return err - } - if role == RoleAdmin { - if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil { - return err - } - } - return nil -} - -// changeProvisionedTx changes the provisioned status of a user. This is used to mark users as -// provisioned. A provisioned user is a user defined in the config file. -func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error { - if _, err := tx.Exec(updateUserProvisionedQuery, provisioned, username); err != nil { - return err - } - return nil + return a.store.ChangeRole(username, role) } // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages, @@ -1273,10 +489,7 @@ func (a *Manager) ChangeTier(username, tier string) error { } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { return err } - if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil { - return err - } - return nil + return a.store.ChangeTier(username, tier) } // ResetTier removes the tier from the given user @@ -1286,8 +499,7 @@ func (a *Manager) ResetTier(username string) error { } else if err := a.checkReservationsLimit(username, 0); err != nil { return err } - _, err := a.db.Exec(deleteUserTierQuery, username) - return err + return a.store.ResetTier(username) } func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { @@ -1312,69 +524,45 @@ func (a *Manager) AllowReservation(username string, topic string) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { return ErrInvalidArgument } - rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username) + otherCount, err := a.store.OtherAccessCount(username, topic) if err != nil { return err } - defer rows.Close() - if !rows.Next() { - return errNoRows - } - var otherCount int - if err := rows.Scan(&otherCount); err != nil { - return err - } if otherCount > 0 { return errTopicOwnedByOthers } return nil } -// AllowAccess adds or updates an entry in th access control list for a specific user. It controls +// AllowAccess adds or updates an entry in the access control list for a specific user. It controls // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // owner may either be a user (username), or the system (empty). func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { - return execTx(a.db, func(tx *sql.Tx) error { - return a.allowAccessTx(tx, username, topicPattern, permission, false) - }) + return a.allowAccess(username, topicPattern, permission, false) } -func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error { +func (a *Manager) allowAccess(username string, topicPattern string, permission Permission, provisioned bool) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } - owner := "" - if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil { - return err - } - return nil + return a.store.AllowAccess(username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned) } // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). func (a *Manager) ResetAccess(username string, topicPattern string) error { - return execTx(a.db, func(tx *sql.Tx) error { - return a.resetAccessTx(tx, username, topicPattern) - }) + return a.resetAccess(username, topicPattern) } -func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error { +func (a *Manager) resetAccess(username string, topicPattern string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { return ErrInvalidArgument } - if username == "" && topicPattern == "" { - _, err := tx.Exec(deleteAllAccessQuery, username) - return err - } else if topicPattern == "" { - _, err := tx.Exec(deleteUserAccessQuery, username, username) - return err - } - _, err := tx.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern)) - return err + return a.store.ResetAccess(username, topicPattern) } // AddReservation creates two access control entries for the given topic: one with full read/write access for the @@ -1384,18 +572,13 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { return ErrInvalidArgument } - tx, err := a.db.Begin() - if err != nil { + if err := a.store.AllowAccess(username, topic, true, true, username, false); err != nil { return err } - defer tx.Rollback() - if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username, false); err != nil { + if err := a.store.AllowAccess(Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil { return err } - if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil { - return err - } - return tx.Commit() + return nil } // RemoveReservations deletes the access control entries associated with the given username/topic, as @@ -1409,20 +592,15 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { return ErrInvalidArgument } } - tx, err := a.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() for _, topic := range topics { - if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil { + if err := a.store.ResetAccess(username, topic); err != nil { return err } - if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil { + if err := a.store.ResetAccess(Everyone, topic); err != nil { return err } } - return tx.Commit() + return nil } // DefaultAccess returns the default read/write access if no access control entry matches @@ -1432,117 +610,42 @@ func (a *Manager) DefaultAccess() Permission { // AddTier creates a new tier in the database func (a *Manager) AddTier(tier *Tier) error { - if tier.ID == "" { - tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) - } - if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { - return err - } - return nil + return a.store.AddTier(tier) } // UpdateTier updates a tier's properties in the database func (a *Manager) UpdateTier(tier *Tier) error { - if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { - return err - } - return nil + return a.store.UpdateTier(tier) } // RemoveTier deletes the tier with the given code func (a *Manager) RemoveTier(code string) error { - if !AllowedTier(code) { - return ErrInvalidArgument - } - // This fails if any user has this tier - if _, err := a.db.Exec(deleteTierQuery, code); err != nil { - return err - } - return nil + return a.store.RemoveTier(code) } // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information func (a *Manager) ChangeBilling(username string, billing *Billing) error { - if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { - return err - } - return nil + return a.store.ChangeBilling(username, billing) } // Tiers returns a list of all Tier structs func (a *Manager) Tiers() ([]*Tier, error) { - rows, err := a.db.Query(selectTiersQuery) - if err != nil { - return nil, err - } - defer rows.Close() - tiers := make([]*Tier, 0) - for { - tier, err := a.readTier(rows) - if errors.Is(err, ErrTierNotFound) { - break - } else if err != nil { - return nil, err - } - tiers = append(tiers, tier) - } - return tiers, nil + return a.store.Tiers() } // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist func (a *Manager) Tier(code string) (*Tier, error) { - rows, err := a.db.Query(selectTierByCodeQuery, code) - if err != nil { - return nil, err - } - defer rows.Close() - return a.readTier(rows) + return a.store.Tier(code) } // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { - rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID) - if err != nil { - return nil, err - } - defer rows.Close() - return a.readTier(rows) -} - -func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { - var id, code, name string - var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString - var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 - if !rows.Next() { - return nil, ErrTierNotFound - } - if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - // When changed, note readUser() as well - return &Tier{ - ID: id, - Code: code, - Name: name, - MessageLimit: messagesLimit.Int64, - MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailLimit: emailsLimit.Int64, - CallLimit: callsLimit.Int64, - ReservationLimit: reservationsLimit.Int64, - AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, - AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, - AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty - StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty - }, nil + return a.store.TierByStripePrice(priceID) } // Close closes the underlying database func (a *Manager) Close() error { - return a.db.Close() + return a.store.Close() } // maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config. @@ -1557,29 +660,27 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error { provisionUsernames := util.Map(a.config.Users, func(u *User) string { return u.Name }) - return execTx(a.db, func(tx *sql.Tx) error { - if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil { - return fmt.Errorf("failed to provision users: %v", err) - } - if err := a.maybeProvisionGrants(tx); err != nil { - return fmt.Errorf("failed to provision grants: %v", err) - } - if err := a.maybeProvisionTokens(tx, provisionUsernames); err != nil { - return fmt.Errorf("failed to provision tokens: %v", err) - } - return nil - }) + if err := a.maybeProvisionUsers(provisionUsernames, existingUsers); err != nil { + return fmt.Errorf("failed to provision users: %v", err) + } + if err := a.maybeProvisionGrants(); err != nil { + return fmt.Errorf("failed to provision grants: %v", err) + } + if err := a.maybeProvisionTokens(provisionUsernames); err != nil { + return fmt.Errorf("failed to provision tokens: %v", err) + } + return nil } // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them. // It also removes users that are provisioned, but not in the config anymore. -func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error { +func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers []*User) error { // Remove users that are provisioned, but not in the config anymore for _, user := range existingUsers { if user.Name == Everyone { continue } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { - if err := a.removeUserTx(tx, user.Name); err != nil { + if err := a.store.RemoveUser(user.Name); err != nil { return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) } } @@ -1593,22 +694,22 @@ func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, e return u.Name == user.Name }) if !exists { - if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { + if err := a.addUser(user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) } } else { if !existingUser.Provisioned { - if err := a.changeProvisionedTx(tx, user.Name, true); err != nil { + if err := a.store.ChangeProvisioned(user.Name, true); err != nil { return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) } } if existingUser.Hash != user.Hash { - if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil { + if err := a.store.ChangePassword(user.Name, user.Hash); err != nil { return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) } } if existingUser.Role != user.Role { - if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil { + if err := a.store.ChangeRole(user.Name, user.Role); err != nil { return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) } } @@ -1621,9 +722,9 @@ func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, e // // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last // access time) or do not have dependent resources (such as grants or tokens). -func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error { +func (a *Manager) maybeProvisionGrants() error { // Remove all provisioned grants - if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil { + if err := a.store.ResetAllProvisionedAccess(); err != nil { return err } // (Re-)add provisioned grants @@ -1637,10 +738,10 @@ func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error { return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username) } for _, grant := range grants { - if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil { + if err := a.resetAccess(username, grant.TopicPattern); err != nil { return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err) } - if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil { + if err := a.allowAccess(username, grant.TopicPattern, grant.Permission, true); err != nil { return err } } @@ -1648,9 +749,9 @@ func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error { return nil } -func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) error { +func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { // Remove tokens that are provisioned, but not in the config anymore - existingTokens, err := a.allProvisionedTokens() + existingTokens, err := a.store.AllProvisionedTokens() if err != nil { return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err) } @@ -1662,7 +763,7 @@ func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) } for _, existingToken := range existingTokens { if !slices.Contains(provisionTokens, existingToken.Value) { - if _, err := tx.Exec(deleteProvisionedTokenQuery, existingToken.Value); err != nil { + if err := a.store.RemoveToken("", existingToken.Value); err != nil { return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err) } } @@ -1672,71 +773,15 @@ func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) if !slices.Contains(provisionUsernames, username) && username != Everyone { return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) } - var userID string - row := tx.QueryRow(selectUserIDFromUsernameQuery, username) - if err := row.Scan(&userID); err != nil { - return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username) + userID, err := a.store.UserIDByUsername(username) + if err != nil { + return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err) } for _, token := range tokens { - if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil { + if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), true); err != nil { return err } } } return nil } - -func runStartupQueries(db *sql.DB, startupQueries string) error { - if _, err := db.Exec(startupQueries); err != nil { - return err - } - if _, err := db.Exec(builtinStartupQueries); err != nil { - return err - } - return nil -} - -func setupDB(db *sql.DB) error { - // If 'schemaVersion' table does not exist, this must be a new database - rowsSV, err := db.Query(selectSchemaVersionQuery) - if err != nil { - return setupNewDB(db) - } - defer rowsSV.Close() - - // If 'schemaVersion' table exists, read version and potentially upgrade - schemaVersion := 0 - if !rowsSV.Next() { - return errors.New("cannot determine schema version: database file may be corrupt") - } - if err := rowsSV.Scan(&schemaVersion); err != nil { - return err - } - rowsSV.Close() - - // Do migrations - if schemaVersion == currentSchemaVersion { - return nil - } else if schemaVersion > currentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) - } - for i := schemaVersion; i < currentSchemaVersion; i++ { - fn, ok := migrations[i] - if !ok { - return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) - } else if err := fn(db); err != nil { - return err - } - } - return nil -} - -func setupNewDB(db *sql.DB) error { - if _, err := db.Exec(createTablesQueries); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil -} diff --git a/user/manager_test.go b/user/manager_test.go index 26a70942..a27f18f1 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -224,7 +224,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { require.Nil(t, err) require.True(t, u.Deleted) - _, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) + _, err = testDB(a).Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) require.Nil(t, err) require.Nil(t, a.RemoveDeletedUsers()) @@ -604,14 +604,14 @@ func TestManager_Token_Expire(t *testing.T) { require.Nil(t, err) // Modify token expiration in database - _, err = a.db.Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value) + _, err = testDB(a).Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value) require.Nil(t, err) // Now token1 shouldn't work anymore _, err = a.AuthenticateToken(token1.Value) require.Equal(t, ErrUnauthenticated, err) - result, err := a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value) + result, err := testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value) require.Nil(t, err) require.True(t, result.Next()) // Still a matching row require.Nil(t, result.Close()) @@ -619,7 +619,7 @@ func TestManager_Token_Expire(t *testing.T) { // Expire tokens and check database rows require.Nil(t, a.RemoveExpiredTokens()) - result, err = a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value) + result, err = testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value) require.Nil(t, err) require.False(t, result.Next()) // No matching row! require.Nil(t, result.Close()) @@ -687,7 +687,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { benTokens = append(benTokens, token.Value) // Manually modify expiry date to avoid sorting issues (this is a hack) - _, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value) + _, err = testDB(a).Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value) require.Nil(t, err) } @@ -715,14 +715,14 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { } var benCount int - rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID) + rows, err := testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID) require.Nil(t, err) require.True(t, rows.Next()) require.Nil(t, rows.Scan(&benCount)) require.Equal(t, 60, benCount) var philCount int - rows, err = a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID) + rows, err = testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID) require.Nil(t, err) require.True(t, rows.Next()) require.Nil(t, rows.Scan(&philCount)) @@ -730,15 +730,13 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { } func TestManager_EnqueueStats_ResetStats(t *testing.T) { + filename := filepath.Join(t.TempDir(), "db") conf := &Config{ - Filename: filepath.Join(t.TempDir(), "db"), - StartupQueries: "", DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 1500 * time.Millisecond, } - a, err := NewManager(conf) - require.Nil(t, err) + a := newTestManagerFromStoreConfig(t, filename, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // Baseline: No messages or emails @@ -779,15 +777,13 @@ func TestManager_EnqueueStats_ResetStats(t *testing.T) { } func TestManager_EnqueueTokenUpdate(t *testing.T) { + filename := filepath.Join(t.TempDir(), "db") conf := &Config{ - Filename: filepath.Join(t.TempDir(), "db"), - StartupQueries: "", DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 500 * time.Millisecond, } - a, err := NewManager(conf) - require.Nil(t, err) + a := newTestManagerFromStoreConfig(t, filename, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // Create user and token @@ -819,15 +815,13 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) { } func TestManager_ChangeSettings(t *testing.T) { + filename := filepath.Join(t.TempDir(), "db") conf := &Config{ - Filename: filepath.Join(t.TempDir(), "db"), - StartupQueries: "", DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 1500 * time.Millisecond, } - a, err := NewManager(conf) - require.Nil(t, err) + a := newTestManagerFromStoreConfig(t, filename, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // No settings @@ -1053,7 +1047,7 @@ func TestUser_PhoneNumberAddListRemove(t *testing.T) { require.Equal(t, 0, len(phoneNumbers)) // Paranoia check: We do NOT want to keep phone numbers in there - rows, err := a.db.Query(`SELECT * FROM user_phone`) + rows, err := testDB(a).Query(`SELECT * FROM user_phone`) require.Nil(t, err) require.False(t, rows.Next()) require.Nil(t, rows.Close()) @@ -1098,7 +1092,6 @@ func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) { func TestManager_WithProvisionedUsers(t *testing.T) { f := filepath.Join(t.TempDir(), "user.db") conf := &Config{ - Filename: f, DefaultAccess: PermissionReadWrite, ProvisionEnabled: true, Users: []*User{ @@ -1117,8 +1110,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { }, }, } - a, err := NewManager(conf) - require.Nil(t, err) + a := newTestManagerFromStoreConfig(t, f, conf) // Manually add user require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false)) @@ -1154,13 +1146,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) { // Update the token last access time and origin (so we can check that it is persisted) lastAccessTime := time.Now().Add(time.Hour) lastOrigin := netip.MustParseAddr("1.1.9.9") - err = execTx(a.db, func(tx *sql.Tx) error { - return a.updateTokenLastAccessTx(tx, tokens[0].Value, lastAccessTime.Unix(), lastOrigin.String()) - }) + err = a.store.UpdateTokenLastAccess(tokens[0].Value, lastAccessTime, lastOrigin) require.Nil(t, err) // Re-open the DB (second app start) - require.Nil(t, a.db.Close()) + require.Nil(t, a.Close()) conf.Users = []*User{ {Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser}, } @@ -1176,8 +1166,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { {Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"}, }, } - a, err = NewManager(conf) - require.Nil(t, err) + a = newTestManagerFromStoreConfig(t, f, conf) // Check that the provisioned users are there users, err = a.Users() @@ -1212,12 +1201,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) { require.Error(t, a.ChangePassword("philuser", "new-pass", false)) // Re-open the DB again (third app start) - require.Nil(t, a.db.Close()) + require.Nil(t, a.Close()) conf.Users = []*User{} conf.Access = map[string][]*Grant{} conf.Tokens = map[string][]*Token{} - a, err = NewManager(conf) - require.Nil(t, err) + a = newTestManagerFromStoreConfig(t, f, conf) // Check that the provisioned users are all gone users, err = a.Users() @@ -1237,17 +1225,16 @@ func TestManager_WithProvisionedUsers(t *testing.T) { require.Equal(t, 0, len(tokens)) var count int - a.db.QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count) + testDB(a).QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count) require.Equal(t, 0, count) - a.db.QueryRow("SELECT COUNT(*) FROM user_grant WHERE provisioned = 1").Scan(&count) + testDB(a).QueryRow("SELECT COUNT(*) FROM user_access WHERE provisioned = 1").Scan(&count) require.Equal(t, 0, count) - a.db.QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count) + testDB(a).QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count) } func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { f := filepath.Join(t.TempDir(), "user.db") conf := &Config{ - Filename: f, DefaultAccess: PermissionReadWrite, ProvisionEnabled: true, Users: []*User{}, @@ -1257,8 +1244,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { }, }, } - a, err := NewManager(conf) - require.Nil(t, err) + a := newTestManagerFromStoreConfig(t, f, conf) // Manually add user require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false)) @@ -1290,7 +1276,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { require.True(t, grants[0].Provisioned) // Provisioned entry // Re-open the DB (second app start) - require.Nil(t, a.db.Close()) + require.Nil(t, a.Close()) conf.Users = []*User{ {Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser}, } @@ -1299,8 +1285,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { {TopicPattern: "stats", Permission: PermissionReadWrite}, }, } - a, err = NewManager(conf) - require.Nil(t, err) + a = newTestManagerFromStoreConfig(t, f, conf) // Check that the user was "upgraded" to a provisioned user users, err = a.Users() @@ -1383,7 +1368,7 @@ func TestMigrationFrom1(t *testing.T) { // Create manager to trigger migration a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval) - checkSchemaVersion(t, a.db) + checkSchemaVersion(t, testDB(a)) users, err := a.Users() require.Nil(t, err) @@ -1526,7 +1511,7 @@ func TestMigrationFrom4(t *testing.T) { // Create manager to trigger migration a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval) - checkSchemaVersion(t, a.db) + checkSchemaVersion(t, testDB(a)) // Add another require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite)) @@ -1587,14 +1572,26 @@ func newTestManager(t *testing.T, defaultAccess Permission) *Manager { } func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager { + store, err := NewSQLiteStore(filename, startupQueries) + require.Nil(t, err) conf := &Config{ - Filename: filename, - StartupQueries: startupQueries, DefaultAccess: defaultAccess, BcryptCost: bcryptCost, QueueWriterInterval: statsWriterInterval, } - a, err := NewManager(conf) + a, err := NewManager(store, conf) + require.Nil(t, err) + return a +} + +func testDB(a *Manager) *sql.DB { + return a.store.(*commonStore).db +} + +func newTestManagerFromStoreConfig(t *testing.T, filename string, conf *Config) *Manager { + store, err := NewSQLiteStore(filename, "") + require.Nil(t, err) + a, err := NewManager(store, conf) require.Nil(t, err) return a } diff --git a/user/store.go b/user/store.go new file mode 100644 index 00000000..2a276d19 --- /dev/null +++ b/user/store.go @@ -0,0 +1,992 @@ +package user + +import ( + "database/sql" + "encoding/json" + "errors" + "heckel.io/ntfy/v2/payments" + "heckel.io/ntfy/v2/util" + "net/netip" + "time" +) + +// Store is the interface for a user database store +type Store interface { + // User operations + UserByID(id string) (*User, error) + User(username string) (*User, error) + UserByToken(token string) (*User, error) + UserByStripeCustomer(customerID string) (*User, error) + Users() ([]*User, error) + UsersCount() (int64, error) + AddUser(username, hash string, role Role, provisioned bool) error + RemoveUser(username string) error + MarkUserRemoved(userID string) error + RemoveDeletedUsers() error + ChangePassword(username, hash string) error + ChangeRole(username string, role Role) error + ChangeProvisioned(username string, provisioned bool) error + ChangeSettings(userID string, prefs *Prefs) error + ChangeTier(username, tierCode string) error + ResetTier(username string) error + UpdateStats(userID string, stats *Stats) error + ResetStats() error + // Token operations + CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) + Token(userID, token string) (*Token, error) + Tokens(userID string) ([]*Token, error) + AllProvisionedTokens() ([]*Token, error) + ChangeTokenLabel(userID, token, label string) error + ChangeTokenExpiry(userID, token string, expires time.Time) error + UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error + RemoveToken(userID, token string) error + RemoveExpiredTokens() error + TokenCount(userID string) (int, error) + RemoveExcessTokens(userID string, maxCount int) error + // Access operations + AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) + AllGrants() (map[string][]Grant, error) + Grants(username string) ([]Grant, error) + AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error + ResetAccess(username, topicPattern string) error + ResetAllProvisionedAccess() error + Reservations(username string) ([]Reservation, error) + HasReservation(username, topic string) (bool, error) + ReservationsCount(username string) (int64, error) + ReservationOwner(topic string) (string, error) + OtherAccessCount(username, topic string) (int, error) + // Tier operations + AddTier(tier *Tier) error + UpdateTier(tier *Tier) error + RemoveTier(code string) error + Tiers() ([]*Tier, error) + Tier(code string) (*Tier, error) + TierByStripePrice(priceID string) (*Tier, error) + // Phone operations + PhoneNumbers(userID string) ([]string, error) + AddPhoneNumber(userID, phoneNumber string) error + RemovePhoneNumber(userID, phoneNumber string) error + // Billing + ChangeBilling(username string, billing *Billing) error + // Internal helpers + UserIDByUsername(username string) (string, error) + // System + Close() error +} + +// storeQueries holds the database-specific SQL queries +type storeQueries struct { + // User queries + selectUserByID string + selectUserByName string + selectUserByToken string + selectUserByStripeID string + selectUsernames string + selectUserCount string + selectUserIDFromUsername string + insertUser string + updateUserPass string + updateUserRole string + updateUserProvisioned string + updateUserPrefs string + updateUserStats string + updateUserStatsResetAll string + updateUserTier string + updateUserDeleted string + deleteUser string + deleteUserTier string + deleteUsersMarked string + // Access queries + selectTopicPerms string + selectUserAllAccess string + selectUserAccess string + selectUserReservations string + selectUserReservationsCount string + selectUserReservationsOwner string + selectUserHasReservation string + selectOtherAccessCount string + upsertUserAccess string + deleteUserAccess string + deleteUserAccessProvisioned string + deleteTopicAccess string + deleteAllAccess string + // Token queries + selectToken string + selectTokens string + selectTokenCount string + selectAllProvisionedTokens string + upsertToken string + updateTokenLabel string + updateTokenExpiry string + updateTokenLastAccess string + deleteToken string + deleteProvisionedToken string + deleteAllToken string + deleteExpiredTokens string + deleteExcessTokens string + // Tier queries + insertTier string + selectTiers string + selectTierByCode string + selectTierByPriceID string + updateTier string + deleteTier string + // Phone queries + selectPhoneNumbers string + insertPhoneNumber string + deletePhoneNumber string + // Billing queries + updateBilling string +} + +// commonStore implements store operations that work across database backends +type commonStore struct { + db *sql.DB + queries storeQueries +} + +// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise +func (s *commonStore) UserByID(id string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByID, id) + if err != nil { + return nil, err + } + return s.readUser(rows) +} + +// User returns the user with the given username if it exists, or ErrUserNotFound otherwise +func (s *commonStore) User(username string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByName, username) + if err != nil { + return nil, err + } + return s.readUser(rows) +} + +// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise +func (s *commonStore) UserByToken(token string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix()) + if err != nil { + return nil, err + } + return s.readUser(rows) +} + +// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise +func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID) + if err != nil { + return nil, err + } + return s.readUser(rows) +} + +// Users returns a list of users +func (s *commonStore) Users() ([]*User, error) { + rows, err := s.db.Query(s.queries.selectUsernames) + if err != nil { + return nil, err + } + defer rows.Close() + usernames := make([]string, 0) + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + usernames = append(usernames, username) + } + rows.Close() + users := make([]*User, 0) + for _, username := range usernames { + user, err := s.User(username) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + +// UsersCount returns the number of users in the database +func (s *commonStore) UsersCount() (int64, error) { + rows, err := s.db.Query(s.queries.selectUserCount) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// AddUser adds a user with the given username, password hash and role +func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error { + if !AllowedUsername(username) || !AllowedRole(role) { + return ErrInvalidArgument + } + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) + now := time.Now().Unix() + if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil { + if isUniqueConstraintError(err) { + return ErrUserExists + } + return err + } + return nil +} + +// RemoveUser deletes the user with the given username +func (s *commonStore) RemoveUser(username string) error { + if !AllowedUsername(username) { + return ErrInvalidArgument + } + // Rows in user_access, user_token, etc. are deleted via foreign keys + if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil { + return err + } + return nil +} + +// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens +func (s *commonStore) MarkUserRemoved(userID string) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + // Get username for deleteUserAccess query + user, err := s.UserByID(userID) + if err != nil { + return err + } + if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil { + return err + } + if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil { + return err + } + deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix() + if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil { + return err + } + return tx.Commit() +} + +// RemoveDeletedUsers deletes all users that have been marked deleted +func (s *commonStore) RemoveDeletedUsers() error { + if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil { + return err + } + return nil +} + +// ChangePassword changes a user's password +func (s *commonStore) ChangePassword(username, hash string) error { + if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil { + return err + } + return nil +} + +// ChangeRole changes a user's role +func (s *commonStore) ChangeRole(username string, role Role) error { + if !AllowedUsername(username) || !AllowedRole(role) { + return ErrInvalidArgument + } + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil { + return err + } + // If changing to admin, remove all access entries + if role == RoleAdmin { + if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil { + return err + } + } + return tx.Commit() +} + +// ChangeProvisioned changes the provisioned status of a user +func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error { + if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil { + return err + } + return nil +} + +// ChangeSettings persists the user settings +func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error { + b, err := json.Marshal(prefs) + if err != nil { + return err + } + if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil { + return err + } + return nil +} + +// ChangeTier changes a user's tier using the tier code +func (s *commonStore) ChangeTier(username, tierCode string) error { + if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil { + return err + } + return nil +} + +// ResetTier removes the tier from the given user +func (s *commonStore) ResetTier(username string) error { + if !AllowedUsername(username) && username != Everyone && username != "" { + return ErrInvalidArgument + } + _, err := s.db.Exec(s.queries.deleteUserTier, username) + return err +} + +// UpdateStats updates the user statistics +func (s *commonStore) UpdateStats(userID string, stats *Stats) error { + if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil { + return err + } + return nil +} + +// ResetStats resets all user stats in the user database +func (s *commonStore) ResetStats() error { + if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil { + return err + } + return nil +} +func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { + defer rows.Close() + var id, username, hash, role, prefs, syncTopic string + var provisioned bool + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString + var messages, emails, calls int64 + var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 + if !rows.Next() { + return nil, ErrUserNotFound + } + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + user := &User{ + ID: id, + Name: username, + Hash: hash, + Role: Role(role), + Prefs: &Prefs{}, + SyncTopic: syncTopic, + Provisioned: provisioned, + Stats: &Stats{ + Messages: messages, + Emails: emails, + Calls: calls, + }, + Billing: &Billing{ + StripeCustomerID: stripeCustomerID.String, + StripeSubscriptionID: stripeSubscriptionID.String, + StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), + StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), + StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), + StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), + }, + Deleted: deleted.Valid, + } + if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { + return nil, err + } + if tierCode.Valid { + user.Tier = &Tier{ + ID: tierID.String, + Code: tierCode.String, + Name: tierName.String, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + CallLimit: callsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, + AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, + AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, + StripeMonthlyPriceID: stripeMonthlyPriceID.String, + StripeYearlyPriceID: stripeYearlyPriceID.String, + } + } + return user, nil +} + +// CreateToken creates a new token +func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) { + if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { + return nil, err + } + return &Token{ + Value: token, + Label: label, + LastAccess: lastAccess, + LastOrigin: lastOrigin, + Expires: expires, + Provisioned: provisioned, + }, nil +} + +// Token returns a specific token for a user +func (s *commonStore) Token(userID, token string) (*Token, error) { + rows, err := s.db.Query(s.queries.selectToken, userID, token) + if err != nil { + return nil, err + } + defer rows.Close() + return s.readToken(rows) +} + +// Tokens returns all existing tokens for the user with the given user ID +func (s *commonStore) Tokens(userID string) ([]*Token, error) { + rows, err := s.db.Query(s.queries.selectTokens, userID) + if err != nil { + return nil, err + } + defer rows.Close() + tokens := make([]*Token, 0) + for { + token, err := s.readToken(rows) + if errors.Is(err, ErrTokenNotFound) { + break + } else if err != nil { + return nil, err + } + tokens = append(tokens, token) + } + return tokens, nil +} + +// AllProvisionedTokens returns all provisioned tokens +func (s *commonStore) AllProvisionedTokens() ([]*Token, error) { + rows, err := s.db.Query(s.queries.selectAllProvisionedTokens) + if err != nil { + return nil, err + } + defer rows.Close() + tokens := make([]*Token, 0) + for { + token, err := s.readToken(rows) + if errors.Is(err, ErrTokenNotFound) { + break + } else if err != nil { + return nil, err + } + tokens = append(tokens, token) + } + return tokens, nil +} + +// ChangeTokenLabel updates a token's label +func (s *commonStore) ChangeTokenLabel(userID, token, label string) error { + if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil { + return err + } + return nil +} + +// ChangeTokenExpiry updates a token's expiry time +func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error { + if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil { + return err + } + return nil +} + +// UpdateTokenLastAccess updates a token's last access time and origin +func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error { + if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil { + return err + } + return nil +} + +// RemoveToken deletes the token +func (s *commonStore) RemoveToken(userID, token string) error { + if token == "" { + return errNoTokenProvided + } + if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil { + return err + } + return nil +} + +// RemoveExpiredTokens deletes all expired tokens from the database +func (s *commonStore) RemoveExpiredTokens() error { + if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil { + return err + } + return nil +} + +// TokenCount returns the number of tokens for a user +func (s *commonStore) TokenCount(userID string) (int, error) { + rows, err := s.db.Query(s.queries.selectTokenCount, userID) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// RemoveExcessTokens deletes excess tokens beyond the specified maximum +func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error { + if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil { + return err + } + return nil +} +func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) { + var token, label, lastOrigin string + var lastAccess, expires int64 + var provisioned bool + if !rows.Next() { + return nil, ErrTokenNotFound + } + if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + lastOriginIP, err := netip.ParseAddr(lastOrigin) + if err != nil { + lastOriginIP = netip.IPv4Unspecified() + } + return &Token{ + Value: token, + Label: label, + LastAccess: time.Unix(lastAccess, 0), + LastOrigin: lastOriginIP, + Expires: time.Unix(expires, 0), + Provisioned: provisioned, + }, nil +} + +// AuthorizeTopicAccess returns the read/write permissions for the given username and topic. +// The found return value indicates whether an ACL entry was found at all. +func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { + rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) + if err != nil { + return false, false, false, err + } + defer rows.Close() + if !rows.Next() { + return false, false, false, nil + } + if err := rows.Scan(&read, &write); err != nil { + return false, false, false, err + } else if err := rows.Err(); err != nil { + return false, false, false, err + } + return read, write, true, nil +} + +// AllGrants returns all user-specific access control entries, mapped to their respective user IDs +func (s *commonStore) AllGrants() (map[string][]Grant, error) { + rows, err := s.db.Query(s.queries.selectUserAllAccess) + if err != nil { + return nil, err + } + defer rows.Close() + grants := make(map[string][]Grant, 0) + for rows.Next() { + var userID, topic string + var read, write, provisioned bool + if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + if _, ok := grants[userID]; !ok { + grants[userID] = make([]Grant, 0) + } + grants[userID] = append(grants[userID], Grant{ + TopicPattern: fromSQLWildcard(topic), + Permission: NewPermission(read, write), + Provisioned: provisioned, + }) + } + return grants, nil +} + +// Grants returns all user-specific access control entries +func (s *commonStore) Grants(username string) ([]Grant, error) { + rows, err := s.db.Query(s.queries.selectUserAccess, username) + if err != nil { + return nil, err + } + defer rows.Close() + grants := make([]Grant, 0) + for rows.Next() { + var topic string + var read, write, provisioned bool + if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + grants = append(grants, Grant{ + TopicPattern: fromSQLWildcard(topic), + Permission: NewPermission(read, write), + Provisioned: provisioned, + }) + } + return grants, nil +} + +// AllowAccess adds or updates an entry in the access control list +func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { + if !AllowedUsername(username) && username != Everyone { + return ErrInvalidArgument + } else if !AllowedTopicPattern(topicPattern) { + return ErrInvalidArgument + } + if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil { + return err + } + return nil +} + +// ResetAccess removes an access control list entry +func (s *commonStore) ResetAccess(username, topicPattern string) error { + if !AllowedUsername(username) && username != Everyone && username != "" { + return ErrInvalidArgument + } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { + return ErrInvalidArgument + } + if username == "" && topicPattern == "" { + _, err := s.db.Exec(s.queries.deleteAllAccess) + return err + } else if topicPattern == "" { + _, err := s.db.Exec(s.queries.deleteUserAccess, username, username) + return err + } + _, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) + return err +} + +// ResetAllProvisionedAccess removes all provisioned access control entries +func (s *commonStore) ResetAllProvisionedAccess() error { + if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil { + return err + } + return nil +} + +// Reservations returns all user-owned topics, and the associated everyone-access +func (s *commonStore) Reservations(username string) ([]Reservation, error) { + rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username) + if err != nil { + return nil, err + } + defer rows.Close() + reservations := make([]Reservation, 0) + for rows.Next() { + var topic string + var ownerRead, ownerWrite bool + var everyoneRead, everyoneWrite sql.NullBool + if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + reservations = append(reservations, Reservation{ + Topic: unescapeUnderscore(topic), + Owner: NewPermission(ownerRead, ownerWrite), + Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), + }) + } + return reservations, nil +} + +// HasReservation returns true if the given topic access is owned by the user +func (s *commonStore) HasReservation(username, topic string) (bool, error) { + rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic)) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + return false, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + +// ReservationsCount returns the number of reservations owned by this user +func (s *commonStore) ReservationsCount(username string) (int64, error) { + rows, err := s.db.Query(s.queries.selectUserReservationsCount, username) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone +func (s *commonStore) ReservationOwner(topic string) (string, error) { + rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic)) + if err != nil { + return "", err + } + defer rows.Close() + if !rows.Next() { + return "", nil + } + var ownerUserID string + if err := rows.Scan(&ownerUserID); err != nil { + return "", err + } + return ownerUserID, nil +} + +// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user +func (s *commonStore) OtherAccessCount(username, topic string) (int, error) { + rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// AddTier creates a new tier in the database +func (s *commonStore) AddTier(tier *Tier) error { + if tier.ID == "" { + tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) + } + if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { + return err + } + return nil +} + +// UpdateTier updates a tier's properties in the database +func (s *commonStore) UpdateTier(tier *Tier) error { + if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { + return err + } + return nil +} + +// RemoveTier deletes the tier with the given code +func (s *commonStore) RemoveTier(code string) error { + if !AllowedTier(code) { + return ErrInvalidArgument + } + if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil { + return err + } + return nil +} + +// Tiers returns a list of all Tier structs +func (s *commonStore) Tiers() ([]*Tier, error) { + rows, err := s.db.Query(s.queries.selectTiers) + if err != nil { + return nil, err + } + defer rows.Close() + tiers := make([]*Tier, 0) + for { + tier, err := s.readTier(rows) + if errors.Is(err, ErrTierNotFound) { + break + } else if err != nil { + return nil, err + } + tiers = append(tiers, tier) + } + return tiers, nil +} + +// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist +func (s *commonStore) Tier(code string) (*Tier, error) { + rows, err := s.db.Query(s.queries.selectTierByCode, code) + if err != nil { + return nil, err + } + defer rows.Close() + return s.readTier(rows) +} + +// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist +func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) { + rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID) + if err != nil { + return nil, err + } + defer rows.Close() + return s.readTier(rows) +} +func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) { + var id, code, name string + var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString + var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 + if !rows.Next() { + return nil, ErrTierNotFound + } + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + return &Tier{ + ID: id, + Code: code, + Name: name, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + CallLimit: callsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, + AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, + AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, + StripeMonthlyPriceID: stripeMonthlyPriceID.String, + StripeYearlyPriceID: stripeYearlyPriceID.String, + }, nil +} + +// PhoneNumbers returns all phone numbers for the user with the given user ID +func (s *commonStore) PhoneNumbers(userID string) ([]string, error) { + rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID) + if err != nil { + return nil, err + } + defer rows.Close() + phoneNumbers := make([]string, 0) + for { + phoneNumber, err := s.readPhoneNumber(rows) + if errors.Is(err, ErrPhoneNumberNotFound) { + break + } else if err != nil { + return nil, err + } + phoneNumbers = append(phoneNumbers, phoneNumber) + } + return phoneNumbers, nil +} + +// AddPhoneNumber adds a phone number to the user with the given user ID +func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error { + if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil { + if isUniqueConstraintError(err) { + return ErrPhoneNumberExists + } + return err + } + return nil +} + +// RemovePhoneNumber deletes a phone number from the user with the given user ID +func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error { + _, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber) + return err +} +func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) { + var phoneNumber string + if !rows.Next() { + return "", ErrPhoneNumberNotFound + } + if err := rows.Scan(&phoneNumber); err != nil { + return "", err + } else if err := rows.Err(); err != nil { + return "", err + } + return phoneNumber, nil +} + +// ChangeBilling updates a user's billing fields +func (s *commonStore) ChangeBilling(username string, billing *Billing) error { + if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { + return err + } + return nil +} + +// UserIDByUsername returns the user ID for the given username +func (s *commonStore) UserIDByUsername(username string) (string, error) { + rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username) + if err != nil { + return "", err + } + defer rows.Close() + if !rows.Next() { + return "", ErrUserNotFound + } + var userID string + if err := rows.Scan(&userID); err != nil { + return "", err + } + return userID, nil +} + +// Close closes the underlying database +func (s *commonStore) Close() error { + return s.db.Close() +} + +// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL +func isUniqueConstraintError(err error) bool { + // Check for SQLite unique constraint error + if sqliteErr, ok := err.(interface{ ExtendedCode() int }); ok { + if sqliteErr.ExtendedCode() == 2067 { // sqlite3.ErrConstraintUnique + return true + } + } + // Check for PostgreSQL unique violation (error code 23505) + if pgErr, ok := err.(interface{ Code() string }); ok { + if pgErr.Code() == "23505" { + return true + } + } + return false +} diff --git a/user/store_postgres.go b/user/store_postgres.go new file mode 100644 index 00000000..bfcf4b54 --- /dev/null +++ b/user/store_postgres.go @@ -0,0 +1,424 @@ +package user + +import ( + "database/sql" + "fmt" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver +) + +// PostgreSQL schema and queries +const ( + postgresCreateTablesQueries = ` + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit BIGINT NOT NULL, + messages_expiry_duration BIGINT NOT NULL, + emails_limit BIGINT NOT NULL, + calls_limit BIGINT NOT NULL, + reservations_limit BIGINT NOT NULL, + attachment_file_size_limit BIGINT NOT NULL, + attachment_total_size_limit BIGINT NOT NULL, + attachment_expiry_duration BIGINT NOT NULL, + attachment_bandwidth_limit BIGINT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT, + UNIQUE(code), + UNIQUE(stripe_monthly_price_id), + UNIQUE(stripe_yearly_price_id) + ); + CREATE TABLE IF NOT EXISTS "user" ( + id TEXT PRIMARY KEY, + tier_id TEXT REFERENCES tier(id), + user_name TEXT NOT NULL UNIQUE, + pass TEXT NOT NULL, + role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')), + prefs JSONB NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned BOOLEAN NOT NULL, + stats_messages BIGINT NOT NULL DEFAULT 0, + stats_emails BIGINT NOT NULL DEFAULT 0, + stats_calls BIGINT NOT NULL DEFAULT 0, + stripe_customer_id TEXT UNIQUE, + stripe_subscription_id TEXT UNIQUE, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until BIGINT, + stripe_subscription_cancel_at BIGINT, + created BIGINT NOT NULL, + deleted BIGINT + ); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + topic TEXT NOT NULL, + read BOOLEAN NOT NULL, + write BOOLEAN NOT NULL, + owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, topic) + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + token TEXT NOT NULL UNIQUE, + label TEXT NOT NULL, + last_access BIGINT NOT NULL, + last_origin TEXT NOT NULL, + expires BIGINT NOT NULL, + provisioned BOOLEAN NOT NULL, + PRIMARY KEY (user_id, token) + ); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number) + ); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT) + ON CONFLICT (id) DO NOTHING; + ` + + // User queries + postgresSelectUserByID = ` + SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM "user" u + LEFT JOIN tier t on t.id = u.tier_id + WHERE u.id = $1 + ` + postgresSelectUserByName = ` + SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM "user" u + LEFT JOIN tier t on t.id = u.tier_id + WHERE user_name = $1 + ` + postgresSelectUserByToken = ` + SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM "user" u + JOIN user_token tk on u.id = tk.user_id + LEFT JOIN tier t on t.id = u.tier_id + WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2) + ` + postgresSelectUserByStripeID = ` + SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM "user" u + LEFT JOIN tier t on t.id = u.tier_id + WHERE u.stripe_customer_id = $1 + ` + postgresSelectUsernames = ` + SELECT user_name + FROM "user" + ORDER BY + CASE role + WHEN 'admin' THEN 1 + WHEN 'anonymous' THEN 3 + ELSE 2 + END, user_name + ` + postgresSelectUserCount = `SELECT COUNT(*) FROM "user"` + postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1` + postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)` + postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2` + postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2` + postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2` + postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2` + postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4` + postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0` + postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2` + postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2` + postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1` + postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1` + postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1` + + // Access queries + postgresSelectTopicPerms = ` + SELECT read, write + FROM user_access a + JOIN "user" u ON u.id = a.user_id + WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\' + ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC + ` + postgresSelectUserAllAccess = ` + SELECT user_id, topic, read, write, provisioned + FROM user_access + ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic + ` + postgresSelectUserAccess = ` + SELECT topic, read, write, provisioned + FROM user_access + WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1) + ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic + ` + postgresSelectUserReservations = ` + SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write + FROM user_access a_user + LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1) + WHERE a_user.user_id = a_user.owner_user_id + AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2) + ORDER BY a_user.topic + ` + postgresSelectUserReservationsCount = ` + SELECT COUNT(*) + FROM user_access + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1) + ` + postgresSelectUserReservationsOwner = ` + SELECT owner_user_id + FROM user_access + WHERE topic = $1 + AND user_id = owner_user_id + ` + postgresSelectUserHasReservation = ` + SELECT COUNT(*) + FROM user_access + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1) + AND topic = $2 + ` + postgresSelectOtherAccessCount = ` + SELECT COUNT(*) + FROM user_access + WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\') + AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3)) + ` + postgresUpsertUserAccess = ` + INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) + VALUES ( + (SELECT id FROM "user" WHERE user_name = $1), + $2, + $3, + $4, + CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END, + $7 + ) + ON CONFLICT (user_id, topic) + DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned + ` + postgresDeleteUserAccess = ` + DELETE FROM user_access + WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1) + OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2) + ` + postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true` + postgresDeleteTopicAccess = ` + DELETE FROM user_access + WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)) + AND topic = $3 + ` + postgresDeleteAllAccess = `DELETE FROM user_access` + + // Token queries + postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2` + postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1` + postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1` + postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true` + postgresUpsertToken = ` + INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (user_id, token) + DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned + ` + postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3` + postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3` + postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3` + postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2` + postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1` + postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1` + postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1` + postgresDeleteExcessTokens = ` + DELETE FROM user_token + WHERE user_id = $1 + AND (user_id, token) NOT IN ( + SELECT user_id, token + FROM user_token + WHERE user_id = $2 + ORDER BY expires DESC + LIMIT $3 + ) + ` + + // Tier queries + postgresInsertTier = ` + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ` + postgresUpdateTier = ` + UPDATE tier + SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12 + WHERE code = $13 + ` + postgresSelectTiers = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + ` + postgresSelectTierByCode = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + WHERE code = $1 + ` + postgresSelectTierByPriceID = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2) + ` + postgresDeleteTier = `DELETE FROM tier WHERE code = $1` + + // Phone queries + postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1` + postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)` + postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2` + + // Billing queries + postgresUpdateBilling = ` + UPDATE "user" + SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6 + WHERE user_name = $7 + ` + + // Schema version queries + postgresSelectSchemaVersionExists = `SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name = 'schema_version' + )` + postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'` + postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)` +) + +// NewPostgresStore creates a new PostgreSQL-backed user store +func NewPostgresStore(dsn string) (Store, error) { + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, err + } + if err := db.Ping(); err != nil { + return nil, err + } + if err := setupPostgresDB(db); err != nil { + return nil, err + } + return &commonStore{ + db: db, + queries: postgresQueries(), + }, nil +} + +func postgresQueries() storeQueries { + return storeQueries{ + // User queries + selectUserByID: postgresSelectUserByID, + selectUserByName: postgresSelectUserByName, + selectUserByToken: postgresSelectUserByToken, + selectUserByStripeID: postgresSelectUserByStripeID, + selectUsernames: postgresSelectUsernames, + selectUserCount: postgresSelectUserCount, + selectUserIDFromUsername: postgresSelectUserIDFromUsername, + insertUser: postgresInsertUser, + updateUserPass: postgresUpdateUserPass, + updateUserRole: postgresUpdateUserRole, + updateUserProvisioned: postgresUpdateUserProvisioned, + updateUserPrefs: postgresUpdateUserPrefs, + updateUserStats: postgresUpdateUserStats, + updateUserStatsResetAll: postgresUpdateUserStatsResetAll, + updateUserTier: postgresUpdateUserTier, + updateUserDeleted: postgresUpdateUserDeleted, + deleteUser: postgresDeleteUser, + deleteUserTier: postgresDeleteUserTier, + deleteUsersMarked: postgresDeleteUsersMarked, + + // Access queries + selectTopicPerms: postgresSelectTopicPerms, + selectUserAllAccess: postgresSelectUserAllAccess, + selectUserAccess: postgresSelectUserAccess, + selectUserReservations: postgresSelectUserReservations, + selectUserReservationsCount: postgresSelectUserReservationsCount, + selectUserReservationsOwner: postgresSelectUserReservationsOwner, + selectUserHasReservation: postgresSelectUserHasReservation, + selectOtherAccessCount: postgresSelectOtherAccessCount, + upsertUserAccess: postgresUpsertUserAccess, + deleteUserAccess: postgresDeleteUserAccess, + deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned, + deleteTopicAccess: postgresDeleteTopicAccess, + deleteAllAccess: postgresDeleteAllAccess, + + // Token queries + selectToken: postgresSelectToken, + selectTokens: postgresSelectTokens, + selectTokenCount: postgresSelectTokenCount, + selectAllProvisionedTokens: postgresSelectAllProvisionedTokens, + upsertToken: postgresUpsertToken, + updateTokenLabel: postgresUpdateTokenLabel, + updateTokenExpiry: postgresUpdateTokenExpiry, + updateTokenLastAccess: postgresUpdateTokenLastAccess, + deleteToken: postgresDeleteToken, + deleteProvisionedToken: postgresDeleteProvisionedToken, + deleteAllToken: postgresDeleteAllToken, + deleteExpiredTokens: postgresDeleteExpiredTokens, + deleteExcessTokens: postgresDeleteExcessTokens, + + // Tier queries + insertTier: postgresInsertTier, + selectTiers: postgresSelectTiers, + selectTierByCode: postgresSelectTierByCode, + selectTierByPriceID: postgresSelectTierByPriceID, + updateTier: postgresUpdateTier, + deleteTier: postgresDeleteTier, + + // Phone queries + selectPhoneNumbers: postgresSelectPhoneNumbers, + insertPhoneNumber: postgresInsertPhoneNumber, + deletePhoneNumber: postgresDeletePhoneNumber, + + // Billing queries + updateBilling: postgresUpdateBilling, + } +} + +func setupPostgresDB(db *sql.DB) error { + // Check if schema version table exists + var exists bool + if err := db.QueryRow(postgresSelectSchemaVersionExists).Scan(&exists); err != nil { + return err + } + + if !exists { + // New database, create all tables + if _, err := db.Exec(postgresCreateTablesQueries); err != nil { + return err + } + if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil { + return err + } + return nil + } + + // Table exists, check if user store has a row + var schemaVersion int + err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion) + if err == sql.ErrNoRows { + // schema_version table exists (e.g. created by webpush) but no user row yet + if _, err := db.Exec(postgresCreateTablesQueries); err != nil { + return err + } + if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil { + return err + } + return nil + } else if err != nil { + return fmt.Errorf("cannot determine schema version: %v", err) + } + + if schemaVersion > currentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) + } + + // Note: PostgreSQL migrations will be added when needed + // For now, we only support new installations at the latest schema version + + return nil +} diff --git a/user/store_sqlite.go b/user/store_sqlite.go new file mode 100644 index 00000000..b2c26d37 --- /dev/null +++ b/user/store_sqlite.go @@ -0,0 +1,420 @@ +package user + +import ( + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" // SQLite driver +) + +// SQLite schema and queries +const ( + sqliteCreateTablesQueries = ` + BEGIN; + CREATE TABLE IF NOT EXISTS tier ( + id TEXT PRIMARY KEY, + code TEXT NOT NULL, + name TEXT NOT NULL, + messages_limit INT NOT NULL, + messages_expiry_duration INT NOT NULL, + emails_limit INT NOT NULL, + calls_limit INT NOT NULL, + reservations_limit INT NOT NULL, + attachment_file_size_limit INT NOT NULL, + attachment_total_size_limit INT NOT NULL, + attachment_expiry_duration INT NOT NULL, + attachment_bandwidth_limit INT NOT NULL, + stripe_monthly_price_id TEXT, + stripe_yearly_price_id TEXT + ); + CREATE UNIQUE INDEX idx_tier_code ON tier (code); + CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id); + CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id); + CREATE TABLE IF NOT EXISTS user ( + id TEXT PRIMARY KEY, + tier_id TEXT, + user TEXT NOT NULL, + pass TEXT NOT NULL, + role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL, + prefs JSON NOT NULL DEFAULT '{}', + sync_topic TEXT NOT NULL, + provisioned INT NOT NULL, + stats_messages INT NOT NULL DEFAULT (0), + stats_emails INT NOT NULL DEFAULT (0), + stats_calls INT NOT NULL DEFAULT (0), + stripe_customer_id TEXT, + stripe_subscription_id TEXT, + stripe_subscription_status TEXT, + stripe_subscription_interval TEXT, + stripe_subscription_paid_until INT, + stripe_subscription_cancel_at INT, + created INT NOT NULL, + deleted INT, + FOREIGN KEY (tier_id) REFERENCES tier (id) + ); + CREATE UNIQUE INDEX idx_user ON user (user); + CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id); + CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id); + CREATE TABLE IF NOT EXISTS user_access ( + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + read INT NOT NULL, + write INT NOT NULL, + owner_user_id INT, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, topic), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE, + FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS user_token ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, + expires INT NOT NULL, + provisioned INT NOT NULL, + PRIMARY KEY (user_id, token), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE UNIQUE INDEX idx_user_token ON user_token (token); + CREATE TABLE IF NOT EXISTS user_phone ( + user_id TEXT NOT NULL, + phone_number TEXT NOT NULL, + PRIMARY KEY (user_id, phone_number), + FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE + ); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) + VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH()) + ON CONFLICT (id) DO NOTHING; + COMMIT; + ` + + sqliteBuiltinStartupQueries = ` + PRAGMA foreign_keys = ON; + ` + + // User queries + sqliteSelectUserByID = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM user u + LEFT JOIN tier t on t.id = u.tier_id + WHERE u.id = ? + ` + sqliteSelectUserByName = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM user u + LEFT JOIN tier t on t.id = u.tier_id + WHERE user = ? + ` + sqliteSelectUserByToken = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM user u + JOIN user_token tk on u.id = tk.user_id + LEFT JOIN tier t on t.id = u.tier_id + WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) + ` + sqliteSelectUserByStripeID = ` + SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id + FROM user u + LEFT JOIN tier t on t.id = u.tier_id + WHERE u.stripe_customer_id = ? + ` + sqliteSelectUsernames = ` + SELECT user + FROM user + ORDER BY + CASE role + WHEN 'admin' THEN 1 + WHEN 'anonymous' THEN 3 + ELSE 2 + END, user + ` + sqliteSelectUserCount = `SELECT COUNT(*) FROM user` + sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?` + sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)` + sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?` + sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?` + sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?` + sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?` + sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?` + sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0` + sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?` + sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?` + sqliteDeleteUser = `DELETE FROM user WHERE user = ?` + sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?` + sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?` + + // Access queries + sqliteSelectTopicPerms = ` + SELECT read, write + FROM user_access a + JOIN user u ON u.id = a.user_id + WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\' + ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC + ` + sqliteSelectUserAllAccess = ` + SELECT user_id, topic, read, write, provisioned + FROM user_access + ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic + ` + sqliteSelectUserAccess = ` + SELECT topic, read, write, provisioned + FROM user_access + WHERE user_id = (SELECT id FROM user WHERE user = ?) + ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic + ` + sqliteSelectUserReservations = ` + SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write + FROM user_access a_user + LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?) + WHERE a_user.user_id = a_user.owner_user_id + AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?) + ORDER BY a_user.topic + ` + sqliteSelectUserReservationsCount = ` + SELECT COUNT(*) + FROM user_access + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM user WHERE user = ?) + ` + sqliteSelectUserReservationsOwner = ` + SELECT owner_user_id + FROM user_access + WHERE topic = ? + AND user_id = owner_user_id + ` + sqliteSelectUserHasReservation = ` + SELECT COUNT(*) + FROM user_access + WHERE user_id = owner_user_id + AND owner_user_id = (SELECT id FROM user WHERE user = ?) + AND topic = ? + ` + sqliteSelectOtherAccessCount = ` + SELECT COUNT(*) + FROM user_access + WHERE (topic = ? OR ? LIKE topic ESCAPE '\') + AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) + ` + sqliteUpsertUserAccess = ` + INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned) + VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?) + ON CONFLICT (user_id, topic) + DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned + ` + sqliteDeleteUserAccess = ` + DELETE FROM user_access + WHERE user_id = (SELECT id FROM user WHERE user = ?) + OR owner_user_id = (SELECT id FROM user WHERE user = ?) + ` + sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1` + sqliteDeleteTopicAccess = ` + DELETE FROM user_access + WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?)) + AND topic = ? + ` + sqliteDeleteAllAccess = `DELETE FROM user_access` + + // Token queries + sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?` + sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?` + sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` + sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1` + sqliteUpsertToken = ` + INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT (user_id, token) + DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned; + ` + sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` + sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` + sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` + sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?` + sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?` + sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?` + sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` + sqliteDeleteExcessTokens = ` + DELETE FROM user_token + WHERE user_id = ? + AND (user_id, token) NOT IN ( + SELECT user_id, token + FROM user_token + WHERE user_id = ? + ORDER BY expires DESC + LIMIT ? + ) + ` + + // Tier queries + sqliteInsertTier = ` + INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + sqliteUpdateTier = ` + UPDATE tier + SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ? + WHERE code = ? + ` + sqliteSelectTiers = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + ` + sqliteSelectTierByCode = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + WHERE code = ? + ` + sqliteSelectTierByPriceID = ` + SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id + FROM tier + WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?) + ` + sqliteDeleteTier = `DELETE FROM tier WHERE code = ?` + + // Phone queries + sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?` + sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)` + sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?` + + // Billing queries + sqliteUpdateBilling = ` + UPDATE user + SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ? + WHERE user = ? + ` +) + +// NewSQLiteStore creates a new SQLite-backed user store +func NewSQLiteStore(filename, startupQueries string) (Store, error) { + db, err := sql.Open("sqlite3", filename) + if err != nil { + return nil, err + } + if err := setupSQLiteDB(db); err != nil { + return nil, err + } + if err := runSQLiteStartupQueries(db, startupQueries); err != nil { + return nil, err + } + return &commonStore{ + db: db, + queries: sqliteQueries(), + }, nil +} + +func sqliteQueries() storeQueries { + return storeQueries{ + selectUserByID: sqliteSelectUserByID, + selectUserByName: sqliteSelectUserByName, + selectUserByToken: sqliteSelectUserByToken, + selectUserByStripeID: sqliteSelectUserByStripeID, + selectUsernames: sqliteSelectUsernames, + selectUserCount: sqliteSelectUserCount, + selectUserIDFromUsername: sqliteSelectUserIDFromUsername, + insertUser: sqliteInsertUser, + updateUserPass: sqliteUpdateUserPass, + updateUserRole: sqliteUpdateUserRole, + updateUserProvisioned: sqliteUpdateUserProvisioned, + updateUserPrefs: sqliteUpdateUserPrefs, + updateUserStats: sqliteUpdateUserStats, + updateUserStatsResetAll: sqliteUpdateUserStatsResetAll, + updateUserTier: sqliteUpdateUserTier, + updateUserDeleted: sqliteUpdateUserDeleted, + deleteUser: sqliteDeleteUser, + deleteUserTier: sqliteDeleteUserTier, + deleteUsersMarked: sqliteDeleteUsersMarked, + selectTopicPerms: sqliteSelectTopicPerms, + selectUserAllAccess: sqliteSelectUserAllAccess, + selectUserAccess: sqliteSelectUserAccess, + selectUserReservations: sqliteSelectUserReservations, + selectUserReservationsCount: sqliteSelectUserReservationsCount, + selectUserReservationsOwner: sqliteSelectUserReservationsOwner, + selectUserHasReservation: sqliteSelectUserHasReservation, + selectOtherAccessCount: sqliteSelectOtherAccessCount, + upsertUserAccess: sqliteUpsertUserAccess, + deleteUserAccess: sqliteDeleteUserAccess, + deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned, + deleteTopicAccess: sqliteDeleteTopicAccess, + deleteAllAccess: sqliteDeleteAllAccess, + selectToken: sqliteSelectToken, + selectTokens: sqliteSelectTokens, + selectTokenCount: sqliteSelectTokenCount, + selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens, + upsertToken: sqliteUpsertToken, + updateTokenLabel: sqliteUpdateTokenLabel, + updateTokenExpiry: sqliteUpdateTokenExpiry, + updateTokenLastAccess: sqliteUpdateTokenLastAccess, + deleteToken: sqliteDeleteToken, + deleteProvisionedToken: sqliteDeleteProvisionedToken, + deleteAllToken: sqliteDeleteAllToken, + deleteExpiredTokens: sqliteDeleteExpiredTokens, + deleteExcessTokens: sqliteDeleteExcessTokens, + insertTier: sqliteInsertTier, + selectTiers: sqliteSelectTiers, + selectTierByCode: sqliteSelectTierByCode, + selectTierByPriceID: sqliteSelectTierByPriceID, + updateTier: sqliteUpdateTier, + deleteTier: sqliteDeleteTier, + selectPhoneNumbers: sqliteSelectPhoneNumbers, + insertPhoneNumber: sqliteInsertPhoneNumber, + deletePhoneNumber: sqliteDeletePhoneNumber, + updateBilling: sqliteUpdateBilling, + } +} + +func setupSQLiteDB(db *sql.DB) error { + rowsSV, err := db.Query(selectSchemaVersionQuery) + if err != nil { + return setupNewSQLiteDB(db) + } + defer rowsSV.Close() + schemaVersion := 0 + if !rowsSV.Next() { + return fmt.Errorf("cannot determine schema version: database file may be corrupt") + } + if err := rowsSV.Scan(&schemaVersion); err != nil { + return err + } + rowsSV.Close() + if schemaVersion == currentSchemaVersion { + return nil + } else if schemaVersion > currentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) + } + for i := schemaVersion; i < currentSchemaVersion; i++ { + fn, ok := migrations[i] + if !ok { + return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) + } else if err := fn(db); err != nil { + return err + } + } + return nil +} + +func setupNewSQLiteDB(db *sql.DB) error { + if _, err := db.Exec(sqliteCreateTablesQueries); err != nil { + return err + } + if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { + return err + } + return nil +} + +func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { + if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil { + return err + } + if startupQueries != "" { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + } + return nil +} diff --git a/user/migrations.go b/user/store_sqlite_migrations.go similarity index 99% rename from user/migrations.go rename to user/store_sqlite_migrations.go index a7eb317f..442bcc05 100644 --- a/user/migrations.go +++ b/user/store_sqlite_migrations.go @@ -2,11 +2,12 @@ package user import ( "database/sql" + "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/util" ) -// Schema management queries +// SQLite migrations const ( currentSchemaVersion = 6 insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` diff --git a/user/types.go b/user/types.go index c65cb863..be589643 100644 --- a/user/types.go +++ b/user/types.go @@ -242,6 +242,20 @@ const ( everyoneID = "u_everyone" ) +// Config holds the configuration for the user Manager +type Config struct { + Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite) + DatabaseURL string // Database connection string (PostgreSQL) + StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only) + DefaultAccess Permission // Default permission if no ACL matches + ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands + Users []*User // Predefined users to create on startup + Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant) + Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token) + QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database + BcryptCost int // Cost of generated passwords; lowering makes testing faster +} + // Error constants used by the package var ( ErrUnauthenticated = errors.New("unauthenticated") diff --git a/webpush/store_postgres.go b/webpush/store_postgres.go index da69c6e1..fdec1f29 100644 --- a/webpush/store_postgres.go +++ b/webpush/store_postgres.go @@ -25,8 +25,8 @@ const ( PRIMARY KEY (subscription_id, topic) ); CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic); - CREATE TABLE IF NOT EXISTS webpush_schema_version ( - id INT PRIMARY KEY, + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, version INT NOT NULL ); ` @@ -65,8 +65,8 @@ const ( // PostgreSQL schema management queries const ( pgCurrentSchemaVersion = 1 - pgInsertSchemaVersion = `INSERT INTO webpush_schema_version VALUES (1, $1)` - pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1` + pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)` + pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'` ) // NewPostgresStore creates a new PostgreSQL-backed web push store. @@ -102,12 +102,16 @@ func NewPostgresStore(dsn string) (Store, error) { } func setupPostgresDB(db *sql.DB) error { - // If 'webpush_schema_version' table does not exist, this must be a new database + // If 'schema_version' table does not exist or no webpush row, this must be a new database rows, err := db.Query(pgSelectSchemaVersionQuery) if err != nil { return setupNewPostgresDB(db) } - return rows.Close() + defer rows.Close() + if !rows.Next() { + return setupNewPostgresDB(db) + } + return nil } func setupNewPostgresDB(db *sql.DB) error {