// Package user deals with authentication and authorization against topics package user import ( "database/sql" "encoding/json" "errors" "fmt" "net/netip" "slices" "strings" "sync" "time" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/util" ) const ( tierIDPrefix = "ti_" tierIDLength = 8 syncTopicPrefix = "st_" syncTopicLength = 16 userIDPrefix = "u_" userIDLength = 12 userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost userHardDeleteAfterDuration = 7 * 24 * time.Hour tokenPrefix = "tk_" tokenLength = 32 tokenMaxCount = 60 // Only keep this many tokens in the table per user tag = "user_manager" ) // Default constants that may be overridden by configs const ( DefaultUserStatsQueueWriterInterval = 33 * time.Second DefaultUserPasswordBcryptCost = 10 ) var ( errNoTokenProvided = errors.New("no token provided") errTopicOwnedByOthers = errors.New("topic owned by others") errNoRows = errors.New("no rows found") ) // Manager handles user authentication, authorization, and management type Manager struct { config *Config db *sql.DB queries queries statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate) mu sync.Mutex } var _ Auther = (*Manager)(nil) func newManager(db *sql.DB, queries queries, config *Config) (*Manager, error) { if config.BcryptCost <= 0 { config.BcryptCost = DefaultUserPasswordBcryptCost } if config.QueueWriterInterval.Seconds() <= 0 { config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval } manager := &Manager{ config: config, db: db, statsQueue: make(map[string]*Stats), tokenQueue: make(map[string]*TokenUpdate), queries: queries, } if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil { return nil, err } go manager.asyncQueueWriter(manager.config.QueueWriterInterval) return manager, nil } // Authenticate checks username and password and returns a User if correct, and the user has not been // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or // the password is correct or incorrect. func (a *Manager) Authenticate(username, password string) (*User, error) { if username == Everyone { return nil, ErrUnauthenticated } user, err := a.User(username) if err != nil { log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)") bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) return nil, ErrUnauthenticated } else if user.Deleted { log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted") bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks")) return nil, ErrUnauthenticated } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil { log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)") return nil, ErrUnauthenticated } return user, nil } // AuthenticateToken checks if the token exists and returns the associated User if it does. // The method sets the User.Token value to the token that was used for authentication. func (a *Manager) AuthenticateToken(token string) (*User, error) { if len(token) != tokenLength { return nil, ErrUnauthenticated } user, err := a.userByToken(token) if err != nil { log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed") return nil, ErrUnauthenticated } user.Token = token return user, nil } // AddUser adds a user with the given username, password and role func (a *Manager) AddUser(username, password string, role Role, hashed bool) error { hash, err := a.maybeHashPassword(password, hashed) if err != nil { return err } return execTx(a.db, func(tx *sql.Tx) error { return a.addUserTx(tx, username, hash, role, false) }) } // addUserTx adds a user with the given username, password hash and role to the database func (a *Manager) addUserTx(tx *sql.Tx, username, hash string, role Role, provisioned bool) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } userID := util.RandomStringPrefix(userIDPrefix, userIDLength) syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) now := time.Now().Unix() if _, err := tx.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil { if isUniqueConstraintError(err) { return ErrUserExists } return err } return nil } // RemoveUser deletes the user with the given username. The function returns nil on success, even // if the user did not exist in the first place. func (a *Manager) RemoveUser(username string) error { if err := a.CanChangeUser(username); err != nil { return err } return execTx(a.db, func(tx *sql.Tx) error { return a.removeUserTx(tx, username) }) } // removeUserTx deletes the user with the given username func (a *Manager) removeUserTx(tx *sql.Tx, username string) error { if !AllowedUsername(username) { return ErrInvalidArgument } // Rows in user_access, user_token, etc. are deleted via foreign keys if _, err := tx.Exec(a.queries.deleteUser, username); err != nil { return err } return nil } // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents // successful auth via Authenticate. A background process will delete the user at a later date. func (a *Manager) MarkUserRemoved(user *User) error { if !AllowedUsername(user.Name) { return ErrInvalidArgument } return execTx(a.db, func(tx *sql.Tx) error { if err := a.resetUserAccessTx(tx, user.Name); err != nil { return err } if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil { return err } deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix() if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil { return err } return nil }) } // RemoveDeletedUsers deletes all users that have been marked deleted func (a *Manager) RemoveDeletedUsers() error { if _, err := a.db.Exec(a.queries.deleteUsersMarked, time.Now().Unix()); err != nil { return err } return nil } // ChangePassword changes a user's password func (a *Manager) ChangePassword(username, password string, hashed bool) error { if err := a.CanChangeUser(username); err != nil { return err } hash, err := a.maybeHashPassword(password, hashed) if err != nil { return err } return execTx(a.db, func(tx *sql.Tx) error { return a.changePasswordHashTx(tx, username, hash) }) } // changePasswordHashTx changes a user's password hash in the database func (a *Manager) changePasswordHashTx(tx *sql.Tx, username, hash string) error { if _, err := tx.Exec(a.queries.updateUserPass, hash, username); err != nil { return err } return nil } // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin, // all existing access control entries (Grant) are removed, since they are no longer needed. func (a *Manager) ChangeRole(username string, role Role) error { if err := a.CanChangeUser(username); err != nil { return err } return execTx(a.db, func(tx *sql.Tx) error { return a.changeRoleTx(tx, username, role) }) } // changeRoleTx changes a user's role func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error { if !AllowedUsername(username) || !AllowedRole(role) { return ErrInvalidArgument } if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil { return err } // If changing to admin, remove all access entries if role == RoleAdmin { if err := a.resetUserAccessTx(tx, username); err != nil { return err } } return nil } // CanChangeUser checks if the user with the given username can be changed. // This is used to prevent changes to provisioned users, which are defined in the config file. func (a *Manager) CanChangeUser(username string) error { user, err := a.User(username) if err != nil { return err } else if user.Provisioned { return ErrProvisionedUserChange } return nil } // changeProvisionedTx changes the provisioned status of a user func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error { if _, err := tx.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil { return err } return nil } // ChangeSettings persists the user settings func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error { b, err := json.Marshal(prefs) if err != nil { return err } if _, err := a.db.Exec(a.queries.updateUserPrefs, string(b), userID); err != nil { return err } return nil } // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages, // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere. func (a *Manager) ChangeTier(username, tier string) error { if !AllowedUsername(username) { return ErrInvalidArgument } t, err := a.Tier(tier) if err != nil { return err } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil { return err } if _, err := a.db.Exec(a.queries.updateUserTier, tier, username); err != nil { return err } return nil } // ResetTier removes the tier from the given user func (a *Manager) ResetTier(username string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if err := a.checkReservationsLimit(username, 0); err != nil { return err } _, err := a.db.Exec(a.queries.deleteUserTier, username) return err } func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error { u, err := a.User(username) if err != nil { return err } if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit { reservations, err := a.Reservations(username) if err != nil { return err } else if int64(len(reservations)) > reservationsLimit { return ErrTooManyReservations } } return nil } // ResetStats resets all user stats in the user database. This touches all users. func (a *Manager) ResetStats() error { a.mu.Lock() // Includes database query to avoid races! defer a.mu.Unlock() if _, err := a.db.Exec(a.queries.updateUserStatsResetAll); err != nil { return err } a.statsQueue = make(map[string]*Stats) return nil } // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in // batches at a regular interval func (a *Manager) EnqueueUserStats(userID string, stats *Stats) { a.mu.Lock() defer a.mu.Unlock() a.statsQueue[userID] = stats } func (a *Manager) asyncQueueWriter(interval time.Duration) { ticker := time.NewTicker(interval) for range ticker.C { if err := a.writeUserStatsQueue(); err != nil { log.Tag(tag).Err(err).Warn("Writing user stats queue failed") } if err := a.writeTokenUpdateQueue(); err != nil { log.Tag(tag).Err(err).Warn("Writing token update queue failed") } } } func (a *Manager) writeUserStatsQueue() error { a.mu.Lock() if len(a.statsQueue) == 0 { a.mu.Unlock() log.Tag(tag).Trace("No user stats updates to commit") return nil } statsQueue := a.statsQueue a.statsQueue = make(map[string]*Stats) a.mu.Unlock() return execTx(a.db, func(tx *sql.Tx) error { log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue)) for userID, update := range statsQueue { log. Tag(tag). Fields(log.Context{ "user_id": userID, "messages_count": update.Messages, "emails_count": update.Emails, "calls_count": update.Calls, }). Trace("Updating stats for user %s", userID) if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil { return err } } return nil }) } // User returns the user with the given username if it exists, or ErrUserNotFound otherwise func (a *Manager) User(username string) (*User, error) { rows, err := a.db.Query(a.queries.selectUserByName, username) if err != nil { return nil, err } return a.readUser(rows) } // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByID(id string) (*User, error) { rows, err := a.db.Query(a.queries.selectUserByID, id) if err != nil { return nil, err } return a.readUser(rows) } // userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise func (a *Manager) userByToken(token string) (*User, error) { rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix()) if err != nil { return nil, err } return a.readUser(rows) } // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) { rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID) if err != nil { return nil, err } return a.readUser(rows) } // Users returns a list of users func (a *Manager) Users() ([]*User, error) { rows, err := a.db.Query(a.queries.selectUsernames) if err != nil { return nil, err } defer rows.Close() usernames := make([]string, 0) for rows.Next() { var username string if err := rows.Scan(&username); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } usernames = append(usernames, username) } rows.Close() users := make([]*User, 0) for _, username := range usernames { user, err := a.User(username) if err != nil { return nil, err } users = append(users, user) } return users, nil } // UsersCount returns the number of users in the database func (a *Manager) UsersCount() (int64, error) { rows, err := a.db.Query(a.queries.selectUserCount) if err != nil { return 0, err } defer rows.Close() if !rows.Next() { return 0, errNoRows } var count int64 if err := rows.Scan(&count); err != nil { return 0, err } return count, nil } func (a *Manager) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string var provisioned bool var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString var messages, emails, calls int64 var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64 if !rows.Next() { return nil, ErrUserNotFound } if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } user := &User{ ID: id, Name: username, Hash: hash, Role: Role(role), Prefs: &Prefs{}, SyncTopic: syncTopic, Provisioned: provisioned, Stats: &Stats{ Messages: messages, Emails: emails, Calls: calls, }, Billing: &Billing{ StripeCustomerID: stripeCustomerID.String, // May be empty StripeSubscriptionID: stripeSubscriptionID.String, // May be empty StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero }, Deleted: deleted.Valid, } if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil { return nil, err } if tierCode.Valid { // See readTier() when this is changed! user.Tier = &Tier{ ID: tierID.String, Code: tierCode.String, Name: tierName.String, MessageLimit: messagesLimit.Int64, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, EmailLimit: emailsLimit.Int64, CallLimit: callsLimit.Int64, ReservationLimit: reservationsLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty } } return user, nil } func (a *Manager) maybeHashPassword(password string, hashed bool) (string, error) { if hashed { if err := ValidPasswordHash(password, a.config.BcryptCost); err != nil { return "", err } return password, nil } return hashPassword(password, a.config.BcryptCost) } // Authorize returns nil if the given user has access to the given topic using the desired // permission. The user param may be nil to signal an anonymous user. func (a *Manager) Authorize(user *User, topic string, perm Permission) error { if user != nil && user.Role == RoleAdmin { return nil // Admin can do everything } username := Everyone if user != nil { username = user.Name } // Select the read/write permissions for this user/topic combo. read, write, found, err := a.authorizeTopicAccess(username, topic) if err != nil { return err } else if !found { return a.resolvePerms(a.config.DefaultAccess, perm) } return a.resolvePerms(NewPermission(read, write), perm) } func (a *Manager) resolvePerms(base, perm Permission) error { if perm == PermissionRead && base.IsRead() { return nil } else if perm == PermissionWrite && base.IsWrite() { return nil } return ErrUnauthorized } // AllowAccess adds or updates an entry in the access control list for a specific user. It controls // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // owner may either be a user (username), or the system (empty). func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { return execTx(a.db, func(tx *sql.Tx) error { return a.allowAccessTx(tx, username, topicPattern, permission, false) }) } func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), "", "", provisioned) return err } // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // empty) for an entire user. The parameter topicPattern may include wildcards (*). func (a *Manager) ResetAccess(username string, topicPattern string) error { return execTx(a.db, func(tx *sql.Tx) error { return a.resetAccessTx(tx, username, topicPattern) }) } func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { return ErrInvalidArgument } if username == "" && topicPattern == "" { _, err := tx.Exec(a.queries.deleteAllAccess) return err } else if topicPattern == "" { return a.resetUserAccessTx(tx, username) } return a.resetTopicAccessTx(tx, username, topicPattern) } // DefaultAccess returns the default read/write access if no access control entry matches func (a *Manager) DefaultAccess() Permission { return a.config.DefaultAccess } // AllowReservation tests if a user may create an access control entry for the given topic. // If there are any ACL entries that are not owned by the user, an error is returned. func (a *Manager) AllowReservation(username string, topic string) error { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { return ErrInvalidArgument } otherCount, err := a.otherAccessCount(username, topic) if err != nil { return err } if otherCount > 0 { return errTopicOwnedByOthers } return nil } // authorizeTopicAccess returns the read/write permissions for the given username and topic. // The found return value indicates whether an ACL entry was found at all. // // - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" // - It also prioritizes write permissions over read permissions func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) if err != nil { return false, false, false, err } defer rows.Close() if !rows.Next() { return false, false, false, nil } if err := rows.Scan(&read, &write); err != nil { return false, false, false, err } else if err := rows.Err(); err != nil { return false, false, false, err } return read, write, true, nil } // AllGrants returns all user-specific access control entries, mapped to their respective user IDs func (a *Manager) AllGrants() (map[string][]Grant, error) { rows, err := a.db.Query(a.queries.selectUserAllAccess) if err != nil { return nil, err } defer rows.Close() grants := make(map[string][]Grant, 0) for rows.Next() { var userID, topic string var read, write, provisioned bool if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } if _, ok := grants[userID]; !ok { grants[userID] = make([]Grant, 0) } grants[userID] = append(grants[userID], Grant{ TopicPattern: fromSQLWildcard(topic), Permission: NewPermission(read, write), Provisioned: provisioned, }) } return grants, nil } // Grants returns all user-specific access control entries func (a *Manager) Grants(username string) ([]Grant, error) { rows, err := a.db.Query(a.queries.selectUserAccess, username) if err != nil { return nil, err } defer rows.Close() grants := make([]Grant, 0) for rows.Next() { var topic string var read, write, provisioned bool if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } grants = append(grants, Grant{ TopicPattern: fromSQLWildcard(topic), Permission: NewPermission(read, write), Provisioned: provisioned, }) } return grants, nil } // AddReservation creates two access control entries for the given topic: one with full read/write // access for the given user, and one for Everyone with the given permission. Both entries are // created atomically in a single transaction. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error { if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { return ErrInvalidArgument } return execTx(a.db, func(tx *sql.Tx) error { if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil { return err } return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username) }) } // RemoveReservations deletes the access control entries associated with the given username/topic, // as well as all entries with Everyone/topic. All deletions are performed atomically in a single // transaction. func (a *Manager) RemoveReservations(username string, topics ...string) error { if !AllowedUsername(username) || username == Everyone || len(topics) == 0 { return ErrInvalidArgument } for _, topic := range topics { if !AllowedTopic(topic) { return ErrInvalidArgument } } return execTx(a.db, func(tx *sql.Tx) error { for _, topic := range topics { if err := a.resetTopicAccessTx(tx, username, topic); err != nil { return err } if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil { return err } } return nil }) } // Reservations returns all user-owned topics, and the associated everyone-access func (a *Manager) Reservations(username string) ([]Reservation, error) { rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username) if err != nil { return nil, err } defer rows.Close() reservations := make([]Reservation, 0) for rows.Next() { var topic string var ownerRead, ownerWrite bool var everyoneRead, everyoneWrite sql.NullBool if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } reservations = append(reservations, Reservation{ Topic: fromSQLWildcard(topic), Owner: NewPermission(ownerRead, ownerWrite), Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), }) } return reservations, nil } // HasReservation returns true if the given topic access is owned by the user func (a *Manager) HasReservation(username, topic string) (bool, error) { rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic)) if err != nil { return false, err } defer rows.Close() if !rows.Next() { return false, errNoRows } var count int64 if err := rows.Scan(&count); err != nil { return false, err } return count > 0, nil } // ReservationsCount returns the number of reservations owned by this user func (a *Manager) ReservationsCount(username string) (int64, error) { rows, err := a.db.Query(a.queries.selectUserReservationsCount, username) if err != nil { return 0, err } defer rows.Close() if !rows.Next() { return 0, errNoRows } var count int64 if err := rows.Scan(&count); err != nil { return 0, err } return count, nil } // ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone func (a *Manager) ReservationOwner(topic string) (string, error) { rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic)) if err != nil { return "", err } defer rows.Close() if !rows.Next() { return "", nil } var ownerUserID string if err := rows.Scan(&ownerUserID); err != nil { return "", err } return ownerUserID, nil } // otherAccessCount returns the number of access entries for the given topic that are not owned by the user func (a *Manager) otherAccessCount(username, topic string) (int, error) { rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username) if err != nil { return 0, err } defer rows.Close() if !rows.Next() { return 0, errNoRows } var count int if err := rows.Scan(&count); err != nil { return 0, err } return count, nil } func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } else if !AllowedTopicPattern(topic) { return ErrInvalidArgument } _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false) return err } func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } _, err := tx.Exec(a.queries.deleteUserAccess, username, username) return err } func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { return ErrInvalidArgument } _, err := tx.Exec(a.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) return err } // CreateToken generates a random token for the given user and returns it. The token expires // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { return queryTx(a.db, func(tx *sql.Tx) (*Token, error) { return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) }) } // createTokenTx creates a new token and prunes excess tokens if the count exceeds maxTokenCount. // If maxTokenCount is 0, no pruning is performed. func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) { if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { return nil, err } if maxTokenCount > 0 { var tokenCount int if err := tx.QueryRow(a.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil { return nil, err } if tokenCount > maxTokenCount { // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup // on two indices, whereas the query below is a full table scan. if _, err := tx.Exec(a.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { return nil, err } } } return &Token{ Value: token, Label: label, LastAccess: lastAccess, LastOrigin: lastOrigin, Expires: expires, Provisioned: provisioned, }, nil } // ChangeToken updates a token's label and/or expiry date func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) { if token == "" { return nil, errNoTokenProvided } if err := a.canChangeToken(userID, token); err != nil { return nil, err } t, err := a.Token(userID, token) if err != nil { return nil, err } if label != nil { t.Label = *label } if expires != nil { t.Expires = *expires } if _, err := a.db.Exec(a.queries.updateToken, t.Label, t.Expires.Unix(), userID, token); err != nil { return nil, err } return t, nil } // RemoveToken deletes the token defined in User.Token func (a *Manager) RemoveToken(userID, token string) error { if err := a.canChangeToken(userID, token); err != nil { return err } if token == "" { return errNoTokenProvided } if _, err := a.db.Exec(a.queries.deleteToken, userID, token); err != nil { return err } return nil } // canChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed. func (a *Manager) canChangeToken(userID, token string) error { t, err := a.Token(userID, token) if err != nil { return err } else if t.Provisioned { return ErrProvisionedTokenChange } return nil } // Token returns a specific token for a user func (a *Manager) Token(userID, token string) (*Token, error) { rows, err := a.db.Query(a.queries.selectToken, userID, token) if err != nil { return nil, err } defer rows.Close() return a.readToken(rows) } // Tokens returns all existing tokens for the user with the given user ID func (a *Manager) Tokens(userID string) ([]*Token, error) { rows, err := a.db.Query(a.queries.selectTokens, userID) if err != nil { return nil, err } defer rows.Close() tokens := make([]*Token, 0) for { token, err := a.readToken(rows) if errors.Is(err, ErrTokenNotFound) { break } else if err != nil { return nil, err } tokens = append(tokens, token) } return tokens, nil } func (a *Manager) allProvisionedTokens() ([]*Token, error) { rows, err := a.db.Query(a.queries.selectAllProvisionedTokens) if err != nil { return nil, err } defer rows.Close() tokens := make([]*Token, 0) for { token, err := a.readToken(rows) if errors.Is(err, ErrTokenNotFound) { break } else if err != nil { return nil, err } tokens = append(tokens, token) } return tokens, nil } // RemoveExpiredTokens deletes all expired tokens from the database func (a *Manager) RemoveExpiredTokens() error { if _, err := a.db.Exec(a.queries.deleteExpiredTokens, time.Now().Unix()); err != nil { return err } return nil } // EnqueueTokenUpdate adds the token update to a queue which writes out token access times // in batches at a regular interval func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) { a.mu.Lock() defer a.mu.Unlock() a.tokenQueue[tokenID] = update } func (a *Manager) writeTokenUpdateQueue() error { a.mu.Lock() if len(a.tokenQueue) == 0 { a.mu.Unlock() log.Tag(tag).Trace("No token updates to commit") return nil } tokenQueue := a.tokenQueue a.tokenQueue = make(map[string]*TokenUpdate) a.mu.Unlock() return execTx(a.db, func(tx *sql.Tx) error { log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil { return err } } return nil }) } func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error { if _, err := tx.Exec(a.queries.updateTokenLastAccess, lastAccess, lastOrigin, token); err != nil { return err } return nil } func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { var token, label, lastOrigin string var lastAccess, expires int64 var provisioned bool if !rows.Next() { return nil, ErrTokenNotFound } if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } lastOriginIP, err := netip.ParseAddr(lastOrigin) if err != nil { lastOriginIP = netip.IPv4Unspecified() } return &Token{ Value: token, Label: label, LastAccess: time.Unix(lastAccess, 0), LastOrigin: lastOriginIP, Expires: time.Unix(expires, 0), Provisioned: provisioned, }, nil } // AddTier creates a new tier in the database func (a *Manager) AddTier(tier *Tier) error { if tier.ID == "" { tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength) } if _, err := a.db.Exec(a.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil { return err } return nil } // UpdateTier updates a tier's properties in the database func (a *Manager) UpdateTier(tier *Tier) error { if _, err := a.db.Exec(a.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil { return err } return nil } // RemoveTier deletes the tier with the given code func (a *Manager) RemoveTier(code string) error { if !AllowedTier(code) { return ErrInvalidArgument } // This fails if any user has this tier if _, err := a.db.Exec(a.queries.deleteTier, code); err != nil { return err } return nil } // Tiers returns a list of all Tier structs func (a *Manager) Tiers() ([]*Tier, error) { rows, err := a.db.Query(a.queries.selectTiers) if err != nil { return nil, err } defer rows.Close() tiers := make([]*Tier, 0) for { tier, err := a.readTier(rows) if errors.Is(err, ErrTierNotFound) { break } else if err != nil { return nil, err } tiers = append(tiers, tier) } return tiers, nil } // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist func (a *Manager) Tier(code string) (*Tier, error) { rows, err := a.db.Query(a.queries.selectTierByCode, code) if err != nil { return nil, err } defer rows.Close() return a.readTier(rows) } // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) { rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID) if err != nil { return nil, err } defer rows.Close() return a.readTier(rows) } func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) { var id, code, name string var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64 if !rows.Next() { return nil, ErrTierNotFound } if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } // When changed, note readUser() as well return &Tier{ ID: id, Code: code, Name: name, MessageLimit: messagesLimit.Int64, MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second, EmailLimit: emailsLimit.Int64, CallLimit: callsLimit.Int64, ReservationLimit: reservationsLimit.Int64, AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty }, nil } // PhoneNumbers returns all phone numbers for the user with the given user ID func (a *Manager) PhoneNumbers(userID string) ([]string, error) { rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID) if err != nil { return nil, err } defer rows.Close() phoneNumbers := make([]string, 0) for { phoneNumber, err := a.readPhoneNumber(rows) if errors.Is(err, ErrPhoneNumberNotFound) { break } else if err != nil { return nil, err } phoneNumbers = append(phoneNumbers, phoneNumber) } return phoneNumbers, nil } // AddPhoneNumber adds a phone number to the user with the given user ID func (a *Manager) AddPhoneNumber(userID, phoneNumber string) error { if _, err := a.db.Exec(a.queries.insertPhoneNumber, userID, phoneNumber); err != nil { if isUniqueConstraintError(err) { return ErrPhoneNumberExists } return err } return nil } // RemovePhoneNumber deletes a phone number from the user with the given user ID func (a *Manager) RemovePhoneNumber(userID, phoneNumber string) error { _, err := a.db.Exec(a.queries.deletePhoneNumber, userID, phoneNumber) return err } func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) { var phoneNumber string if !rows.Next() { return "", ErrPhoneNumberNotFound } if err := rows.Scan(&phoneNumber); err != nil { return "", err } else if err := rows.Err(); err != nil { return "", err } return phoneNumber, nil } // ChangeBilling updates a user's billing fields func (a *Manager) ChangeBilling(username string, billing *Billing) error { if _, err := a.db.Exec(a.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil { return err } return nil } // maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config. func (a *Manager) maybeProvisionUsersAccessAndTokens() error { if !a.config.ProvisionEnabled { return nil } existingUsers, err := a.Users() if err != nil { return err } provisionUsernames := util.Map(a.config.Users, func(u *User) string { return u.Name }) existingTokens, err := a.allProvisionedTokens() if err != nil { return err } return execTx(a.db, func(tx *sql.Tx) error { if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil { return fmt.Errorf("failed to provision users: %v", err) } if err := a.maybeProvisionGrants(tx); err != nil { return fmt.Errorf("failed to provision grants: %v", err) } if err := a.maybeProvisionTokens(tx, provisionUsernames, existingTokens); err != nil { return fmt.Errorf("failed to provision tokens: %v", err) } return nil }) } // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them. // It also removes users that are provisioned, but not in the config anymore. func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error { // Remove users that are provisioned, but not in the config anymore for _, user := range existingUsers { if user.Name == Everyone { continue } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { if err := a.removeUserTx(tx, user.Name); err != nil { return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) } } } // Add or update provisioned users for _, user := range a.config.Users { if user.Name == Everyone { continue } existingUser, exists := util.Find(existingUsers, func(u *User) bool { return u.Name == user.Name }) if !exists { if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) { return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) } } else { if !existingUser.Provisioned { if err := a.changeProvisionedTx(tx, user.Name, true); err != nil { return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) } } if existingUser.Hash != user.Hash { if err := a.changePasswordHashTx(tx, user.Name, user.Hash); err != nil { return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) } } if existingUser.Role != user.Role { if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil { return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) } } } } return nil } // maybeProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. // // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last // access time) or do not have dependent resources (such as grants or tokens). func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error { // Remove all provisioned grants if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil { return err } // (Re-)add provisioned grants for username, grants := range a.config.Access { user, exists := util.Find(a.config.Users, func(u *User) bool { return u.Name == username }) if !exists && username != Everyone { return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username) } else if user != nil && user.Role == RoleAdmin { return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username) } for _, grant := range grants { if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil { return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err) } if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil { return err } } } return nil } func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string, existingTokens []*Token) error { // Remove tokens that are provisioned, but not in the config anymore var provisionTokens []string for _, userTokens := range a.config.Tokens { for _, token := range userTokens { provisionTokens = append(provisionTokens, token.Value) } } for _, existingToken := range existingTokens { if !slices.Contains(provisionTokens, existingToken.Value) { if _, err := tx.Exec(a.queries.deleteProvisionedToken, existingToken.Value); err != nil { return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err) } } } // (Re-)add provisioned tokens for username, tokens := range a.config.Tokens { if !slices.Contains(provisionUsernames, username) && username != Everyone { return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) } var userID string if err := tx.QueryRow(a.queries.selectUserIDFromUsername, username).Scan(&userID); err != nil { return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err) } for _, token := range tokens { if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil { return err } } } return nil } // Close closes the underlying database func (a *Manager) Close() error { return a.db.Close() } // isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL func isUniqueConstraintError(err error) bool { errStr := err.Error() return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505") }