diff --git a/cmd/user.go b/cmd/user.go index 808920be..e0041151 100644 --- a/cmd/user.go +++ b/cmd/user.go @@ -378,25 +378,19 @@ func createUserManager(c *cli.Context) (*user.Manager, error) { BcryptCost: user.DefaultUserPasswordBcryptCost, QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, } - var store user.Store if databaseURL != "" { pool, dbErr := db.OpenPostgres(databaseURL) if dbErr != nil { return nil, dbErr } - store, err = user.NewPostgresStore(pool) + return user.NewPostgresManager(pool, authConfig) } else if authFile != "" { if !util.FileExists(authFile) { return nil, errors.New("auth-file does not exist; please start the server at least once to create it") } - store, err = user.NewSQLiteStore(authFile, authStartupQueries) - } else { - return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server") + return user.NewSQLiteManager(authFile, authStartupQueries, authConfig) } - if err != nil { - return nil, err - } - return user.NewManager(store, authConfig) + return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server") } func readPasswordAndConfirm(c *cli.Context) (string, error) { diff --git a/server/server.go b/server/server.go index 1b6c8d10..55caaa53 100644 --- a/server/server.go +++ b/server/server.go @@ -235,19 +235,14 @@ func New(conf *Config) (*Server, error) { BcryptCost: conf.AuthBcryptCost, QueueWriterInterval: conf.AuthStatsQueueWriterInterval, } - var store user.Store if pool != nil { - store, err = user.NewPostgresStore(pool) + userManager, err = user.NewPostgresManager(pool, authConfig) } else { - store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries) + userManager, err = user.NewSQLiteManager(conf.AuthFile, conf.AuthStartupQueries, authConfig) } if err != nil { return nil, err } - userManager, err = user.NewManager(store, authConfig) - if err != nil { - return nil, err - } } var firebaseClient *firebaseClient if conf.FirebaseKeyFile != "" { diff --git a/user/manager.go b/user/manager.go index 34a6d33b..d18bbc8f 100644 --- a/user/manager.go +++ b/user/manager.go @@ -2,15 +2,19 @@ package user import ( + "database/sql" + "encoding/json" "errors" "fmt" "net/netip" "slices" + "strings" "sync" "time" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" ) @@ -44,7 +48,8 @@ var ( // Manager handles user authentication, authorization, and management type Manager struct { config *Config - store Store // Database store + db *sql.DB + queries storeQueries statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate) mu sync.Mutex @@ -52,26 +57,21 @@ type Manager struct { var _ Auther = (*Manager)(nil) -// NewManager creates a new Manager instance -func NewManager(store Store, config *Config) (*Manager, error) { - // Set defaults - if config.BcryptCost <= 0 { - config.BcryptCost = DefaultUserPasswordBcryptCost +// initManager sets defaults and runs startup tasks common to all backends +func initManager(manager *Manager) error { + if manager.config.BcryptCost <= 0 { + manager.config.BcryptCost = DefaultUserPasswordBcryptCost } - if config.QueueWriterInterval.Seconds() <= 0 { - config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval - } - manager := &Manager{ - store: store, - config: config, - statsQueue: make(map[string]*Stats), - tokenQueue: make(map[string]*TokenUpdate), + if manager.config.QueueWriterInterval.Seconds() <= 0 { + manager.config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval } + manager.statsQueue = make(map[string]*Stats) + manager.tokenQueue = make(map[string]*TokenUpdate) if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil { - return nil, err + return err } - go manager.asyncQueueWriter(config.QueueWriterInterval) - return manager, nil + go manager.asyncQueueWriter(manager.config.QueueWriterInterval) + return nil } // Authenticate checks username and password and returns a User if correct, and the user has not been @@ -81,7 +81,7 @@ func (a *Manager) Authenticate(username, password string) (*User, error) { if username == Everyone { return nil, ErrUnauthenticated } - user, err := a.store.User(username) + user, err := a.User(username) if err != nil { log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)") bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) @@ -103,7 +103,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { if len(token) != tokenLength { return nil, ErrUnauthenticated } - user, err := a.store.UserByToken(token) + user, err := a.UserByToken(token) if err != nil { log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed") return nil, ErrUnauthenticated @@ -116,17 +116,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { - return a.store.CreateToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) -} - -// Tokens returns all existing tokens for the user with the given user ID -func (a *Manager) Tokens(userID string) ([]*Token, error) { - return a.store.Tokens(userID) -} - -// Token returns a specific token for a user -func (a *Manager) Token(userID, token string) (*Token, error) { - return a.store.Token(userID, token) + return a.createToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) } // ChangeToken updates a token's label and/or expiry date @@ -137,7 +127,7 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time if err := a.canChangeToken(userID, token); err != nil { return nil, err } - t, err := a.store.Token(userID, token) + t, err := a.Token(userID, token) if err != nil { return nil, err } @@ -147,7 +137,7 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time if expires != nil { t.Expires = *expires } - if err := a.store.ChangeToken(userID, token, t.Label, t.Expires); err != nil { + if err := a.changeToken(userID, token, t.Label, t.Expires); err != nil { return nil, err } return t, nil @@ -158,7 +148,7 @@ func (a *Manager) RemoveToken(userID, token string) error { if err := a.canChangeToken(userID, token); err != nil { return err } - return a.store.RemoveToken(userID, token) + return a.removeToken(userID, token) } // canChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed. @@ -172,47 +162,25 @@ func (a *Manager) canChangeToken(userID, token string) error { return nil } -// RemoveExpiredTokens deletes all expired tokens from the database -func (a *Manager) RemoveExpiredTokens() error { - return a.store.RemoveExpiredTokens() -} - -// PhoneNumbers returns all phone numbers for the user with the given user ID -func (a *Manager) PhoneNumbers(userID string) ([]string, error) { - return a.store.PhoneNumbers(userID) -} - -// AddPhoneNumber adds a phone number to the user with the given user ID -func (a *Manager) AddPhoneNumber(userID string, phoneNumber string) error { - return a.store.AddPhoneNumber(userID, phoneNumber) -} - -// RemovePhoneNumber deletes a phone number from the user with the given user ID -func (a *Manager) RemovePhoneNumber(userID string, phoneNumber string) error { - return a.store.RemovePhoneNumber(userID, phoneNumber) -} - -// RemoveDeletedUsers deletes all users that have been marked deleted -func (a *Manager) RemoveDeletedUsers() error { - return a.store.RemoveDeletedUsers() -} - -// ChangeSettings persists the user settings -func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error { - return a.store.ChangeSettings(userID, prefs) -} - // ResetStats resets all user stats in the user database. This touches all users. func (a *Manager) ResetStats() error { a.mu.Lock() // Includes database query to avoid races! defer a.mu.Unlock() - if err := a.store.ResetStats(); err != nil { + if err := a.resetStats(); err != nil { return err } a.statsQueue = make(map[string]*Stats) return nil } +// resetStats resets all user stats in the user database +func (a *Manager) resetStats() error { + if _, err := a.db.Exec(a.queries.updateUserStatsResetAll); err != nil { + return err + } + return nil +} + // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in // batches at a regular interval func (a *Manager) EnqueueUserStats(userID string, stats *Stats) { @@ -264,7 +232,7 @@ func (a *Manager) writeUserStatsQueue() error { }). Trace("Updating stats for user %s", userID) } - return a.store.UpdateStats(statsQueue) + return a.UpdateStats(statsQueue) } func (a *Manager) writeTokenUpdateQueue() error { @@ -282,7 +250,7 @@ func (a *Manager) writeTokenUpdateQueue() error { for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) } - return a.store.UpdateTokenLastAccess(tokenQueue) + return a.UpdateTokenLastAccess(tokenQueue) } // Authorize returns nil if the given user has access to the given topic using the desired @@ -296,7 +264,7 @@ func (a *Manager) Authorize(user *User, topic string, perm Permission) error { username = user.Name } // Select the read/write permissions for this user/topic combo. - read, write, found, err := a.store.AuthorizeTopicAccess(username, topic) + read, write, found, err := a.AuthorizeTopicAccess(username, topic) if err != nil { return err } @@ -337,7 +305,7 @@ func (a *Manager) addUser(username, password string, role Role, hashed, provisio return err } } - return a.store.AddUser(username, hash, role, provisioned) + return a.insertUser(username, hash, role, provisioned) } // RemoveUser deletes the user with the given username. The function returns nil on success, even @@ -346,7 +314,7 @@ func (a *Manager) RemoveUser(username string) error { if err := a.CanChangeUser(username); err != nil { return err } - return a.store.RemoveUser(username) + return a.removeUser(username) } // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents @@ -355,64 +323,7 @@ func (a *Manager) MarkUserRemoved(user *User) error { if !AllowedUsername(user.Name) { return ErrInvalidArgument } - return a.store.MarkUserRemoved(user.ID, user.Name) -} - -// Users returns a list of users. It always also returns the Everyone user ("*"). -func (a *Manager) Users() ([]*User, error) { - return a.store.Users() -} - -// UsersCount returns the number of users in the database -func (a *Manager) UsersCount() (int64, error) { - return a.store.UsersCount() -} - -// User returns the user with the given username if it exists, or ErrUserNotFound otherwise. -// You may also pass Everyone to retrieve the anonymous user and its Grant list. -func (a *Manager) User(username string) (*User, error) { - return a.store.User(username) -} - -// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise -func (a *Manager) UserByID(id string) (*User, error) { - return a.store.UserByID(id) -} - -// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise. -func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) { - return a.store.UserByStripeCustomer(stripeCustomerID) -} - -// AllGrants returns all user-specific access control entries, mapped to their respective user IDs -func (a *Manager) AllGrants() (map[string][]Grant, error) { - return a.store.AllGrants() -} - -// Grants returns all user-specific access control entries -func (a *Manager) Grants(username string) ([]Grant, error) { - return a.store.Grants(username) -} - -// Reservations returns all user-owned topics, and the associated everyone-access -func (a *Manager) Reservations(username string) ([]Reservation, error) { - return a.store.Reservations(username) -} - -// HasReservation returns true if the given topic access is owned by the user -func (a *Manager) HasReservation(username, topic string) (bool, error) { - return a.store.HasReservation(username, topic) -} - -// ReservationsCount returns the number of reservations owned by this user -func (a *Manager) ReservationsCount(username string) (int64, error) { - return a.store.ReservationsCount(username) -} - -// ReservationOwner returns user ID of the user that owns this topic, or an -// empty string if it's not owned by anyone -func (a *Manager) ReservationOwner(topic string) (string, error) { - return a.store.ReservationOwner(topic) + return a.markUserRemoved(user.ID, user.Name) } // ChangePassword changes a user's password @@ -433,7 +344,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error { return err } } - return a.store.ChangePassword(username, hash) + return a.changePasswordHash(username, hash) } // CanChangeUser checks if the user with the given username can be changed. @@ -454,7 +365,7 @@ func (a *Manager) ChangeRole(username string, role Role) error { if err := a.CanChangeUser(username); err != nil { return err } - return a.store.ChangeRole(username, role) + return a.changeRole(username, role) } // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages, @@ -469,7 +380,7 @@ func (a *Manager) ChangeTier(username, tier string) error { } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { return err } - return a.store.ChangeTier(username, tier) + return a.changeTierCode(username, tier) } // ResetTier removes the tier from the given user @@ -479,7 +390,7 @@ func (a *Manager) ResetTier(username string) error { } else if err := a.checkReservationsLimit(username, 0); err != nil { return err } - return a.store.ResetTier(username) + return a.resetTierCode(username) } func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { @@ -504,7 +415,7 @@ func (a *Manager) AllowReservation(username string, topic string) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { return ErrInvalidArgument } - otherCount, err := a.store.OtherAccessCount(username, topic) + otherCount, err := a.OtherAccessCount(username, topic) if err != nil { return err } @@ -527,7 +438,7 @@ func (a *Manager) allowAccess(username string, topicPattern string, permission P } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } - return a.store.AllowAccess(username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned) + return a.allowAccessTx(a.db, username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned) } // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is @@ -542,31 +453,13 @@ func (a *Manager) resetAccess(username string, topicPattern string) error { } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { return ErrInvalidArgument } - return a.store.ResetAccess(username, topicPattern) -} - -// AddReservation creates two access control entries for the given topic: one with full read/write access for the -// given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and -// can modify or delete them. -func (a *Manager) AddReservation(username string, topic string, everyone Permission) error { - if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { - return ErrInvalidArgument + if username == "" && topicPattern == "" { + _, err := a.db.Exec(a.queries.deleteAllAccess) + return err + } else if topicPattern == "" { + return a.resetUserAccessTx(a.db, username) } - return a.store.AddReservation(username, topic, everyone) -} - -// RemoveReservations deletes the access control entries associated with the given username/topic, as -// well as all entries with Everyone/topic. This is the counterpart for AddReservation. -func (a *Manager) RemoveReservations(username string, topics ...string) error { - if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { - return ErrInvalidArgument - } - for _, topic := range topics { - if !AllowedTopic(topic) { - return ErrInvalidArgument - } - } - return a.store.RemoveReservations(username, topics...) + return a.resetTopicAccessTx(a.db, username, topicPattern) } // DefaultAccess returns the default read/write access if no access control entry matches @@ -574,44 +467,9 @@ func (a *Manager) DefaultAccess() Permission { return a.config.DefaultAccess } -// AddTier creates a new tier in the database -func (a *Manager) AddTier(tier *Tier) error { - return a.store.AddTier(tier) -} - -// UpdateTier updates a tier's properties in the database -func (a *Manager) UpdateTier(tier *Tier) error { - return a.store.UpdateTier(tier) -} - -// RemoveTier deletes the tier with the given code -func (a *Manager) RemoveTier(code string) error { - return a.store.RemoveTier(code) -} - -// ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information -func (a *Manager) ChangeBilling(username string, billing *Billing) error { - return a.store.ChangeBilling(username, billing) -} - -// Tiers returns a list of all Tier structs -func (a *Manager) Tiers() ([]*Tier, error) { - return a.store.Tiers() -} - -// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist -func (a *Manager) Tier(code string) (*Tier, error) { - return a.store.Tier(code) -} - -// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist -func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { - return a.store.TierByStripePrice(priceID) -} - // Close closes the underlying database func (a *Manager) Close() error { - return a.store.Close() + return a.db.Close() } // maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config. @@ -646,7 +504,7 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers if user.Name == Everyone { continue } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { - if err := a.store.RemoveUser(user.Name); err != nil { + if err := a.removeUser(user.Name); err != nil { return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) } } @@ -665,17 +523,17 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers } } else { if !existingUser.Provisioned { - if err := a.store.ChangeProvisioned(user.Name, true); err != nil { + if err := a.ChangeProvisioned(user.Name, true); err != nil { return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) } } if existingUser.Hash != user.Hash { - if err := a.store.ChangePassword(user.Name, user.Hash); err != nil { + if err := a.changePasswordHash(user.Name, user.Hash); err != nil { return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) } } if existingUser.Role != user.Role { - if err := a.store.ChangeRole(user.Name, user.Role); err != nil { + if err := a.changeRole(user.Name, user.Role); err != nil { return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) } } @@ -690,7 +548,7 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers // access time) or do not have dependent resources (such as grants or tokens). func (a *Manager) maybeProvisionGrants() error { // Remove all provisioned grants - if err := a.store.ResetAllProvisionedAccess(); err != nil { + if err := a.ResetAllProvisionedAccess(); err != nil { return err } // (Re-)add provisioned grants @@ -717,7 +575,7 @@ func (a *Manager) maybeProvisionGrants() error { func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { // Remove tokens that are provisioned, but not in the config anymore - existingTokens, err := a.store.AllProvisionedTokens() + existingTokens, err := a.AllProvisionedTokens() if err != nil { return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err) } @@ -729,7 +587,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { } for _, existingToken := range existingTokens { if !slices.Contains(provisionTokens, existingToken.Value) { - if err := a.store.RemoveProvisionedToken(existingToken.Value); err != nil { + if err := a.RemoveProvisionedToken(existingToken.Value); err != nil { return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err) } } @@ -739,15 +597,900 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { if !slices.Contains(provisionUsernames, username) && username != Everyone { return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) } - userID, err := a.store.UserIDByUsername(username) + userID, err := a.UserIDByUsername(username) if err != nil { return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err) } for _, token := range tokens { - if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil { + if _, err := a.createToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil { return err } } } return nil } + +// User returns the user with the given username if it exists, or ErrUserNotFound otherwise +func (a *Manager) User(username string) (*User, error) { + rows, err := a.db.Query(a.queries.selectUserByName, username) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + +// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise +func (a *Manager) UserByID(id string) (*User, error) { + rows, err := a.db.Query(a.queries.selectUserByID, id) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + +// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise +func (a *Manager) UserByToken(token string) (*User, error) { + rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix()) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + +// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise +func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { + rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID) + if err != nil { + return nil, err + } + return a.readUser(rows) +} + +// Users returns a list of users +func (a *Manager) Users() ([]*User, error) { + rows, err := a.db.Query(a.queries.selectUsernames) + if err != nil { + return nil, err + } + defer rows.Close() + usernames := make([]string, 0) + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + usernames = append(usernames, username) + } + rows.Close() + users := make([]*User, 0) + for _, username := range usernames { + user, err := a.User(username) + if err != nil { + return nil, err + } + users = append(users, user) + } + return users, nil +} + +// UsersCount returns the number of users in the database +func (a *Manager) UsersCount() (int64, error) { + rows, err := a.db.Query(a.queries.selectUserCount) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// insertUser adds a user with the given username, password hash and role to the database +func (a *Manager) insertUser(username, hash string, role Role, provisioned bool) error { + if !AllowedUsername(username) || !AllowedRole(role) { + return ErrInvalidArgument + } + userID := util.RandomStringPrefix(userIDPrefix, userIDLength) + syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) + now := time.Now().Unix() + if _, err := a.db.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil { + if isUniqueConstraintError(err) { + return ErrUserExists + } + return err + } + return nil +} + +// removeUser deletes the user with the given username +func (a *Manager) removeUser(username string) error { + if !AllowedUsername(username) { + return ErrInvalidArgument + } + // Rows in user_access, user_token, etc. are deleted via foreign keys + if _, err := a.db.Exec(a.queries.deleteUser, username); err != nil { + return err + } + return nil +} + +// markUserRemoved sets the deleted flag on the user, and deletes all access tokens +func (a *Manager) markUserRemoved(userID, username string) error { + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := a.resetUserAccessTx(tx, username); err != nil { + return err + } + if _, err := tx.Exec(a.queries.deleteAllToken, userID); err != nil { + return err + } + deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix() + if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, userID); err != nil { + return err + } + return tx.Commit() +} + +// RemoveDeletedUsers deletes all users that have been marked deleted +func (a *Manager) RemoveDeletedUsers() error { + if _, err := a.db.Exec(a.queries.deleteUsersMarked, time.Now().Unix()); err != nil { + return err + } + return nil +} + +// changePasswordHash changes a user's password hash in the database +func (a *Manager) changePasswordHash(username, hash string) error { + if _, err := a.db.Exec(a.queries.updateUserPass, hash, username); err != nil { + return err + } + return nil +} + +// changeRole changes a user's role +func (a *Manager) changeRole(username string, role Role) error { + if !AllowedUsername(username) || !AllowedRole(role) { + return ErrInvalidArgument + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil { + return err + } + // If changing to admin, remove all access entries + if role == RoleAdmin { + if err := a.resetUserAccessTx(tx, username); err != nil { + return err + } + } + return tx.Commit() +} + +// ChangeProvisioned changes the provisioned status of a user +func (a *Manager) ChangeProvisioned(username string, provisioned bool) error { + if _, err := a.db.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil { + return err + } + return nil +} + +// ChangeSettings persists the user settings +func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error { + b, err := json.Marshal(prefs) + if err != nil { + return err + } + if _, err := a.db.Exec(a.queries.updateUserPrefs, string(b), userID); err != nil { + return err + } + return nil +} + +// changeTierCode changes a user's tier using the tier code in the database +func (a *Manager) changeTierCode(username, tierCode string) error { + if _, err := a.db.Exec(a.queries.updateUserTier, tierCode, username); err != nil { + return err + } + return nil +} + +// resetTierCode removes the tier from the given user in the database +func (a *Manager) resetTierCode(username string) error { + if !AllowedUsername(username) && username != Everyone && username != "" { + return ErrInvalidArgument + } + _, err := a.db.Exec(a.queries.deleteUserTier, username) + return err +} + +// UpdateStats updates statistics for one or more users in a single transaction +func (a *Manager) UpdateStats(stats map[string]*Stats) error { + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for userID, update := range stats { + if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil { + return err + } + } + return tx.Commit() +} + +func (a *Manager) readUser(rows *sql.Rows) (*User, error) { + defer rows.Close() + var id, username, hash, role, prefs, syncTopic string + var provisioned bool + var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString + var messages, emails, calls int64 + var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 + if !rows.Next() { + return nil, ErrUserNotFound + } + if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + user := &User{ + ID: id, + Name: username, + Hash: hash, + Role: Role(role), + Prefs: &Prefs{}, + SyncTopic: syncTopic, + Provisioned: provisioned, + Stats: &Stats{ + Messages: messages, + Emails: emails, + Calls: calls, + }, + Billing: &Billing{ + StripeCustomerID: stripeCustomerID.String, // May be empty + StripeSubscriptionID: stripeSubscriptionID.String, // May be empty + StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty + StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty + StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero + StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero + }, + Deleted: deleted.Valid, + } + if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { + return nil, err + } + if tierCode.Valid { + // See readTier() when this is changed! + user.Tier = &Tier{ + ID: tierID.String, + Code: tierCode.String, + Name: tierName.String, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + CallLimit: callsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, + AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, + AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, + StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty + StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty + } + } + return user, nil +} + +// createToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount. +// If maxTokenCount is 0, no pruning is performed. +func (a *Manager) createToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) { + tx, err := a.db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { + return nil, err + } + if maxTokenCount > 0 { + var tokenCount int + if err := tx.QueryRow(a.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil { + return nil, err + } + if tokenCount > maxTokenCount { + if _, err := tx.Exec(a.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { + return nil, err + } + } + } + if err := tx.Commit(); err != nil { + return nil, err + } + return &Token{ + Value: token, + Label: label, + LastAccess: lastAccess, + LastOrigin: lastOrigin, + Expires: expires, + Provisioned: provisioned, + }, nil +} + +// Token returns a specific token for a user +func (a *Manager) Token(userID, token string) (*Token, error) { + rows, err := a.db.Query(a.queries.selectToken, userID, token) + if err != nil { + return nil, err + } + defer rows.Close() + return a.readToken(rows) +} + +// Tokens returns all existing tokens for the user with the given user ID +func (a *Manager) Tokens(userID string) ([]*Token, error) { + rows, err := a.db.Query(a.queries.selectTokens, userID) + if err != nil { + return nil, err + } + defer rows.Close() + tokens := make([]*Token, 0) + for { + token, err := a.readToken(rows) + if errors.Is(err, ErrTokenNotFound) { + break + } else if err != nil { + return nil, err + } + tokens = append(tokens, token) + } + return tokens, nil +} + +// AllProvisionedTokens returns all provisioned tokens +func (a *Manager) AllProvisionedTokens() ([]*Token, error) { + rows, err := a.db.Query(a.queries.selectAllProvisionedTokens) + if err != nil { + return nil, err + } + defer rows.Close() + tokens := make([]*Token, 0) + for { + token, err := a.readToken(rows) + if errors.Is(err, ErrTokenNotFound) { + break + } else if err != nil { + return nil, err + } + tokens = append(tokens, token) + } + return tokens, nil +} + +// changeToken updates a token's label and expiry time +func (a *Manager) changeToken(userID, token, label string, expires time.Time) error { + if _, err := a.db.Exec(a.queries.updateToken, label, expires.Unix(), userID, token); err != nil { + return err + } + return nil +} + +// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction +func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error { + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for token, update := range updates { + if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil { + return err + } + } + return tx.Commit() +} + +// removeToken deletes the token +func (a *Manager) removeToken(userID, token string) error { + if token == "" { + return errNoTokenProvided + } + if _, err := a.db.Exec(a.queries.deleteToken, userID, token); err != nil { + return err + } + return nil +} + +// RemoveProvisionedToken deletes a provisioned token by value, regardless of user +func (a *Manager) RemoveProvisionedToken(token string) error { + if token == "" { + return errNoTokenProvided + } + if _, err := a.db.Exec(a.queries.deleteProvisionedToken, token); err != nil { + return err + } + return nil +} + +// RemoveExpiredTokens deletes all expired tokens from the database +func (a *Manager) RemoveExpiredTokens() error { + if _, err := a.db.Exec(a.queries.deleteExpiredTokens, time.Now().Unix()); err != nil { + return err + } + return nil +} + +func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { + var token, label, lastOrigin string + var lastAccess, expires int64 + var provisioned bool + if !rows.Next() { + return nil, ErrTokenNotFound + } + if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + lastOriginIP, err := netip.ParseAddr(lastOrigin) + if err != nil { + lastOriginIP = netip.IPv4Unspecified() + } + return &Token{ + Value: token, + Label: label, + LastAccess: time.Unix(lastAccess, 0), + LastOrigin: lastOriginIP, + Expires: time.Unix(expires, 0), + Provisioned: provisioned, + }, nil +} + +// AuthorizeTopicAccess returns the read/write permissions for the given username and topic. +// The found return value indicates whether an ACL entry was found at all. +// +// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. +// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" +// - It also prioritizes write permissions over read permissions +func (a *Manager) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { + rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) + if err != nil { + return false, false, false, err + } + defer rows.Close() + if !rows.Next() { + return false, false, false, nil + } + if err := rows.Scan(&read, &write); err != nil { + return false, false, false, err + } else if err := rows.Err(); err != nil { + return false, false, false, err + } + return read, write, true, nil +} + +// AllGrants returns all user-specific access control entries, mapped to their respective user IDs +func (a *Manager) AllGrants() (map[string][]Grant, error) { + rows, err := a.db.Query(a.queries.selectUserAllAccess) + if err != nil { + return nil, err + } + defer rows.Close() + grants := make(map[string][]Grant, 0) + for rows.Next() { + var userID, topic string + var read, write, provisioned bool + if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + if _, ok := grants[userID]; !ok { + grants[userID] = make([]Grant, 0) + } + grants[userID] = append(grants[userID], Grant{ + TopicPattern: fromSQLWildcard(topic), + Permission: NewPermission(read, write), + Provisioned: provisioned, + }) + } + return grants, nil +} + +// Grants returns all user-specific access control entries +func (a *Manager) Grants(username string) ([]Grant, error) { + rows, err := a.db.Query(a.queries.selectUserAccess, username) + if err != nil { + return nil, err + } + defer rows.Close() + grants := make([]Grant, 0) + for rows.Next() { + var topic string + var read, write, provisioned bool + if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + grants = append(grants, Grant{ + TopicPattern: fromSQLWildcard(topic), + Permission: NewPermission(read, write), + Provisioned: provisioned, + }) + } + return grants, nil +} + +func (a *Manager) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { + if !AllowedUsername(username) && username != Everyone { + return ErrInvalidArgument + } else if !AllowedTopicPattern(topicPattern) { + return ErrInvalidArgument + } + _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned) + return err +} + +func (a *Manager) resetUserAccessTx(tx execer, username string) error { + if !AllowedUsername(username) && username != Everyone { + return ErrInvalidArgument + } + _, err := tx.Exec(a.queries.deleteUserAccess, username, username) + return err +} + +func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) error { + if !AllowedUsername(username) && username != Everyone && username != "" { + return ErrInvalidArgument + } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { + return ErrInvalidArgument + } + _, err := tx.Exec(a.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) + return err +} + +// ResetAllProvisionedAccess removes all provisioned access control entries +func (a *Manager) ResetAllProvisionedAccess() error { + if _, err := a.db.Exec(a.queries.deleteUserAccessProvisioned); err != nil { + return err + } + return nil +} + +// AddReservation creates two access control entries for the given topic: one with full read/write +// access for the given user, and one for Everyone with the given permission. Both entries are +// created atomically in a single transaction. +func (a *Manager) AddReservation(username string, topic string, everyone Permission) error { + if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { + return ErrInvalidArgument + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := a.allowAccessTx(tx, username, topic, true, true, username, false); err != nil { + return err + } + if err := a.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil { + return err + } + return tx.Commit() +} + +// RemoveReservations deletes the access control entries associated with the given username/topic, +// as well as all entries with Everyone/topic. All deletions are performed atomically in a single +// transaction. +func (a *Manager) RemoveReservations(username string, topics ...string) error { + if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { + return ErrInvalidArgument + } + for _, topic := range topics { + if !AllowedTopic(topic) { + return ErrInvalidArgument + } + } + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, topic := range topics { + if err := a.resetTopicAccessTx(tx, username, topic); err != nil { + return err + } + if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil { + return err + } + } + return tx.Commit() +} + +// Reservations returns all user-owned topics, and the associated everyone-access +func (a *Manager) Reservations(username string) ([]Reservation, error) { + rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username) + if err != nil { + return nil, err + } + defer rows.Close() + reservations := make([]Reservation, 0) + for rows.Next() { + var topic string + var ownerRead, ownerWrite bool + var everyoneRead, everyoneWrite sql.NullBool + if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + reservations = append(reservations, Reservation{ + Topic: fromSQLWildcard(topic), + Owner: NewPermission(ownerRead, ownerWrite), + Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), + }) + } + return reservations, nil +} + +// HasReservation returns true if the given topic access is owned by the user +func (a *Manager) HasReservation(username, topic string) (bool, error) { + rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + return false, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + +// ReservationsCount returns the number of reservations owned by this user +func (a *Manager) ReservationsCount(username string) (int64, error) { + rows, err := a.db.Query(a.queries.selectUserReservationsCount, username) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone +func (a *Manager) ReservationOwner(topic string) (string, error) { + rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic)) + if err != nil { + return "", err + } + defer rows.Close() + if !rows.Next() { + return "", nil + } + var ownerUserID string + if err := rows.Scan(&ownerUserID); err != nil { + return "", err + } + return ownerUserID, nil +} + +// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user +func (a *Manager) OtherAccessCount(username, topic string) (int, error) { + rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + var count int + if err := rows.Scan(&count); err != nil { + return 0, err + } + return count, nil +} + +// AddTier creates a new tier in the database +func (a *Manager) AddTier(tier *Tier) error { + if tier.ID == "" { + tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) + } + if _, err := a.db.Exec(a.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { + return err + } + return nil +} + +// UpdateTier updates a tier's properties in the database +func (a *Manager) UpdateTier(tier *Tier) error { + if _, err := a.db.Exec(a.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { + return err + } + return nil +} + +// RemoveTier deletes the tier with the given code +func (a *Manager) RemoveTier(code string) error { + if !AllowedTier(code) { + return ErrInvalidArgument + } + // This fails if any user has this tier + if _, err := a.db.Exec(a.queries.deleteTier, code); err != nil { + return err + } + return nil +} + +// Tiers returns a list of all Tier structs +func (a *Manager) Tiers() ([]*Tier, error) { + rows, err := a.db.Query(a.queries.selectTiers) + if err != nil { + return nil, err + } + defer rows.Close() + tiers := make([]*Tier, 0) + for { + tier, err := a.readTier(rows) + if errors.Is(err, ErrTierNotFound) { + break + } else if err != nil { + return nil, err + } + tiers = append(tiers, tier) + } + return tiers, nil +} + +// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist +func (a *Manager) Tier(code string) (*Tier, error) { + rows, err := a.db.Query(a.queries.selectTierByCode, code) + if err != nil { + return nil, err + } + defer rows.Close() + return a.readTier(rows) +} + +// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist +func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { + rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID) + if err != nil { + return nil, err + } + defer rows.Close() + return a.readTier(rows) +} + +func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { + var id, code, name string + var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString + var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 + if !rows.Next() { + return nil, ErrTierNotFound + } + if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + // When changed, note readUser() as well + return &Tier{ + ID: id, + Code: code, + Name: name, + MessageLimit: messagesLimit.Int64, + MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, + EmailLimit: emailsLimit.Int64, + CallLimit: callsLimit.Int64, + ReservationLimit: reservationsLimit.Int64, + AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, + AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, + AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, + AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, + StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty + StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty + }, nil +} + +// PhoneNumbers returns all phone numbers for the user with the given user ID +func (a *Manager) PhoneNumbers(userID string) ([]string, error) { + rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID) + if err != nil { + return nil, err + } + defer rows.Close() + phoneNumbers := make([]string, 0) + for { + phoneNumber, err := a.readPhoneNumber(rows) + if errors.Is(err, ErrPhoneNumberNotFound) { + break + } else if err != nil { + return nil, err + } + phoneNumbers = append(phoneNumbers, phoneNumber) + } + return phoneNumbers, nil +} + +// AddPhoneNumber adds a phone number to the user with the given user ID +func (a *Manager) AddPhoneNumber(userID, phoneNumber string) error { + if _, err := a.db.Exec(a.queries.insertPhoneNumber, userID, phoneNumber); err != nil { + if isUniqueConstraintError(err) { + return ErrPhoneNumberExists + } + return err + } + return nil +} + +// RemovePhoneNumber deletes a phone number from the user with the given user ID +func (a *Manager) RemovePhoneNumber(userID, phoneNumber string) error { + _, err := a.db.Exec(a.queries.deletePhoneNumber, userID, phoneNumber) + return err +} +func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) { + var phoneNumber string + if !rows.Next() { + return "", ErrPhoneNumberNotFound + } + if err := rows.Scan(&phoneNumber); err != nil { + return "", err + } else if err := rows.Err(); err != nil { + return "", err + } + return phoneNumber, nil +} + +// ChangeBilling updates a user's billing fields +func (a *Manager) ChangeBilling(username string, billing *Billing) error { + if _, err := a.db.Exec(a.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { + return err + } + return nil +} + +// UserIDByUsername returns the user ID for the given username +func (a *Manager) UserIDByUsername(username string) (string, error) { + rows, err := a.db.Query(a.queries.selectUserIDFromUsername, username) + if err != nil { + return "", err + } + defer rows.Close() + if !rows.Next() { + return "", ErrUserNotFound + } + var userID string + if err := rows.Scan(&userID); err != nil { + return "", err + } + return userID, nil +} + +// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL +func isUniqueConstraintError(err error) bool { + errStr := err.Error() + return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505") +} diff --git a/user/store_postgres.go b/user/manager_postgres.go similarity index 97% rename from user/store_postgres.go rename to user/manager_postgres.go index c4c9577c..e396fbea 100644 --- a/user/store_postgres.go +++ b/user/manager_postgres.go @@ -203,13 +203,14 @@ const ( ` ) -// NewPostgresStore creates a new PostgreSQL-backed user store using an existing database connection pool. -func NewPostgresStore(db *sql.DB) (Store, error) { +// NewPostgresManager creates a new Manager backed by a PostgreSQL database using an existing connection pool. +func NewPostgresManager(db *sql.DB, config *Config) (*Manager, error) { if err := setupPostgres(db); err != nil { return nil, err } - return &commonStore{ - db: db, + manager := &Manager{ + config: config, + db: db, queries: storeQueries{ // User queries selectUserByID: postgresSelectUserByIDQuery, @@ -277,5 +278,9 @@ func NewPostgresStore(db *sql.DB) (Store, error) { // Billing queries updateBilling: postgresUpdateBillingQuery, }, - }, nil + } + if err := initManager(manager); err != nil { + return nil, err + } + return manager, nil } diff --git a/user/store_postgres_schema.go b/user/manager_postgres_schema.go similarity index 100% rename from user/store_postgres_schema.go rename to user/manager_postgres_schema.go diff --git a/user/store_sqlite.go b/user/manager_sqlite.go similarity index 98% rename from user/store_sqlite.go rename to user/manager_sqlite.go index c9d6d33f..7ca6e711 100644 --- a/user/store_sqlite.go +++ b/user/manager_sqlite.go @@ -201,8 +201,8 @@ const ( ` ) -// NewSQLiteStore creates a new SQLite-backed user store -func NewSQLiteStore(filename, startupQueries string) (Store, error) { +// NewSQLiteManager creates a new Manager backed by a SQLite database +func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager, error) { parentDir := filepath.Dir(filename) if !util.FileExists(parentDir) { return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir) @@ -217,8 +217,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { if err := runSQLiteStartupQueries(db, startupQueries); err != nil { return nil, err } - return &commonStore{ - db: db, + manager := &Manager{ + config: config, + db: db, queries: storeQueries{ selectUserByID: sqliteSelectUserByIDQuery, selectUserByName: sqliteSelectUserByNameQuery, @@ -275,5 +276,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { deletePhoneNumber: sqliteDeletePhoneNumberQuery, updateBilling: sqliteUpdateBillingQuery, }, - }, nil + } + if err := initManager(manager); err != nil { + return nil, err + } + return manager, nil } diff --git a/user/store_sqlite_schema.go b/user/manager_sqlite_schema.go similarity index 100% rename from user/store_sqlite_schema.go rename to user/manager_sqlite_schema.go diff --git a/user/manager_test.go b/user/manager_test.go index ea976a3f..8ba6fe3b 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -18,35 +18,36 @@ import ( const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources -// newStoreFunc creates a Store. Calling it multiple times within the same test -// returns a new Store object pointing at the same underlying data (same SQLite -// file / same PostgreSQL schema), enabling close-and-reopen tests. -type newStoreFunc func() Store +// newManagerFunc creates a Manager with the given config. Calling it multiple +// times within the same test returns a new Manager pointing at the same +// underlying data (same SQLite file / same PostgreSQL schema), enabling +// close-and-reopen tests. +type newManagerFunc func(config *Config) *Manager -func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) { +func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc)) { t.Run("sqlite", func(t *testing.T) { dir := t.TempDir() - f(t, func() Store { - store, err := NewSQLiteStore(filepath.Join(dir, "user.db"), "") + f(t, func(config *Config) *Manager { + a, err := NewSQLiteManager(filepath.Join(dir, "user.db"), "", config) require.Nil(t, err) - return store + return a }) }) t.Run("postgres", func(t *testing.T) { schemaDSN := dbtest.CreateTestPostgresSchema(t) - f(t, func() Store { + f(t, func(config *Config) *Manager { pool, err := db.OpenPostgres(schemaDSN) require.Nil(t, err) - store, err := NewPostgresStore(pool) + a, err := NewPostgresManager(pool, config) require.Nil(t, err) - return store + return a }) }) } func TestManager_FullScenario_Default_DenyAll(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("john", "john", RoleUser, false)) @@ -162,8 +163,8 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) { func TestManager_Access_Order_LengthWriteRead(t *testing.T) { // This test validates issue #914 / #917, i.e. that write permissions are prioritized over read permissions, // and longer ACL rules are prioritized as well. - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AllowAccess("ben", "test*", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "*", PermissionRead)) @@ -177,8 +178,8 @@ func TestManager_Access_Order_LengthWriteRead(t *testing.T) { } func TestManager_AddUser_Invalid(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, false)) require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", false)) }) @@ -192,8 +193,8 @@ func TestManager_AddUser_Timing(t *testing.T) { } func TestManager_AddUser_And_Query(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) require.Nil(t, a.ChangeBilling("user", &Billing{ StripeCustomerID: "acct_123", @@ -219,8 +220,8 @@ func TestManager_AddUser_And_Query(t *testing.T) { } func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) // Create user, add reservations and token require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) @@ -262,8 +263,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { require.True(t, u.Deleted) // Backdate the deleted timestamp so RemoveDeletedUsers will prune the user - q := a.store.(*commonStore).queries.updateUserDeleted - _, err = testDB(a).Exec(q, time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) + _, err = testDB(a).Exec(a.queries.updateUserDeleted, time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) require.Nil(t, err) require.Nil(t, a.RemoveDeletedUsers()) @@ -273,8 +273,8 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { } func TestManager_CreateToken_Only_Lower(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) // Create user, add reservations and token require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) @@ -288,8 +288,8 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) { } func TestManager_UserManagement(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) @@ -381,8 +381,8 @@ func TestManager_UserManagement(t *testing.T) { } func TestManager_ChangePassword(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("jane", "$2a$10$OyqU72muEy7VMd1SAU2Iru5IbeSMgrtCGHu/fWLmxL1MwlijQXWbG", RoleUser, true)) @@ -407,8 +407,8 @@ func TestManager_ChangePassword(t *testing.T) { } func TestManager_ChangeRole(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) @@ -434,8 +434,8 @@ func TestManager_ChangeRole(t *testing.T) { } func TestManager_Reservations(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll)) @@ -506,8 +506,8 @@ func TestManager_Reservations(t *testing.T) { } func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddTier(&Tier{ Code: "pro", Name: "ntfy Pro", @@ -567,8 +567,8 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { } func TestManager_Token_Valid(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) u, err := a.User("ben") @@ -613,8 +613,8 @@ func TestManager_Token_Valid(t *testing.T) { } func TestManager_Token_Invalid(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length @@ -628,16 +628,16 @@ func TestManager_Token_Invalid(t *testing.T) { } func TestManager_Token_NotFound(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) _, err := a.Token("u_bla", "notfound") require.Equal(t, ErrTokenNotFound, err) }) } func TestManager_Token_Expire(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) u, err := a.User("ben") @@ -686,8 +686,8 @@ func TestManager_Token_Expire(t *testing.T) { } func TestManager_Token_Extend(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // Try to extend token for user without token @@ -716,8 +716,8 @@ func TestManager_Token_Extend(t *testing.T) { func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { // Tests that tokens are automatically deleted when the maximum number of tokens is reached - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) @@ -787,13 +787,13 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { } func TestManager_EnqueueStats_ResetStats(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 1500 * time.Millisecond, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // Baseline: No messages or emails @@ -835,13 +835,13 @@ func TestManager_EnqueueStats_ResetStats(t *testing.T) { } func TestManager_EnqueueTokenUpdate(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 500 * time.Millisecond, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // Create user and token @@ -874,13 +874,13 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) { } func TestManager_ChangeSettings(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, BcryptCost: bcrypt.MinCost, QueueWriterInterval: 1500 * time.Millisecond, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) // No settings @@ -921,8 +921,8 @@ func TestManager_ChangeSettings(t *testing.T) { } func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) // Create tier and user require.Nil(t, a.AddTier(&Tier{ @@ -1041,8 +1041,8 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { } func TestAccount_Tier_Create_With_ID(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddTier(&Tier{ ID: "ti_123", @@ -1056,8 +1056,8 @@ func TestAccount_Tier_Create_With_ID(t *testing.T) { } func TestManager_Tier_Change_And_Reset(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) // Create tier and user require.Nil(t, a.AddTier(&Tier{ @@ -1095,8 +1095,8 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) { } func TestUser_PhoneNumberAddListRemove(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) phil, err := a.User("phil") @@ -1122,8 +1122,8 @@ func TestUser_PhoneNumberAddListRemove(t *testing.T) { } func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) @@ -1137,8 +1137,8 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) { } func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead)) require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead)) require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead)) @@ -1151,8 +1151,8 @@ func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) { } func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { - a := newTestManager(t, newStore, PermissionDenyAll) + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { + a := newTestManager(t, newManager, PermissionDenyAll) require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite)) require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead)) require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite)) @@ -1162,7 +1162,7 @@ func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) { } func TestManager_WithProvisionedUsers(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, ProvisionEnabled: true, @@ -1182,7 +1182,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { }, }, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) // Manually add user require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false)) @@ -1218,7 +1218,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { // Update the token last access time and origin (so we can check that it is persisted) lastAccessTime := time.Now().Add(time.Hour) lastOrigin := netip.MustParseAddr("1.1.9.9") - err = a.store.UpdateTokenLastAccess(map[string]*TokenUpdate{tokens[0].Value: {LastAccess: lastAccessTime, LastOrigin: lastOrigin}}) + err = a.UpdateTokenLastAccess(map[string]*TokenUpdate{tokens[0].Value: {LastAccess: lastAccessTime, LastOrigin: lastOrigin}}) require.Nil(t, err) // Re-open the DB (second app start) @@ -1238,7 +1238,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { {Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"}, }, } - a = newTestManagerFromStoreConfig(t, newStore, conf) + a = newTestManagerFromConfig(t, newManager, conf) // Check that the provisioned users are there users, err = a.Users() @@ -1277,7 +1277,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { conf.Users = []*User{} conf.Access = map[string][]*Grant{} conf.Tokens = map[string][]*Token{} - a = newTestManagerFromStoreConfig(t, newStore, conf) + a = newTestManagerFromConfig(t, newManager, conf) // Check that the provisioned users are all gone users, err = a.Users() @@ -1314,7 +1314,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { } func TestManager_WithProvisionedUsers_RemoveToken(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, ProvisionEnabled: true, @@ -1328,7 +1328,7 @@ func TestManager_WithProvisionedUsers_RemoveToken(t *testing.T) { }, }, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) users, err := a.Users() require.Nil(t, err) @@ -1351,7 +1351,7 @@ func TestManager_WithProvisionedUsers_RemoveToken(t *testing.T) { {Value: "tk_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", Label: "Token A"}, }, } - a = newTestManagerFromStoreConfig(t, newStore, conf) + a = newTestManagerFromConfig(t, newManager, conf) tokens, err = a.Tokens(philUserID) require.Nil(t, err) @@ -1361,7 +1361,7 @@ func TestManager_WithProvisionedUsers_RemoveToken(t *testing.T) { } func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { - forEachBackend(t, func(t *testing.T, newStore newStoreFunc) { + forEachBackend(t, func(t *testing.T, newManager newManagerFunc) { conf := &Config{ DefaultAccess: PermissionReadWrite, ProvisionEnabled: true, @@ -1372,7 +1372,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { }, }, } - a := newTestManagerFromStoreConfig(t, newStore, conf) + a := newTestManagerFromConfig(t, newManager, conf) // Manually add user require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false)) @@ -1413,7 +1413,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { {TopicPattern: "stats", Permission: PermissionReadWrite}, }, } - a = newTestManagerFromStoreConfig(t, newStore, conf) + a = newTestManagerFromConfig(t, newManager, conf) // Check that the user was "upgraded" to a provisioned user users, err = a.Users() @@ -1696,39 +1696,714 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { require.Nil(t, rows.Close()) } -func newTestManager(t *testing.T, newStore newStoreFunc, defaultAccess Permission) *Manager { - store := newStore() - a, err := NewManager(store, &Config{ +func newTestManager(t *testing.T, newManager newManagerFunc, defaultAccess Permission) *Manager { + a := newManager(&Config{ DefaultAccess: defaultAccess, BcryptCost: bcrypt.MinCost, QueueWriterInterval: DefaultUserStatsQueueWriterInterval, }) - require.Nil(t, err) t.Cleanup(func() { a.Close() }) return a } func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager { - store, err := NewSQLiteStore(filename, startupQueries) - require.Nil(t, err) - conf := &Config{ + a, err := NewSQLiteManager(filename, startupQueries, &Config{ DefaultAccess: defaultAccess, BcryptCost: bcryptCost, QueueWriterInterval: statsWriterInterval, - } - a, err := NewManager(store, conf) + }) require.Nil(t, err) return a } -func newTestManagerFromStoreConfig(t *testing.T, newStore newStoreFunc, conf *Config) *Manager { - store := newStore() - a, err := NewManager(store, conf) - require.Nil(t, err) +func newTestManagerFromConfig(t *testing.T, newManager newManagerFunc, conf *Config) *Manager { + a := newManager(conf) t.Cleanup(func() { a.Close() }) return a } func testDB(a *Manager) *sql.DB { - return a.store.(*commonStore).db + return a.db +} + +func forEachStoreBackend(t *testing.T, f func(t *testing.T, manager *Manager)) { + t.Run("sqlite", func(t *testing.T) { + manager, err := NewSQLiteManager(filepath.Join(t.TempDir(), "user.db"), "", &Config{}) + require.Nil(t, err) + t.Cleanup(func() { manager.Close() }) + f(t, manager) + }) + t.Run("postgres", func(t *testing.T) { + testDB := dbtest.CreateTestPostgres(t) + manager, err := NewPostgresManager(testDB, &Config{}) + require.Nil(t, err) + f(t, manager) + }) +} + +func TestStoreAddUser(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) + require.Equal(t, RoleUser, u.Role) + require.False(t, u.Provisioned) + require.NotEmpty(t, u.ID) + require.NotEmpty(t, u.SyncTopic) + }) +} + +func TestStoreAddUserAlreadyExists(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "pass1", RoleUser, false)) + require.Equal(t, ErrUserExists, manager.AddUser("phil", "pass2", RoleUser, false)) + }) +} + +func TestStoreRemoveUser(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) + + require.Nil(t, manager.RemoveUser("phil")) + _, err = manager.User("phil") + require.Equal(t, ErrUserNotFound, err) + }) +} + +func TestStoreUserByID(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleAdmin, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + u2, err := manager.UserByID(u.ID) + require.Nil(t, err) + require.Equal(t, u.Name, u2.Name) + require.Equal(t, u.ID, u2.ID) + }) +} + +func TestStoreUserByToken(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + tk, err := manager.CreateToken(u.ID, "test token", time.Now().Add(24*time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + require.NotEmpty(t, tk.Value) + + u2, err := manager.UserByToken(tk.Value) + require.Nil(t, err) + require.Equal(t, "phil", u2.Name) + }) +} + +func TestStoreUserByStripeCustomer(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.ChangeBilling("phil", &Billing{ + StripeCustomerID: "cus_test123", + StripeSubscriptionID: "sub_test123", + })) + + u, err := manager.UserByStripeCustomer("cus_test123") + require.Nil(t, err) + require.Equal(t, "phil", u.Name) + require.Equal(t, "cus_test123", u.Billing.StripeCustomerID) + }) +} + +func TestStoreUsers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddUser("ben", "benpass", RoleAdmin, false)) + + users, err := manager.Users() + require.Nil(t, err) + require.True(t, len(users) >= 3) // phil, ben, and the everyone user + }) +} + +func TestStoreUsersCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + count, err := manager.UsersCount() + require.Nil(t, err) + require.True(t, count >= 1) // At least the everyone user + + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + count2, err := manager.UsersCount() + require.Nil(t, err) + require.Equal(t, count+1, count2) + }) +} + +func TestStoreChangePassword(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + require.NotEmpty(t, u.Hash) + + require.Nil(t, manager.ChangePassword("phil", "newpass", false)) + u, err = manager.User("phil") + require.Nil(t, err) + require.NotEmpty(t, u.Hash) + }) +} + +func TestStoreChangeRole(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, RoleUser, u.Role) + + require.Nil(t, manager.ChangeRole("phil", RoleAdmin)) + u, err = manager.User("phil") + require.Nil(t, err) + require.Equal(t, RoleAdmin, u.Role) + }) +} + +func TestStoreTokens(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + expires := time.Now().Add(24 * time.Hour) + origin := netip.MustParseAddr("9.9.9.9") + + tk, err := manager.CreateToken(u.ID, "my token", expires, origin, false) + require.Nil(t, err) + require.NotEmpty(t, tk.Value) + require.Equal(t, "my token", tk.Label) + + // Get single token + tk2, err := manager.Token(u.ID, tk.Value) + require.Nil(t, err) + require.Equal(t, tk.Value, tk2.Value) + require.Equal(t, "my token", tk2.Label) + + // Get all tokens + tokens, err := manager.Tokens(u.ID) + require.Nil(t, err) + require.Len(t, tokens, 1) + require.Equal(t, tk.Value, tokens[0].Value) + }) +} + +func TestStoreTokenChange(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + expires := time.Now().Add(time.Hour) + tk, err := manager.CreateToken(u.ID, "old label", expires, netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + + newLabel := "new label" + newExpires := time.Now().Add(2 * time.Hour) + tk2, err := manager.ChangeToken(u.ID, tk.Value, &newLabel, &newExpires) + require.Nil(t, err) + require.Equal(t, "new label", tk2.Label) + require.Equal(t, newExpires.Unix(), tk2.Expires.Unix()) + }) +} + +func TestStoreTokenRemove(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + tk, err := manager.CreateToken(u.ID, "label", time.Now().Add(time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + + require.Nil(t, manager.RemoveToken(u.ID, tk.Value)) + _, err = manager.Token(u.ID, tk.Value) + require.Equal(t, ErrTokenNotFound, err) + }) +} + +func TestStoreTokenRemoveExpired(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + // Create expired token and active token + tkExpired, err := manager.CreateToken(u.ID, "expired", time.Now().Add(-time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + tkActive, err := manager.CreateToken(u.ID, "active", time.Now().Add(time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + + require.Nil(t, manager.RemoveExpiredTokens()) + + // Expired token should be gone + _, err = manager.Token(u.ID, tkExpired.Value) + require.Equal(t, ErrTokenNotFound, err) + + // Active token should still exist + tk, err := manager.Token(u.ID, tkActive.Value) + require.Nil(t, err) + require.Equal(t, tkActive.Value, tk.Value) + }) +} + +func TestStoreTokenCreatePrunesExcess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + // Create several tokens + var tokenValues []string + for i := 0; i < 3; i++ { + tk, err := manager.CreateToken(u.ID, "label", time.Now().Add(time.Duration(i+1)*time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + tokenValues = append(tokenValues, tk.Value) + } + + tokens, err := manager.Tokens(u.ID) + require.Nil(t, err) + require.True(t, len(tokens) >= 3) + }) +} + +func TestStoreTokenUpdateLastAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + tk, err := manager.CreateToken(u.ID, "label", time.Now().Add(time.Hour), netip.MustParseAddr("1.2.3.4"), false) + require.Nil(t, err) + + newTime := time.Now().Add(5 * time.Minute) + newOrigin := netip.MustParseAddr("5.5.5.5") + manager.EnqueueTokenUpdate(tk.Value, &TokenUpdate{LastAccess: newTime, LastOrigin: newOrigin}) + }) +} + +func TestStoreAllowAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + + require.Nil(t, manager.AllowAccess("phil", "mytopic", PermissionReadWrite)) + grants, err := manager.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.Equal(t, "mytopic", grants[0].TopicPattern) + require.True(t, grants[0].Permission.IsReadWrite()) + }) +} + +func TestStoreAllowAccessReadOnly(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + + require.Nil(t, manager.AllowAccess("phil", "announcements", PermissionRead)) + grants, err := manager.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.True(t, grants[0].Permission.IsRead()) + require.False(t, grants[0].Permission.IsWrite()) + }) +} + +func TestStoreResetAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AllowAccess("phil", "topic1", PermissionReadWrite)) + require.Nil(t, manager.AllowAccess("phil", "topic2", PermissionRead)) + + grants, err := manager.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 2) + + require.Nil(t, manager.ResetAccess("phil", "topic1")) + grants, err = manager.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 1) + require.Equal(t, "topic2", grants[0].TopicPattern) + }) +} + +func TestStoreResetAccessAll(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AllowAccess("phil", "topic1", PermissionReadWrite)) + require.Nil(t, manager.AllowAccess("phil", "topic2", PermissionRead)) + + require.Nil(t, manager.ResetAccess("phil", "")) + grants, err := manager.Grants("phil") + require.Nil(t, err) + require.Len(t, grants, 0) + }) +} + +func TestStoreAuthorizeTopicAccess(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AllowAccess("phil", "mytopic", PermissionReadWrite)) + + read, write, found, err := manager.AuthorizeTopicAccess("phil", "mytopic") + require.Nil(t, err) + require.True(t, found) + require.True(t, read) + require.True(t, write) + }) +} + +func TestStoreAuthorizeTopicAccessNotFound(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + + _, _, found, err := manager.AuthorizeTopicAccess("phil", "other") + require.Nil(t, err) + require.False(t, found) + }) +} + +func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AllowAccess("phil", "secret", PermissionDenyAll)) + + read, write, found, err := manager.AuthorizeTopicAccess("phil", "secret") + require.Nil(t, err) + require.True(t, found) + require.False(t, read) + require.False(t, write) + }) +} + +func TestStoreReservations(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionRead)) + + reservations, err := manager.Reservations("phil") + require.Nil(t, err) + require.Len(t, reservations, 1) + require.Equal(t, "mytopic", reservations[0].Topic) + require.True(t, reservations[0].Owner.IsReadWrite()) + require.True(t, reservations[0].Everyone.IsRead()) + require.False(t, reservations[0].Everyone.IsWrite()) + }) +} + +func TestStoreReservationsCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddReservation("phil", "topic1", PermissionReadWrite)) + require.Nil(t, manager.AddReservation("phil", "topic2", PermissionReadWrite)) + + count, err := manager.ReservationsCount("phil") + require.Nil(t, err) + require.Equal(t, int64(2), count) + }) +} + +func TestStoreHasReservation(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite)) + + has, err := manager.HasReservation("phil", "mytopic") + require.Nil(t, err) + require.True(t, has) + + has, err = manager.HasReservation("phil", "other") + require.Nil(t, err) + require.False(t, has) + }) +} + +func TestStoreReservationOwner(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddReservation("phil", "mytopic", PermissionReadWrite)) + + owner, err := manager.ReservationOwner("mytopic") + require.Nil(t, err) + require.NotEmpty(t, owner) // Returns the user ID + + owner, err = manager.ReservationOwner("unowned") + require.Nil(t, err) + require.Empty(t, owner) + }) +} + +func TestStoreTiers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + tier := &Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + MessageLimit: 5000, + MessageExpiryDuration: 24 * time.Hour, + EmailLimit: 100, + CallLimit: 10, + ReservationLimit: 20, + AttachmentFileSizeLimit: 10 * 1024 * 1024, + AttachmentTotalSizeLimit: 100 * 1024 * 1024, + AttachmentExpiryDuration: 48 * time.Hour, + AttachmentBandwidthLimit: 500 * 1024 * 1024, + } + require.Nil(t, manager.AddTier(tier)) + + // Get by code + t2, err := manager.Tier("pro") + require.Nil(t, err) + require.Equal(t, "ti_test", t2.ID) + require.Equal(t, "pro", t2.Code) + require.Equal(t, "Pro", t2.Name) + require.Equal(t, int64(5000), t2.MessageLimit) + require.Equal(t, int64(100), t2.EmailLimit) + require.Equal(t, int64(10), t2.CallLimit) + require.Equal(t, int64(20), t2.ReservationLimit) + + // List all tiers + tiers, err := manager.Tiers() + require.Nil(t, err) + require.Len(t, tiers, 1) + require.Equal(t, "pro", tiers[0].Code) + }) +} + +func TestStoreTierUpdate(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + tier := &Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, manager.AddTier(tier)) + + tier.Name = "Professional" + tier.MessageLimit = 9999 + require.Nil(t, manager.UpdateTier(tier)) + + t2, err := manager.Tier("pro") + require.Nil(t, err) + require.Equal(t, "Professional", t2.Name) + require.Equal(t, int64(9999), t2.MessageLimit) + }) +} + +func TestStoreTierRemove(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + tier := &Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, manager.AddTier(tier)) + + t2, err := manager.Tier("pro") + require.Nil(t, err) + require.Equal(t, "pro", t2.Code) + + require.Nil(t, manager.RemoveTier("pro")) + _, err = manager.Tier("pro") + require.Equal(t, ErrTierNotFound, err) + }) +} + +func TestStoreTierByStripePrice(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + tier := &Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + StripeMonthlyPriceID: "price_monthly", + StripeYearlyPriceID: "price_yearly", + } + require.Nil(t, manager.AddTier(tier)) + + t2, err := manager.TierByStripePrice("price_monthly") + require.Nil(t, err) + require.Equal(t, "pro", t2.Code) + + t3, err := manager.TierByStripePrice("price_yearly") + require.Nil(t, err) + require.Equal(t, "pro", t3.Code) + }) +} + +func TestStoreChangeTier(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + tier := &Tier{ + ID: "ti_test", + Code: "pro", + Name: "Pro", + } + require.Nil(t, manager.AddTier(tier)) + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.ChangeTier("phil", "pro")) + + u, err := manager.User("phil") + require.Nil(t, err) + require.NotNil(t, u.Tier) + require.Equal(t, "pro", u.Tier.Code) + }) +} + +func TestStorePhoneNumbers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + require.Nil(t, manager.AddPhoneNumber(u.ID, "+1234567890")) + require.Nil(t, manager.AddPhoneNumber(u.ID, "+0987654321")) + + numbers, err := manager.PhoneNumbers(u.ID) + require.Nil(t, err) + require.Len(t, numbers, 2) + + require.Nil(t, manager.RemovePhoneNumber(u.ID, "+1234567890")) + numbers, err = manager.PhoneNumbers(u.ID) + require.Nil(t, err) + require.Len(t, numbers, 1) + require.Equal(t, "+0987654321", numbers[0]) + }) +} + +func TestStoreChangeSettings(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + lang := "de" + prefs := &Prefs{Language: &lang} + require.Nil(t, manager.ChangeSettings(u.ID, prefs)) + + u2, err := manager.User("phil") + require.Nil(t, err) + require.NotNil(t, u2.Prefs) + require.Equal(t, "de", *u2.Prefs.Language) + }) +} + +func TestStoreChangeBilling(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + + billing := &Billing{ + StripeCustomerID: "cus_123", + StripeSubscriptionID: "sub_456", + } + require.Nil(t, manager.ChangeBilling("phil", billing)) + + u, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, "cus_123", u.Billing.StripeCustomerID) + require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID) + }) +} + +func TestStoreUpdateStats(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + stats := &Stats{Messages: 42, Emails: 3, Calls: 1} + require.Nil(t, manager.UpdateStats(map[string]*Stats{u.ID: stats})) + + u2, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, int64(42), u2.Stats.Messages) + require.Equal(t, int64(3), u2.Stats.Emails) + require.Equal(t, int64(1), u2.Stats.Calls) + }) +} + +func TestStoreResetStats(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + require.Nil(t, manager.UpdateStats(map[string]*Stats{u.ID: {Messages: 42, Emails: 3, Calls: 1}})) + require.Nil(t, manager.ResetStats()) + + u2, err := manager.User("phil") + require.Nil(t, err) + require.Equal(t, int64(0), u2.Stats.Messages) + require.Equal(t, int64(0), u2.Stats.Emails) + require.Equal(t, int64(0), u2.Stats.Calls) + }) +} + +func TestStoreMarkUserRemoved(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + require.Nil(t, manager.MarkUserRemoved(u)) + + u2, err := manager.User("phil") + require.Nil(t, err) + require.True(t, u2.Deleted) + }) +} + +func TestStoreRemoveDeletedUsers(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + u, err := manager.User("phil") + require.Nil(t, err) + + require.Nil(t, manager.MarkUserRemoved(u)) + + // RemoveDeletedUsers only removes users past the hard-delete duration (7 days). + // Immediately after marking, the user should still exist. + require.Nil(t, manager.RemoveDeletedUsers()) + u2, err := manager.User("phil") + require.Nil(t, err) + require.True(t, u2.Deleted) + }) +} + +func TestStoreAllGrants(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false)) + phil, err := manager.User("phil") + require.Nil(t, err) + ben, err := manager.User("ben") + require.Nil(t, err) + + require.Nil(t, manager.AllowAccess("phil", "topic1", PermissionReadWrite)) + require.Nil(t, manager.AllowAccess("ben", "topic2", PermissionRead)) + + grants, err := manager.AllGrants() + require.Nil(t, err) + require.Contains(t, grants, phil.ID) + require.Contains(t, grants, ben.ID) + }) +} + +func TestStoreOtherAccessCount(t *testing.T) { + forEachStoreBackend(t, func(t *testing.T, manager *Manager) { + require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) + require.Nil(t, manager.AddUser("ben", "benpass", RoleUser, false)) + require.Nil(t, manager.AddReservation("ben", "mytopic", PermissionReadWrite)) + + count, err := manager.OtherAccessCount("phil", "mytopic") + require.Nil(t, err) + require.Equal(t, 2, count) // ben's owner entry + everyone entry + }) } diff --git a/user/store.go b/user/store.go deleted file mode 100644 index 4460f181..00000000 --- a/user/store.go +++ /dev/null @@ -1,1059 +0,0 @@ -package user - -import ( - "database/sql" - "encoding/json" - "errors" - "net/netip" - "strings" - "time" - - "heckel.io/ntfy/v2/payments" - "heckel.io/ntfy/v2/util" -) - -// Store is the interface for a user database store -type Store interface { - // User operations - UserByID(id string) (*User, error) - User(username string) (*User, error) - UserByToken(token string) (*User, error) - UserByStripeCustomer(customerID string) (*User, error) - UserIDByUsername(username string) (string, error) - Users() ([]*User, error) - UsersCount() (int64, error) - AddUser(username, hash string, role Role, provisioned bool) error - RemoveUser(username string) error - MarkUserRemoved(userID, username string) error - RemoveDeletedUsers() error - ChangePassword(username, hash string) error - ChangeRole(username string, role Role) error - ChangeProvisioned(username string, provisioned bool) error - ChangeSettings(userID string, prefs *Prefs) error - ChangeTier(username, tierCode string) error - ResetTier(username string) error - UpdateStats(stats map[string]*Stats) error - ResetStats() error - - // Token operations - CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) - Token(userID, token string) (*Token, error) - Tokens(userID string) ([]*Token, error) - AllProvisionedTokens() ([]*Token, error) - ChangeToken(userID, token, label string, expires time.Time) error - UpdateTokenLastAccess(updates map[string]*TokenUpdate) error - RemoveToken(userID, token string) error - RemoveProvisionedToken(token string) error - RemoveExpiredTokens() error - - // Access operations - AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) - AllGrants() (map[string][]Grant, error) - Grants(username string) ([]Grant, error) - AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error - ResetAccess(username, topicPattern string) error - ResetAllProvisionedAccess() error - AddReservation(username, topic string, everyone Permission) error - RemoveReservations(username string, topics ...string) error - Reservations(username string) ([]Reservation, error) - HasReservation(username, topic string) (bool, error) - ReservationsCount(username string) (int64, error) - ReservationOwner(topic string) (string, error) - OtherAccessCount(username, topic string) (int, error) - - // Tier operations - AddTier(tier *Tier) error - UpdateTier(tier *Tier) error - RemoveTier(code string) error - Tiers() ([]*Tier, error) - Tier(code string) (*Tier, error) - TierByStripePrice(priceID string) (*Tier, error) - - // Phone operations - PhoneNumbers(userID string) ([]string, error) - AddPhoneNumber(userID, phoneNumber string) error - RemovePhoneNumber(userID, phoneNumber string) error - - // Other stuff - ChangeBilling(username string, billing *Billing) error - Close() error -} - -// storeQueries holds the database-specific SQL queries -type storeQueries struct { - // User queries - selectUserByID string - selectUserByName string - selectUserByToken string - selectUserByStripeCustomerID string - selectUsernames string - selectUserCount string - selectUserIDFromUsername string - insertUser string - updateUserPass string - updateUserRole string - updateUserProvisioned string - updateUserPrefs string - updateUserStats string - updateUserStatsResetAll string - updateUserTier string - updateUserDeleted string - deleteUser string - deleteUserTier string - deleteUsersMarked string - // Access queries - selectTopicPerms string - selectUserAllAccess string - selectUserAccess string - selectUserReservations string - selectUserReservationsCount string - selectUserReservationsOwner string - selectUserHasReservation string - selectOtherAccessCount string - upsertUserAccess string - deleteUserAccess string - deleteUserAccessProvisioned string - deleteTopicAccess string - deleteAllAccess string - // Token queries - selectToken string - selectTokens string - selectTokenCount string - selectAllProvisionedTokens string - upsertToken string - updateToken string - updateTokenLastAccess string - deleteToken string - deleteProvisionedToken string - deleteAllToken string - deleteExpiredTokens string - deleteExcessTokens string - // Tier queries - insertTier string - selectTiers string - selectTierByCode string - selectTierByPriceID string - updateTier string - deleteTier string - // Phone queries - selectPhoneNumbers string - insertPhoneNumber string - deletePhoneNumber string - // Billing queries - updateBilling string -} - -// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods -// to be used both standalone and within a transaction. -type execer interface { - Exec(query string, args ...any) (sql.Result, error) -} - -// commonStore implements store operations that work across database backends -type commonStore struct { - db *sql.DB - queries storeQueries -} - -// User returns the user with the given username if it exists, or ErrUserNotFound otherwise -func (s *commonStore) User(username string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByName, username) - if err != nil { - return nil, err - } - return s.readUser(rows) -} - -// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise -func (s *commonStore) UserByID(id string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByID, id) - if err != nil { - return nil, err - } - return s.readUser(rows) -} - -// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise -func (s *commonStore) UserByToken(token string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix()) - if err != nil { - return nil, err - } - return s.readUser(rows) -} - -// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise -func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByStripeCustomerID, customerID) - if err != nil { - return nil, err - } - return s.readUser(rows) -} - -// Users returns a list of users -func (s *commonStore) Users() ([]*User, error) { - rows, err := s.db.Query(s.queries.selectUsernames) - if err != nil { - return nil, err - } - defer rows.Close() - usernames := make([]string, 0) - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - usernames = append(usernames, username) - } - rows.Close() - users := make([]*User, 0) - for _, username := range usernames { - user, err := s.User(username) - if err != nil { - return nil, err - } - users = append(users, user) - } - return users, nil -} - -// UsersCount returns the number of users in the database -func (s *commonStore) UsersCount() (int64, error) { - rows, err := s.db.Query(s.queries.selectUserCount) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil -} - -// AddUser adds a user with the given username, password hash and role -func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error { - if !AllowedUsername(username) || !AllowedRole(role) { - return ErrInvalidArgument - } - userID := util.RandomStringPrefix(userIDPrefix, userIDLength) - syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) - now := time.Now().Unix() - if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil { - if isUniqueConstraintError(err) { - return ErrUserExists - } - return err - } - return nil -} - -// RemoveUser deletes the user with the given username -func (s *commonStore) RemoveUser(username string) error { - if !AllowedUsername(username) { - return ErrInvalidArgument - } - // Rows in user_access, user_token, etc. are deleted via foreign keys - if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil { - return err - } - return nil -} - -// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens -func (s *commonStore) MarkUserRemoved(userID, username string) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if err := s.resetUserAccessTx(tx, username); err != nil { - return err - } - if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil { - return err - } - deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix() - if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil { - return err - } - return tx.Commit() -} - -// RemoveDeletedUsers deletes all users that have been marked deleted -func (s *commonStore) RemoveDeletedUsers() error { - if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil { - return err - } - return nil -} - -// ChangePassword changes a user's password -func (s *commonStore) ChangePassword(username, hash string) error { - if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil { - return err - } - return nil -} - -// ChangeRole changes a user's role -func (s *commonStore) ChangeRole(username string, role Role) error { - if !AllowedUsername(username) || !AllowedRole(role) { - return ErrInvalidArgument - } - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil { - return err - } - // If changing to admin, remove all access entries - if role == RoleAdmin { - if err := s.resetUserAccessTx(tx, username); err != nil { - return err - } - } - return tx.Commit() -} - -// ChangeProvisioned changes the provisioned status of a user -func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error { - if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil { - return err - } - return nil -} - -// ChangeSettings persists the user settings -func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error { - b, err := json.Marshal(prefs) - if err != nil { - return err - } - if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil { - return err - } - return nil -} - -// ChangeTier changes a user's tier using the tier code -func (s *commonStore) ChangeTier(username, tierCode string) error { - if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil { - return err - } - return nil -} - -// ResetTier removes the tier from the given user -func (s *commonStore) ResetTier(username string) error { - if !AllowedUsername(username) && username != Everyone && username != "" { - return ErrInvalidArgument - } - _, err := s.db.Exec(s.queries.deleteUserTier, username) - return err -} - -// UpdateStats updates statistics for one or more users in a single transaction -func (s *commonStore) UpdateStats(stats map[string]*Stats) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for userID, update := range stats { - if _, err := tx.Exec(s.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil { - return err - } - } - return tx.Commit() -} - -// ResetStats resets all user stats in the user database -func (s *commonStore) ResetStats() error { - if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil { - return err - } - return nil -} - -func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { - defer rows.Close() - var id, username, hash, role, prefs, syncTopic string - var provisioned bool - var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString - var messages, emails, calls int64 - var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 - if !rows.Next() { - return nil, ErrUserNotFound - } - if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - user := &User{ - ID: id, - Name: username, - Hash: hash, - Role: Role(role), - Prefs: &Prefs{}, - SyncTopic: syncTopic, - Provisioned: provisioned, - Stats: &Stats{ - Messages: messages, - Emails: emails, - Calls: calls, - }, - Billing: &Billing{ - StripeCustomerID: stripeCustomerID.String, // May be empty - StripeSubscriptionID: stripeSubscriptionID.String, // May be empty - StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty - StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty - StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero - StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero - }, - Deleted: deleted.Valid, - } - if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { - return nil, err - } - if tierCode.Valid { - // See readTier() when this is changed! - user.Tier = &Tier{ - ID: tierID.String, - Code: tierCode.String, - Name: tierName.String, - MessageLimit: messagesLimit.Int64, - MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailLimit: emailsLimit.Int64, - CallLimit: callsLimit.Int64, - ReservationLimit: reservationsLimit.Int64, - AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, - AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, - AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty - StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty - } - } - return user, nil -} - -// CreateToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount. -// If maxTokenCount is 0, no pruning is performed. -func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) { - tx, err := s.db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - if _, err := tx.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { - return nil, err - } - if maxTokenCount > 0 { - var tokenCount int - if err := tx.QueryRow(s.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil { - return nil, err - } - if tokenCount > maxTokenCount { - if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { - return nil, err - } - } - } - if err := tx.Commit(); err != nil { - return nil, err - } - return &Token{ - Value: token, - Label: label, - LastAccess: lastAccess, - LastOrigin: lastOrigin, - Expires: expires, - Provisioned: provisioned, - }, nil -} - -// Token returns a specific token for a user -func (s *commonStore) Token(userID, token string) (*Token, error) { - rows, err := s.db.Query(s.queries.selectToken, userID, token) - if err != nil { - return nil, err - } - defer rows.Close() - return s.readToken(rows) -} - -// Tokens returns all existing tokens for the user with the given user ID -func (s *commonStore) Tokens(userID string) ([]*Token, error) { - rows, err := s.db.Query(s.queries.selectTokens, userID) - if err != nil { - return nil, err - } - defer rows.Close() - tokens := make([]*Token, 0) - for { - token, err := s.readToken(rows) - if errors.Is(err, ErrTokenNotFound) { - break - } else if err != nil { - return nil, err - } - tokens = append(tokens, token) - } - return tokens, nil -} - -// AllProvisionedTokens returns all provisioned tokens -func (s *commonStore) AllProvisionedTokens() ([]*Token, error) { - rows, err := s.db.Query(s.queries.selectAllProvisionedTokens) - if err != nil { - return nil, err - } - defer rows.Close() - tokens := make([]*Token, 0) - for { - token, err := s.readToken(rows) - if errors.Is(err, ErrTokenNotFound) { - break - } else if err != nil { - return nil, err - } - tokens = append(tokens, token) - } - return tokens, nil -} - -// ChangeToken updates a token's label and expiry time -func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time) error { - if _, err := s.db.Exec(s.queries.updateToken, label, expires.Unix(), userID, token); err != nil { - return err - } - return nil -} - -// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction -func (s *commonStore) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for token, update := range updates { - if _, err := tx.Exec(s.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil { - return err - } - } - return tx.Commit() -} - -// RemoveToken deletes the token -func (s *commonStore) RemoveToken(userID, token string) error { - if token == "" { - return errNoTokenProvided - } - if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil { - return err - } - return nil -} - -// RemoveProvisionedToken deletes a provisioned token by value, regardless of user -func (s *commonStore) RemoveProvisionedToken(token string) error { - if token == "" { - return errNoTokenProvided - } - if _, err := s.db.Exec(s.queries.deleteProvisionedToken, token); err != nil { - return err - } - return nil -} - -// RemoveExpiredTokens deletes all expired tokens from the database -func (s *commonStore) RemoveExpiredTokens() error { - if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil { - return err - } - return nil -} - -func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) { - var token, label, lastOrigin string - var lastAccess, expires int64 - var provisioned bool - if !rows.Next() { - return nil, ErrTokenNotFound - } - if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - lastOriginIP, err := netip.ParseAddr(lastOrigin) - if err != nil { - lastOriginIP = netip.IPv4Unspecified() - } - return &Token{ - Value: token, - Label: label, - LastAccess: time.Unix(lastAccess, 0), - LastOrigin: lastOriginIP, - Expires: time.Unix(expires, 0), - Provisioned: provisioned, - }, nil -} - -// AuthorizeTopicAccess returns the read/write permissions for the given username and topic. -// The found return value indicates whether an ACL entry was found at all. -// -// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. -// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" -// - It also prioritizes write permissions over read permissions -func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { - rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) - if err != nil { - return false, false, false, err - } - defer rows.Close() - if !rows.Next() { - return false, false, false, nil - } - if err := rows.Scan(&read, &write); err != nil { - return false, false, false, err - } else if err := rows.Err(); err != nil { - return false, false, false, err - } - return read, write, true, nil -} - -// AllGrants returns all user-specific access control entries, mapped to their respective user IDs -func (s *commonStore) AllGrants() (map[string][]Grant, error) { - rows, err := s.db.Query(s.queries.selectUserAllAccess) - if err != nil { - return nil, err - } - defer rows.Close() - grants := make(map[string][]Grant, 0) - for rows.Next() { - var userID, topic string - var read, write, provisioned bool - if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - if _, ok := grants[userID]; !ok { - grants[userID] = make([]Grant, 0) - } - grants[userID] = append(grants[userID], Grant{ - TopicPattern: fromSQLWildcard(topic), - Permission: NewPermission(read, write), - Provisioned: provisioned, - }) - } - return grants, nil -} - -// Grants returns all user-specific access control entries -func (s *commonStore) Grants(username string) ([]Grant, error) { - rows, err := s.db.Query(s.queries.selectUserAccess, username) - if err != nil { - return nil, err - } - defer rows.Close() - grants := make([]Grant, 0) - for rows.Next() { - var topic string - var read, write, provisioned bool - if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - grants = append(grants, Grant{ - TopicPattern: fromSQLWildcard(topic), - Permission: NewPermission(read, write), - Provisioned: provisioned, - }) - } - return grants, nil -} - -// AllowAccess adds or updates an entry in the access control list -func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { - return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned) -} - -func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { - if !AllowedUsername(username) && username != Everyone { - return ErrInvalidArgument - } else if !AllowedTopicPattern(topicPattern) { - return ErrInvalidArgument - } - _, err := tx.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned) - return err -} - -// ResetAccess removes an access control list entry -func (s *commonStore) ResetAccess(username, topicPattern string) error { - if username == "" && topicPattern == "" { - _, err := s.db.Exec(s.queries.deleteAllAccess) - return err - } else if topicPattern == "" { - return s.resetUserAccessTx(s.db, username) - } - return s.resetTopicAccessTx(s.db, username, topicPattern) -} - -func (s *commonStore) resetUserAccessTx(tx execer, username string) error { - if !AllowedUsername(username) && username != Everyone { - return ErrInvalidArgument - } - _, err := tx.Exec(s.queries.deleteUserAccess, username, username) - return err -} - -func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error { - if !AllowedUsername(username) && username != Everyone && username != "" { - return ErrInvalidArgument - } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { - return ErrInvalidArgument - } - _, err := tx.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) - return err -} - -// ResetAllProvisionedAccess removes all provisioned access control entries -func (s *commonStore) ResetAllProvisionedAccess() error { - if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil { - return err - } - return nil -} - -// AddReservation creates two access control entries for the given topic: one with full read/write -// access for the given user, and one for Everyone with the given permission. Both entries are -// created atomically in a single transaction. -func (s *commonStore) AddReservation(username, topic string, everyone Permission) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if err := s.allowAccessTx(tx, username, topic, true, true, username, false); err != nil { - return err - } - if err := s.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil { - return err - } - return tx.Commit() -} - -// RemoveReservations deletes the access control entries associated with the given username/topic, -// as well as all entries with Everyone/topic. All deletions are performed atomically in a single -// transaction. -func (s *commonStore) RemoveReservations(username string, topics ...string) error { - tx, err := s.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, topic := range topics { - if err := s.resetTopicAccessTx(tx, username, topic); err != nil { - return err - } - if err := s.resetTopicAccessTx(tx, Everyone, topic); err != nil { - return err - } - } - return tx.Commit() -} - -// Reservations returns all user-owned topics, and the associated everyone-access -func (s *commonStore) Reservations(username string) ([]Reservation, error) { - rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username) - if err != nil { - return nil, err - } - defer rows.Close() - reservations := make([]Reservation, 0) - for rows.Next() { - var topic string - var ownerRead, ownerWrite bool - var everyoneRead, everyoneWrite sql.NullBool - if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - reservations = append(reservations, Reservation{ - Topic: fromSQLWildcard(topic), - Owner: NewPermission(ownerRead, ownerWrite), - Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), - }) - } - return reservations, nil -} - -// HasReservation returns true if the given topic access is owned by the user -func (s *commonStore) HasReservation(username, topic string) (bool, error) { - rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic)) - if err != nil { - return false, err - } - defer rows.Close() - if !rows.Next() { - return false, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return false, err - } - return count > 0, nil -} - -// ReservationsCount returns the number of reservations owned by this user -func (s *commonStore) ReservationsCount(username string) (int64, error) { - rows, err := s.db.Query(s.queries.selectUserReservationsCount, username) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil -} - -// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone -func (s *commonStore) ReservationOwner(topic string) (string, error) { - rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic)) - if err != nil { - return "", err - } - defer rows.Close() - if !rows.Next() { - return "", nil - } - var ownerUserID string - if err := rows.Scan(&ownerUserID); err != nil { - return "", err - } - return ownerUserID, nil -} - -// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user -func (s *commonStore) OtherAccessCount(username, topic string) (int, error) { - rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil -} - -// AddTier creates a new tier in the database -func (s *commonStore) AddTier(tier *Tier) error { - if tier.ID == "" { - tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) - } - if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { - return err - } - return nil -} - -// UpdateTier updates a tier's properties in the database -func (s *commonStore) UpdateTier(tier *Tier) error { - if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { - return err - } - return nil -} - -// RemoveTier deletes the tier with the given code -func (s *commonStore) RemoveTier(code string) error { - if !AllowedTier(code) { - return ErrInvalidArgument - } - // This fails if any user has this tier - if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil { - return err - } - return nil -} - -// Tiers returns a list of all Tier structs -func (s *commonStore) Tiers() ([]*Tier, error) { - rows, err := s.db.Query(s.queries.selectTiers) - if err != nil { - return nil, err - } - defer rows.Close() - tiers := make([]*Tier, 0) - for { - tier, err := s.readTier(rows) - if errors.Is(err, ErrTierNotFound) { - break - } else if err != nil { - return nil, err - } - tiers = append(tiers, tier) - } - return tiers, nil -} - -// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist -func (s *commonStore) Tier(code string) (*Tier, error) { - rows, err := s.db.Query(s.queries.selectTierByCode, code) - if err != nil { - return nil, err - } - defer rows.Close() - return s.readTier(rows) -} - -// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist -func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) { - rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID) - if err != nil { - return nil, err - } - defer rows.Close() - return s.readTier(rows) -} - -func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) { - var id, code, name string - var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString - var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 - if !rows.Next() { - return nil, ErrTierNotFound - } - if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - // When changed, note readUser() as well - return &Tier{ - ID: id, - Code: code, - Name: name, - MessageLimit: messagesLimit.Int64, - MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, - EmailLimit: emailsLimit.Int64, - CallLimit: callsLimit.Int64, - ReservationLimit: reservationsLimit.Int64, - AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, - AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, - AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, - AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty - StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty - }, nil -} - -// PhoneNumbers returns all phone numbers for the user with the given user ID -func (s *commonStore) PhoneNumbers(userID string) ([]string, error) { - rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID) - if err != nil { - return nil, err - } - defer rows.Close() - phoneNumbers := make([]string, 0) - for { - phoneNumber, err := s.readPhoneNumber(rows) - if errors.Is(err, ErrPhoneNumberNotFound) { - break - } else if err != nil { - return nil, err - } - phoneNumbers = append(phoneNumbers, phoneNumber) - } - return phoneNumbers, nil -} - -// AddPhoneNumber adds a phone number to the user with the given user ID -func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error { - if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil { - if isUniqueConstraintError(err) { - return ErrPhoneNumberExists - } - return err - } - return nil -} - -// RemovePhoneNumber deletes a phone number from the user with the given user ID -func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error { - _, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber) - return err -} -func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) { - var phoneNumber string - if !rows.Next() { - return "", ErrPhoneNumberNotFound - } - if err := rows.Scan(&phoneNumber); err != nil { - return "", err - } else if err := rows.Err(); err != nil { - return "", err - } - return phoneNumber, nil -} - -// ChangeBilling updates a user's billing fields -func (s *commonStore) ChangeBilling(username string, billing *Billing) error { - if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { - return err - } - return nil -} - -// UserIDByUsername returns the user ID for the given username -func (s *commonStore) UserIDByUsername(username string) (string, error) { - rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username) - if err != nil { - return "", err - } - defer rows.Close() - if !rows.Next() { - return "", ErrUserNotFound - } - var userID string - if err := rows.Scan(&userID); err != nil { - return "", err - } - return userID, nil -} - -// Close closes the underlying database -func (s *commonStore) Close() error { - return s.db.Close() -} - -// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL -func isUniqueConstraintError(err error) bool { - errStr := err.Error() - return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505") -} diff --git a/user/store_test.go b/user/store_test.go deleted file mode 100644 index 6ed88722..00000000 --- a/user/store_test.go +++ /dev/null @@ -1,713 +0,0 @@ -package user_test - -import ( - "net/netip" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/require" - dbtest "heckel.io/ntfy/v2/db/test" - "heckel.io/ntfy/v2/user" -) - -func forEachStoreBackend(t *testing.T, f func(t *testing.T, store user.Store)) { - t.Run("sqlite", func(t *testing.T) { - store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "") - require.Nil(t, err) - t.Cleanup(func() { store.Close() }) - f(t, store) - }) - t.Run("postgres", func(t *testing.T) { - testDB := dbtest.CreateTestPostgres(t) - store, err := user.NewPostgresStore(testDB) - require.Nil(t, err) - f(t, store) - }) -} - -func TestStoreAddUser(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - require.Equal(t, user.RoleUser, u.Role) - require.False(t, u.Provisioned) - require.NotEmpty(t, u.ID) - require.NotEmpty(t, u.SyncTopic) - }) -} - -func TestStoreAddUserAlreadyExists(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false)) - }) -} - -func TestStoreRemoveUser(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - - require.Nil(t, store.RemoveUser("phil")) - _, err = store.User("phil") - require.Equal(t, user.ErrUserNotFound, err) - }) -} - -func TestStoreUserByID(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false)) - u, err := store.User("phil") - require.Nil(t, err) - - u2, err := store.UserByID(u.ID) - require.Nil(t, err) - require.Equal(t, u.Name, u2.Name) - require.Equal(t, u.ID, u2.ID) - }) -} - -func TestStoreUserByToken(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), 0, false) - require.Nil(t, err) - require.Equal(t, "tk_test123", tk.Value) - - u2, err := store.UserByToken(tk.Value) - require.Nil(t, err) - require.Equal(t, "phil", u2.Name) - }) -} - -func TestStoreUserByStripeCustomer(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.ChangeBilling("phil", &user.Billing{ - StripeCustomerID: "cus_test123", - StripeSubscriptionID: "sub_test123", - })) - - u, err := store.UserByStripeCustomer("cus_test123") - require.Nil(t, err) - require.Equal(t, "phil", u.Name) - require.Equal(t, "cus_test123", u.Billing.StripeCustomerID) - }) -} - -func TestStoreUsers(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false)) - - users, err := store.Users() - require.Nil(t, err) - require.True(t, len(users) >= 3) // phil, ben, and the everyone user - }) -} - -func TestStoreUsersCount(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - count, err := store.UsersCount() - require.Nil(t, err) - require.True(t, count >= 1) // At least the everyone user - - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - count2, err := store.UsersCount() - require.Nil(t, err) - require.Equal(t, count+1, count2) - }) -} - -func TestStoreChangePassword(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "philhash", u.Hash) - - require.Nil(t, store.ChangePassword("phil", "newhash")) - u, err = store.User("phil") - require.Nil(t, err) - require.Equal(t, "newhash", u.Hash) - }) -} - -func TestStoreChangeRole(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, user.RoleUser, u.Role) - - require.Nil(t, store.ChangeRole("phil", user.RoleAdmin)) - u, err = store.User("phil") - require.Nil(t, err) - require.Equal(t, user.RoleAdmin, u.Role) - }) -} - -func TestStoreTokens(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - now := time.Now() - expires := now.Add(24 * time.Hour) - origin := netip.MustParseAddr("9.9.9.9") - - tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, 0, false) - require.Nil(t, err) - require.Equal(t, "tk_abc", tk.Value) - require.Equal(t, "my token", tk.Label) - - // Get single token - tk2, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, "tk_abc", tk2.Value) - require.Equal(t, "my token", tk2.Label) - - // Get all tokens - tokens, err := store.Tokens(u.ID) - require.Nil(t, err) - require.Len(t, tokens, 1) - require.Equal(t, "tk_abc", tokens[0].Value) - }) -} - -func TestStoreTokenChange(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - expires := time.Now().Add(time.Hour) - _, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), expires, 0, false) - require.Nil(t, err) - - newExpires := time.Now().Add(2 * time.Hour) - require.Nil(t, store.ChangeToken(u.ID, "tk_abc", "new label", newExpires)) - tk, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, "new label", tk.Label) - require.Equal(t, newExpires.Unix(), tk.Expires.Unix()) - }) -} - -func TestStoreTokenRemove(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) - require.Nil(t, err) - - require.Nil(t, store.RemoveToken(u.ID, "tk_abc")) - _, err = store.Token(u.ID, "tk_abc") - require.Equal(t, user.ErrTokenNotFound, err) - }) -} - -func TestStoreTokenRemoveExpired(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - // Create expired token and active token - _, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), 0, false) - require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) - require.Nil(t, err) - - require.Nil(t, store.RemoveExpiredTokens()) - - // Expired token should be gone - _, err = store.Token(u.ID, "tk_expired") - require.Equal(t, user.ErrTokenNotFound, err) - - // Active token should still exist - tk, err := store.Token(u.ID, "tk_active") - require.Nil(t, err) - require.Equal(t, "tk_active", tk.Value) - }) -} - -func TestStoreTokenCreatePrunesExcess(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - // Create 2 tokens with no pruning - for i, name := range []string{"tk_a", "tk_b"} { - _, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), 0, false) - require.Nil(t, err) - } - - // Create a 3rd token with maxTokenCount=2, which should prune tk_a (earliest expiry) - _, err = store.CreateToken(u.ID, "tk_c", "tk_c", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(3*time.Hour), 2, false) - require.Nil(t, err) - - tokens, err := store.Tokens(u.ID) - require.Nil(t, err) - require.Equal(t, 2, len(tokens)) - - // tk_a should be removed (earliest expiry) - _, err = store.Token(u.ID, "tk_a") - require.Equal(t, user.ErrTokenNotFound, err) - - // tk_b and tk_c should remain - _, err = store.Token(u.ID, "tk_b") - require.Nil(t, err) - _, err = store.Token(u.ID, "tk_c") - require.Nil(t, err) - }) -} - -func TestStoreTokenUpdateLastAccess(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) - require.Nil(t, err) - - newTime := time.Now().Add(5 * time.Minute) - newOrigin := netip.MustParseAddr("5.5.5.5") - require.Nil(t, store.UpdateTokenLastAccess(map[string]*user.TokenUpdate{"tk_abc": {LastAccess: newTime, LastOrigin: newOrigin}})) - - tk, err := store.Token(u.ID, "tk_abc") - require.Nil(t, err) - require.Equal(t, newTime.Unix(), tk.LastAccess.Unix()) - require.Equal(t, newOrigin, tk.LastOrigin) - }) -} - -func TestStoreAllowAccess(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.Equal(t, "mytopic", grants[0].TopicPattern) - require.True(t, grants[0].Permission.IsReadWrite()) - }) -} - -func TestStoreAllowAccessReadOnly(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - - require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false)) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.True(t, grants[0].Permission.IsRead()) - require.False(t, grants[0].Permission.IsWrite()) - }) -} - -func TestStoreResetAccess(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) - - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 2) - - require.Nil(t, store.ResetAccess("phil", "topic1")) - grants, err = store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 1) - require.Equal(t, "topic2", grants[0].TopicPattern) - }) -} - -func TestStoreResetAccessAll(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false)) - - require.Nil(t, store.ResetAccess("phil", "")) - grants, err := store.Grants("phil") - require.Nil(t, err) - require.Len(t, grants, 0) - }) -} - -func TestStoreAuthorizeTopicAccess(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false)) - - read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic") - require.Nil(t, err) - require.True(t, found) - require.True(t, read) - require.True(t, write) - }) -} - -func TestStoreAuthorizeTopicAccessNotFound(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - - _, _, found, err := store.AuthorizeTopicAccess("phil", "other") - require.Nil(t, err) - require.False(t, found) - }) -} - -func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false)) - - read, write, found, err := store.AuthorizeTopicAccess("phil", "secret") - require.Nil(t, err) - require.True(t, found) - require.False(t, read) - require.False(t, write) - }) -} - -func TestStoreReservations(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) - require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false)) - - reservations, err := store.Reservations("phil") - require.Nil(t, err) - require.Len(t, reservations, 1) - require.Equal(t, "mytopic", reservations[0].Topic) - require.True(t, reservations[0].Owner.IsReadWrite()) - require.True(t, reservations[0].Everyone.IsRead()) - require.False(t, reservations[0].Everyone.IsWrite()) - }) -} - -func TestStoreReservationsCount(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false)) - require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false)) - - count, err := store.ReservationsCount("phil") - require.Nil(t, err) - require.Equal(t, int64(2), count) - }) -} - -func TestStoreHasReservation(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) - - has, err := store.HasReservation("phil", "mytopic") - require.Nil(t, err) - require.True(t, has) - - has, err = store.HasReservation("phil", "other") - require.Nil(t, err) - require.False(t, has) - }) -} - -func TestStoreReservationOwner(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false)) - - owner, err := store.ReservationOwner("mytopic") - require.Nil(t, err) - require.NotEmpty(t, owner) // Returns the user ID - - owner, err = store.ReservationOwner("unowned") - require.Nil(t, err) - require.Empty(t, owner) - }) -} - -func TestStoreTiers(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - MessageLimit: 5000, - MessageExpiryDuration: 24 * time.Hour, - EmailLimit: 100, - CallLimit: 10, - ReservationLimit: 20, - AttachmentFileSizeLimit: 10 * 1024 * 1024, - AttachmentTotalSizeLimit: 100 * 1024 * 1024, - AttachmentExpiryDuration: 48 * time.Hour, - AttachmentBandwidthLimit: 500 * 1024 * 1024, - } - require.Nil(t, store.AddTier(tier)) - - // Get by code - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "ti_test", t2.ID) - require.Equal(t, "pro", t2.Code) - require.Equal(t, "Pro", t2.Name) - require.Equal(t, int64(5000), t2.MessageLimit) - require.Equal(t, int64(100), t2.EmailLimit) - require.Equal(t, int64(10), t2.CallLimit) - require.Equal(t, int64(20), t2.ReservationLimit) - - // List all tiers - tiers, err := store.Tiers() - require.Nil(t, err) - require.Len(t, tiers, 1) - require.Equal(t, "pro", tiers[0].Code) - }) -} - -func TestStoreTierUpdate(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) - - tier.Name = "Professional" - tier.MessageLimit = 9999 - require.Nil(t, store.UpdateTier(tier)) - - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "Professional", t2.Name) - require.Equal(t, int64(9999), t2.MessageLimit) - }) -} - -func TestStoreTierRemove(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) - - t2, err := store.Tier("pro") - require.Nil(t, err) - require.Equal(t, "pro", t2.Code) - - require.Nil(t, store.RemoveTier("pro")) - _, err = store.Tier("pro") - require.Equal(t, user.ErrTierNotFound, err) - }) -} - -func TestStoreTierByStripePrice(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - StripeMonthlyPriceID: "price_monthly", - StripeYearlyPriceID: "price_yearly", - } - require.Nil(t, store.AddTier(tier)) - - t2, err := store.TierByStripePrice("price_monthly") - require.Nil(t, err) - require.Equal(t, "pro", t2.Code) - - t3, err := store.TierByStripePrice("price_yearly") - require.Nil(t, err) - require.Equal(t, "pro", t3.Code) - }) -} - -func TestStoreChangeTier(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - tier := &user.Tier{ - ID: "ti_test", - Code: "pro", - Name: "Pro", - } - require.Nil(t, store.AddTier(tier)) - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.ChangeTier("phil", "pro")) - - u, err := store.User("phil") - require.Nil(t, err) - require.NotNil(t, u.Tier) - require.Equal(t, "pro", u.Tier.Code) - }) -} - -func TestStorePhoneNumbers(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890")) - require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321")) - - numbers, err := store.PhoneNumbers(u.ID) - require.Nil(t, err) - require.Len(t, numbers, 2) - - require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890")) - numbers, err = store.PhoneNumbers(u.ID) - require.Nil(t, err) - require.Len(t, numbers, 1) - require.Equal(t, "+0987654321", numbers[0]) - }) -} - -func TestStoreChangeSettings(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - lang := "de" - prefs := &user.Prefs{Language: &lang} - require.Nil(t, store.ChangeSettings(u.ID, prefs)) - - u2, err := store.User("phil") - require.Nil(t, err) - require.NotNil(t, u2.Prefs) - require.Equal(t, "de", *u2.Prefs.Language) - }) -} - -func TestStoreChangeBilling(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - - billing := &user.Billing{ - StripeCustomerID: "cus_123", - StripeSubscriptionID: "sub_456", - } - require.Nil(t, store.ChangeBilling("phil", billing)) - - u, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, "cus_123", u.Billing.StripeCustomerID) - require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID) - }) -} - -func TestStoreUpdateStats(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1} - require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: stats})) - - u2, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, int64(42), u2.Stats.Messages) - require.Equal(t, int64(3), u2.Stats.Emails) - require.Equal(t, int64(1), u2.Stats.Calls) - }) -} - -func TestStoreResetStats(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: {Messages: 42, Emails: 3, Calls: 1}})) - require.Nil(t, store.ResetStats()) - - u2, err := store.User("phil") - require.Nil(t, err) - require.Equal(t, int64(0), u2.Stats.Messages) - require.Equal(t, int64(0), u2.Stats.Emails) - require.Equal(t, int64(0), u2.Stats.Calls) - }) -} - -func TestStoreMarkUserRemoved(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - require.Nil(t, store.MarkUserRemoved(u.ID, u.Name)) - - u2, err := store.User("phil") - require.Nil(t, err) - require.True(t, u2.Deleted) - }) -} - -func TestStoreRemoveDeletedUsers(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - u, err := store.User("phil") - require.Nil(t, err) - - require.Nil(t, store.MarkUserRemoved(u.ID, u.Name)) - - // RemoveDeletedUsers only removes users past the hard-delete duration (7 days). - // Immediately after marking, the user should still exist. - require.Nil(t, store.RemoveDeletedUsers()) - u2, err := store.User("phil") - require.Nil(t, err) - require.True(t, u2.Deleted) - }) -} - -func TestStoreAllGrants(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) - phil, err := store.User("phil") - require.Nil(t, err) - ben, err := store.User("ben") - require.Nil(t, err) - - require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false)) - require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false)) - - grants, err := store.AllGrants() - require.Nil(t, err) - require.Contains(t, grants, phil.ID) - require.Contains(t, grants, ben.ID) - }) -} - -func TestStoreOtherAccessCount(t *testing.T) { - forEachStoreBackend(t, func(t *testing.T, store user.Store) { - require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) - require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false)) - require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false)) - - count, err := store.OtherAccessCount("phil", "mytopic") - require.Nil(t, err) - require.Equal(t, 1, count) - }) -} diff --git a/user/types.go b/user/types.go index be589643..f5d5e70e 100644 --- a/user/types.go +++ b/user/types.go @@ -1,12 +1,14 @@ package user import ( + "database/sql" "errors" - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/payments" "net/netip" "strings" "time" + + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/payments" ) // User is a struct that represents a user @@ -273,3 +275,78 @@ var ( ErrProvisionedUserChange = errors.New("cannot change or delete provisioned user") ErrProvisionedTokenChange = errors.New("cannot change or delete provisioned token") ) + +// storeQueries holds the database-specific SQL queries +type storeQueries struct { + // User queries + selectUserByID string + selectUserByName string + selectUserByToken string + selectUserByStripeCustomerID string + selectUsernames string + selectUserCount string + selectUserIDFromUsername string + insertUser string + updateUserPass string + updateUserRole string + updateUserProvisioned string + updateUserPrefs string + updateUserStats string + updateUserStatsResetAll string + updateUserTier string + updateUserDeleted string + deleteUser string + deleteUserTier string + deleteUsersMarked string + + // Access queries + selectTopicPerms string + selectUserAllAccess string + selectUserAccess string + selectUserReservations string + selectUserReservationsCount string + selectUserReservationsOwner string + selectUserHasReservation string + selectOtherAccessCount string + upsertUserAccess string + deleteUserAccess string + deleteUserAccessProvisioned string + deleteTopicAccess string + deleteAllAccess string + + // Token queries + selectToken string + selectTokens string + selectTokenCount string + selectAllProvisionedTokens string + upsertToken string + updateToken string + updateTokenLastAccess string + deleteToken string + deleteProvisionedToken string + deleteAllToken string + deleteExpiredTokens string + deleteExcessTokens string + + // Tier queries + insertTier string + selectTiers string + selectTierByCode string + selectTierByPriceID string + updateTier string + deleteTier string + + // Phone queries + selectPhoneNumbers string + insertPhoneNumber string + deletePhoneNumber string + + // Billing queries + updateBilling string +} + +// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods +// to be used both standalone and within a transaction. +type execer interface { + Exec(query string, args ...any) (sql.Result, error) +}