Files
ntfy/user/manager.go
2026-02-28 17:35:35 -05:00

1497 lines
49 KiB
Go

// Package user deals with authentication and authorization against topics
package user
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util"
)
const (
tierIDPrefix = "ti_"
tierIDLength = 8
syncTopicPrefix = "st_"
syncTopicLength = 16
userIDPrefix = "u_"
userIDLength = 12
userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
userHardDeleteAfterDuration = 7 * 24 * time.Hour
tokenPrefix = "tk_"
tokenLength = 32
tokenMaxCount = 60 // Only keep this many tokens in the table per user
tag = "user_manager"
)
// Default constants that may be overridden by configs
const (
DefaultUserStatsQueueWriterInterval = 33 * time.Second
DefaultUserPasswordBcryptCost = 10
)
var (
errNoTokenProvided = errors.New("no token provided")
errTopicOwnedByOthers = errors.New("topic owned by others")
errNoRows = errors.New("no rows found")
)
// Manager handles user authentication, authorization, and management
type Manager struct {
config *Config
db *sql.DB
queries storeQueries
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
mu sync.Mutex
}
var _ Auther = (*Manager)(nil)
// initManager sets defaults and runs startup tasks common to all backends
func initManager(manager *Manager) error {
if manager.config.BcryptCost <= 0 {
manager.config.BcryptCost = DefaultUserPasswordBcryptCost
}
if manager.config.QueueWriterInterval.Seconds() <= 0 {
manager.config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
}
manager.statsQueue = make(map[string]*Stats)
manager.tokenQueue = make(map[string]*TokenUpdate)
if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
return err
}
go manager.asyncQueueWriter(manager.config.QueueWriterInterval)
return nil
}
// Authenticate checks username and password and returns a User if correct, and the user has not been
// marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
// the password is correct or incorrect.
func (a *Manager) Authenticate(username, password string) (*User, error) {
if username == Everyone {
return nil, ErrUnauthenticated
}
user, err := a.User(username)
if err != nil {
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if user.Deleted {
log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
return nil, ErrUnauthenticated
} else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
return nil, ErrUnauthenticated
}
return user, nil
}
// AuthenticateToken checks if the token exists and returns the associated User if it does.
// The method sets the User.Token value to the token that was used for authentication.
func (a *Manager) AuthenticateToken(token string) (*User, error) {
if len(token) != tokenLength {
return nil, ErrUnauthenticated
}
user, err := a.UserByToken(token)
if err != nil {
log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated
}
user.Token = token
return user, nil
}
// CreateToken generates a random token for the given user and returns it. The token expires
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return a.createToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
}
// ChangeToken updates a token's label and/or expiry date
func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
if token == "" {
return nil, errNoTokenProvided
}
if err := a.canChangeToken(userID, token); err != nil {
return nil, err
}
t, err := a.Token(userID, token)
if err != nil {
return nil, err
}
if label != nil {
t.Label = *label
}
if expires != nil {
t.Expires = *expires
}
if err := a.changeToken(userID, token, t.Label, t.Expires); err != nil {
return nil, err
}
return t, nil
}
// RemoveToken deletes the token defined in User.Token
func (a *Manager) RemoveToken(userID, token string) error {
if err := a.canChangeToken(userID, token); err != nil {
return err
}
return a.removeToken(userID, token)
}
// canChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed.
func (a *Manager) canChangeToken(userID, token string) error {
t, err := a.Token(userID, token)
if err != nil {
return err
} else if t.Provisioned {
return ErrProvisionedTokenChange
}
return nil
}
// ResetStats resets all user stats in the user database. This touches all users.
func (a *Manager) ResetStats() error {
a.mu.Lock() // Includes database query to avoid races!
defer a.mu.Unlock()
if err := a.resetStats(); err != nil {
return err
}
a.statsQueue = make(map[string]*Stats)
return nil
}
// resetStats resets all user stats in the user database
func (a *Manager) resetStats() error {
if _, err := a.db.Exec(a.queries.updateUserStatsResetAll); err != nil {
return err
}
return nil
}
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
// batches at a regular interval
func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
a.mu.Lock()
defer a.mu.Unlock()
a.statsQueue[userID] = stats
}
// EnqueueTokenUpdate adds the token update to a queue which writes out token access times
// in batches at a regular interval
func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
a.mu.Lock()
defer a.mu.Unlock()
a.tokenQueue[tokenID] = update
}
func (a *Manager) asyncQueueWriter(interval time.Duration) {
ticker := time.NewTicker(interval)
for range ticker.C {
if err := a.writeUserStatsQueue(); err != nil {
log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
}
if err := a.writeTokenUpdateQueue(); err != nil {
log.Tag(tag).Err(err).Warn("Writing token update queue failed")
}
}
}
func (a *Manager) writeUserStatsQueue() error {
a.mu.Lock()
if len(a.statsQueue) == 0 {
a.mu.Unlock()
log.Tag(tag).Trace("No user stats updates to commit")
return nil
}
statsQueue := a.statsQueue
a.statsQueue = make(map[string]*Stats)
a.mu.Unlock()
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue {
log.
Tag(tag).
Fields(log.Context{
"user_id": userID,
"messages_count": update.Messages,
"emails_count": update.Emails,
"calls_count": update.Calls,
}).
Trace("Updating stats for user %s", userID)
}
return a.UpdateStats(statsQueue)
}
func (a *Manager) writeTokenUpdateQueue() error {
a.mu.Lock()
if len(a.tokenQueue) == 0 {
a.mu.Unlock()
log.Tag(tag).Trace("No token updates to commit")
return nil
}
tokenQueue := a.tokenQueue
a.tokenQueue = make(map[string]*TokenUpdate)
a.mu.Unlock()
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue {
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
}
return a.UpdateTokenLastAccess(tokenQueue)
}
// Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user.
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
if user != nil && user.Role == RoleAdmin {
return nil // Admin can do everything
}
username := Everyone
if user != nil {
username = user.Name
}
// Select the read/write permissions for this user/topic combo.
read, write, found, err := a.AuthorizeTopicAccess(username, topic)
if err != nil {
return err
}
if !found {
return a.resolvePerms(a.config.DefaultAccess, perm)
}
return a.resolvePerms(NewPermission(read, write), perm)
}
func (a *Manager) resolvePerms(base, perm Permission) error {
if perm == PermissionRead && base.IsRead() {
return nil
} else if perm == PermissionWrite && base.IsWrite() {
return nil
}
return ErrUnauthorized
}
// AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
return a.addUser(username, password, role, hashed, false)
}
func (a *Manager) addUser(username, password string, role Role, hashed, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
var hash string
var err error
if hashed {
hash = password
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
return err
}
} else {
hash, err = hashPassword(password, a.config.BcryptCost)
if err != nil {
return err
}
}
return a.insertUser(username, hash, role, provisioned)
}
// RemoveUser deletes the user with the given username. The function returns nil on success, even
// if the user did not exist in the first place.
func (a *Manager) RemoveUser(username string) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return a.removeUser(username)
}
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
// successful auth via Authenticate. A background process will delete the user at a later date.
func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
return a.markUserRemoved(user.ID, user.Name)
}
// ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string, hashed bool) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
var hash string
var err error
if hashed {
hash = password
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
return err
}
} else {
hash, err = hashPassword(password, a.config.BcryptCost)
if err != nil {
return err
}
}
return a.changePasswordHash(username, hash)
}
// CanChangeUser checks if the user with the given username can be changed.
// This is used to prevent changes to provisioned users, which are defined in the config file.
func (a *Manager) CanChangeUser(username string) error {
user, err := a.User(username)
if err != nil {
return err
} else if user.Provisioned {
return ErrProvisionedUserChange
}
return nil
}
// ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
// all existing access control entries (Grant) are removed, since they are no longer needed.
func (a *Manager) ChangeRole(username string, role Role) error {
if err := a.CanChangeUser(username); err != nil {
return err
}
return a.changeRole(username, role)
}
// ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
// or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
func (a *Manager) ChangeTier(username, tier string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
t, err := a.Tier(tier)
if err != nil {
return err
} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
return err
}
return a.changeTierCode(username, tier)
}
// ResetTier removes the tier from the given user
func (a *Manager) ResetTier(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if err := a.checkReservationsLimit(username, 0); err != nil {
return err
}
return a.resetTierCode(username)
}
func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
u, err := a.User(username)
if err != nil {
return err
}
if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
reservations, err := a.Reservations(username)
if err != nil {
return err
} else if int64(len(reservations)) > reservationsLimit {
return ErrTooManyReservations
}
}
return nil
}
// AllowReservation tests if a user may create an access control entry for the given topic.
// If there are any ACL entries that are not owned by the user, an error is returned.
func (a *Manager) AllowReservation(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument
}
otherCount, err := a.OtherAccessCount(username, topic)
if err != nil {
return err
}
if otherCount > 0 {
return errTopicOwnedByOthers
}
return nil
}
// AllowAccess adds or updates an entry in the access control list for a specific user. It controls
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return a.allowAccess(username, topicPattern, permission, false)
}
func (a *Manager) allowAccess(username string, topicPattern string, permission Permission, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
return a.allowAccessTx(a.db, username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned)
}
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error {
return a.resetAccess(username, topicPattern)
}
func (a *Manager) resetAccess(username string, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
return ErrInvalidArgument
}
if username == "" && topicPattern == "" {
_, err := a.db.Exec(a.queries.deleteAllAccess)
return err
} else if topicPattern == "" {
return a.resetUserAccessTx(a.db, username)
}
return a.resetTopicAccessTx(a.db, username, topicPattern)
}
// DefaultAccess returns the default read/write access if no access control entry matches
func (a *Manager) DefaultAccess() Permission {
return a.config.DefaultAccess
}
// Close closes the underlying database
func (a *Manager) Close() error {
return a.db.Close()
}
// maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config.
func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
if !a.config.ProvisionEnabled {
return nil
}
existingUsers, err := a.Users()
if err != nil {
return err
}
provisionUsernames := util.Map(a.config.Users, func(u *User) string {
return u.Name
})
if err := a.maybeProvisionUsers(provisionUsernames, existingUsers); err != nil {
return fmt.Errorf("failed to provision users: %v", err)
}
if err := a.maybeProvisionGrants(); err != nil {
return fmt.Errorf("failed to provision grants: %v", err)
}
if err := a.maybeProvisionTokens(provisionUsernames); err != nil {
return fmt.Errorf("failed to provision tokens: %v", err)
}
return nil
}
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
// It also removes users that are provisioned, but not in the config anymore.
func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers []*User) error {
// Remove users that are provisioned, but not in the config anymore
for _, user := range existingUsers {
if user.Name == Everyone {
continue
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
if err := a.removeUser(user.Name); err != nil {
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
}
}
}
// Add or update provisioned users
for _, user := range a.config.Users {
if user.Name == Everyone {
continue
}
existingUser, exists := util.Find(existingUsers, func(u *User) bool {
return u.Name == user.Name
})
if !exists {
if err := a.addUser(user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
}
} else {
if !existingUser.Provisioned {
if err := a.ChangeProvisioned(user.Name, true); err != nil {
return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
}
}
if existingUser.Hash != user.Hash {
if err := a.changePasswordHash(user.Name, user.Hash); err != nil {
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
}
}
if existingUser.Role != user.Role {
if err := a.changeRole(user.Name, user.Role); err != nil {
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
}
}
}
}
return nil
}
// maybeProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config.
//
// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
// access time) or do not have dependent resources (such as grants or tokens).
func (a *Manager) maybeProvisionGrants() error {
// Remove all provisioned grants
if err := a.ResetAllProvisionedAccess(); err != nil {
return err
}
// (Re-)add provisioned grants
for username, grants := range a.config.Access {
user, exists := util.Find(a.config.Users, func(u *User) bool {
return u.Name == username
})
if !exists && username != Everyone {
return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username)
} else if user != nil && user.Role == RoleAdmin {
return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
}
for _, grant := range grants {
if err := a.resetAccess(username, grant.TopicPattern); err != nil {
return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
}
if err := a.allowAccess(username, grant.TopicPattern, grant.Permission, true); err != nil {
return err
}
}
}
return nil
}
func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
// Remove tokens that are provisioned, but not in the config anymore
existingTokens, err := a.AllProvisionedTokens()
if err != nil {
return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
}
var provisionTokens []string
for _, userTokens := range a.config.Tokens {
for _, token := range userTokens {
provisionTokens = append(provisionTokens, token.Value)
}
}
for _, existingToken := range existingTokens {
if !slices.Contains(provisionTokens, existingToken.Value) {
if err := a.RemoveProvisionedToken(existingToken.Value); err != nil {
return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
}
}
}
// (Re-)add provisioned tokens
for username, tokens := range a.config.Tokens {
if !slices.Contains(provisionUsernames, username) && username != Everyone {
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
}
userID, err := a.UserIDByUsername(username)
if err != nil {
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
}
for _, token := range tokens {
if _, err := a.createToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
return err
}
}
}
return nil
}
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
func (a *Manager) User(username string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByName, username)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByID(id string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByID, id)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
func (a *Manager) UserByToken(token string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
func (a *Manager) UserByStripeCustomer(customerID string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByStripeCustomerID, customerID)
if err != nil {
return nil, err
}
return a.readUser(rows)
}
// Users returns a list of users
func (a *Manager) Users() ([]*User, error) {
rows, err := a.db.Query(a.queries.selectUsernames)
if err != nil {
return nil, err
}
defer rows.Close()
usernames := make([]string, 0)
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
usernames = append(usernames, username)
}
rows.Close()
users := make([]*User, 0)
for _, username := range usernames {
user, err := a.User(username)
if err != nil {
return nil, err
}
users = append(users, user)
}
return users, nil
}
// UsersCount returns the number of users in the database
func (a *Manager) UsersCount() (int64, error) {
rows, err := a.db.Query(a.queries.selectUserCount)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// insertUser adds a user with the given username, password hash and role to the database
func (a *Manager) insertUser(username, hash string, role Role, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
now := time.Now().Unix()
if _, err := a.db.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
if isUniqueConstraintError(err) {
return ErrUserExists
}
return err
}
return nil
}
// removeUser deletes the user with the given username
func (a *Manager) removeUser(username string) error {
if !AllowedUsername(username) {
return ErrInvalidArgument
}
// Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := a.db.Exec(a.queries.deleteUser, username); err != nil {
return err
}
return nil
}
// markUserRemoved sets the deleted flag on the user, and deletes all access tokens
func (a *Manager) markUserRemoved(userID, username string) error {
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := a.resetUserAccessTx(tx, username); err != nil {
return err
}
if _, err := tx.Exec(a.queries.deleteAllToken, userID); err != nil {
return err
}
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, userID); err != nil {
return err
}
return tx.Commit()
}
// RemoveDeletedUsers deletes all users that have been marked deleted
func (a *Manager) RemoveDeletedUsers() error {
if _, err := a.db.Exec(a.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
return err
}
return nil
}
// changePasswordHash changes a user's password hash in the database
func (a *Manager) changePasswordHash(username, hash string) error {
if _, err := a.db.Exec(a.queries.updateUserPass, hash, username); err != nil {
return err
}
return nil
}
// changeRole changes a user's role
func (a *Manager) changeRole(username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil {
return err
}
// If changing to admin, remove all access entries
if role == RoleAdmin {
if err := a.resetUserAccessTx(tx, username); err != nil {
return err
}
}
return tx.Commit()
}
// ChangeProvisioned changes the provisioned status of a user
func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
if _, err := a.db.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
return err
}
return nil
}
// ChangeSettings persists the user settings
func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
b, err := json.Marshal(prefs)
if err != nil {
return err
}
if _, err := a.db.Exec(a.queries.updateUserPrefs, string(b), userID); err != nil {
return err
}
return nil
}
// changeTierCode changes a user's tier using the tier code in the database
func (a *Manager) changeTierCode(username, tierCode string) error {
if _, err := a.db.Exec(a.queries.updateUserTier, tierCode, username); err != nil {
return err
}
return nil
}
// resetTierCode removes the tier from the given user in the database
func (a *Manager) resetTierCode(username string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
}
_, err := a.db.Exec(a.queries.deleteUserTier, username)
return err
}
// UpdateStats updates statistics for one or more users in a single transaction
func (a *Manager) UpdateStats(stats map[string]*Stats) error {
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for userID, update := range stats {
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
return err
}
}
return tx.Commit()
}
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
defer rows.Close()
var id, username, hash, role, prefs, syncTopic string
var provisioned bool
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
var messages, emails, calls int64
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
if !rows.Next() {
return nil, ErrUserNotFound
}
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
user := &User{
ID: id,
Name: username,
Hash: hash,
Role: Role(role),
Prefs: &Prefs{},
SyncTopic: syncTopic,
Provisioned: provisioned,
Stats: &Stats{
Messages: messages,
Emails: emails,
Calls: calls,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
},
Deleted: deleted.Valid,
}
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
return nil, err
}
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
ID: tierID.String,
Code: tierCode.String,
Name: tierName.String,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}
}
return user, nil
}
// createToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
// If maxTokenCount is 0, no pruning is performed.
func (a *Manager) createToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
return nil, err
}
if maxTokenCount > 0 {
var tokenCount int
if err := tx.QueryRow(a.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil {
return nil, err
}
if tokenCount > maxTokenCount {
if _, err := tx.Exec(a.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{
Value: token,
Label: label,
LastAccess: lastAccess,
LastOrigin: lastOrigin,
Expires: expires,
Provisioned: provisioned,
}, nil
}
// Token returns a specific token for a user
func (a *Manager) Token(userID, token string) (*Token, error) {
rows, err := a.db.Query(a.queries.selectToken, userID, token)
if err != nil {
return nil, err
}
defer rows.Close()
return a.readToken(rows)
}
// Tokens returns all existing tokens for the user with the given user ID
func (a *Manager) Tokens(userID string) ([]*Token, error) {
rows, err := a.db.Query(a.queries.selectTokens, userID)
if err != nil {
return nil, err
}
defer rows.Close()
tokens := make([]*Token, 0)
for {
token, err := a.readToken(rows)
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
}
tokens = append(tokens, token)
}
return tokens, nil
}
// AllProvisionedTokens returns all provisioned tokens
func (a *Manager) AllProvisionedTokens() ([]*Token, error) {
rows, err := a.db.Query(a.queries.selectAllProvisionedTokens)
if err != nil {
return nil, err
}
defer rows.Close()
tokens := make([]*Token, 0)
for {
token, err := a.readToken(rows)
if errors.Is(err, ErrTokenNotFound) {
break
} else if err != nil {
return nil, err
}
tokens = append(tokens, token)
}
return tokens, nil
}
// changeToken updates a token's label and expiry time
func (a *Manager) changeToken(userID, token, label string, expires time.Time) error {
if _, err := a.db.Exec(a.queries.updateToken, label, expires.Unix(), userID, token); err != nil {
return err
}
return nil
}
// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for token, update := range updates {
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
return err
}
}
return tx.Commit()
}
// removeToken deletes the token
func (a *Manager) removeToken(userID, token string) error {
if token == "" {
return errNoTokenProvided
}
if _, err := a.db.Exec(a.queries.deleteToken, userID, token); err != nil {
return err
}
return nil
}
// RemoveProvisionedToken deletes a provisioned token by value, regardless of user
func (a *Manager) RemoveProvisionedToken(token string) error {
if token == "" {
return errNoTokenProvided
}
if _, err := a.db.Exec(a.queries.deleteProvisionedToken, token); err != nil {
return err
}
return nil
}
// RemoveExpiredTokens deletes all expired tokens from the database
func (a *Manager) RemoveExpiredTokens() error {
if _, err := a.db.Exec(a.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
return err
}
return nil
}
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
var token, label, lastOrigin string
var lastAccess, expires int64
var provisioned bool
if !rows.Next() {
return nil, ErrTokenNotFound
}
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
lastOriginIP, err := netip.ParseAddr(lastOrigin)
if err != nil {
lastOriginIP = netip.IPv4Unspecified()
}
return &Token{
Value: token,
Label: label,
LastAccess: time.Unix(lastAccess, 0),
LastOrigin: lastOriginIP,
Expires: time.Unix(expires, 0),
Provisioned: provisioned,
}, nil
}
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
// The found return value indicates whether an ACL entry was found at all.
//
// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (a *Manager) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
return false, false, false, err
}
defer rows.Close()
if !rows.Next() {
return false, false, false, nil
}
if err := rows.Scan(&read, &write); err != nil {
return false, false, false, err
} else if err := rows.Err(); err != nil {
return false, false, false, err
}
return read, write, true, nil
}
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
func (a *Manager) AllGrants() (map[string][]Grant, error) {
rows, err := a.db.Query(a.queries.selectUserAllAccess)
if err != nil {
return nil, err
}
defer rows.Close()
grants := make(map[string][]Grant, 0)
for rows.Next() {
var userID, topic string
var read, write, provisioned bool
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
if _, ok := grants[userID]; !ok {
grants[userID] = make([]Grant, 0)
}
grants[userID] = append(grants[userID], Grant{
TopicPattern: fromSQLWildcard(topic),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
}
// Grants returns all user-specific access control entries
func (a *Manager) Grants(username string) ([]Grant, error) {
rows, err := a.db.Query(a.queries.selectUserAccess, username)
if err != nil {
return nil, err
}
defer rows.Close()
grants := make([]Grant, 0)
for rows.Next() {
var topic string
var read, write, provisioned bool
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
grants = append(grants, Grant{
TopicPattern: fromSQLWildcard(topic),
Permission: NewPermission(read, write),
Provisioned: provisioned,
})
}
return grants, nil
}
func (a *Manager) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned)
return err
}
func (a *Manager) resetUserAccessTx(tx execer, username string) error {
if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument
}
_, err := tx.Exec(a.queries.deleteUserAccess, username, username)
return err
}
func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
return ErrInvalidArgument
}
_, err := tx.Exec(a.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
return err
}
// ResetAllProvisionedAccess removes all provisioned access control entries
func (a *Manager) ResetAllProvisionedAccess() error {
if _, err := a.db.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
return err
}
return nil
}
// AddReservation creates two access control entries for the given topic: one with full read/write
// access for the given user, and one for Everyone with the given permission. Both entries are
// created atomically in a single transaction.
func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := a.allowAccessTx(tx, username, topic, true, true, username, false); err != nil {
return err
}
if err := a.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return tx.Commit()
}
// RemoveReservations deletes the access control entries associated with the given username/topic,
// as well as all entries with Everyone/topic. All deletions are performed atomically in a single
// transaction.
func (a *Manager) RemoveReservations(username string, topics ...string) error {
if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
return ErrInvalidArgument
}
for _, topic := range topics {
if !AllowedTopic(topic) {
return ErrInvalidArgument
}
}
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, topic := range topics {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err
}
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
return err
}
}
return tx.Commit()
}
// Reservations returns all user-owned topics, and the associated everyone-access
func (a *Manager) Reservations(username string) ([]Reservation, error) {
rows, err := a.db.Query(a.queries.selectUserReservations, Everyone, username)
if err != nil {
return nil, err
}
defer rows.Close()
reservations := make([]Reservation, 0)
for rows.Next() {
var topic string
var ownerRead, ownerWrite bool
var everyoneRead, everyoneWrite sql.NullBool
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
reservations = append(reservations, Reservation{
Topic: fromSQLWildcard(topic),
Owner: NewPermission(ownerRead, ownerWrite),
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
})
}
return reservations, nil
}
// HasReservation returns true if the given topic access is owned by the user
func (a *Manager) HasReservation(username, topic string) (bool, error) {
rows, err := a.db.Query(a.queries.selectUserHasReservation, username, escapeUnderscore(topic))
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
return false, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
// ReservationsCount returns the number of reservations owned by this user
func (a *Manager) ReservationsCount(username string) (int64, error) {
rows, err := a.db.Query(a.queries.selectUserReservationsCount, username)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
func (a *Manager) ReservationOwner(topic string) (string, error) {
rows, err := a.db.Query(a.queries.selectUserReservationsOwner, escapeUnderscore(topic))
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", nil
}
var ownerUserID string
if err := rows.Scan(&ownerUserID); err != nil {
return "", err
}
return ownerUserID, nil
}
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
func (a *Manager) OtherAccessCount(username, topic string) (int, error) {
rows, err := a.db.Query(a.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// AddTier creates a new tier in the database
func (a *Manager) AddTier(tier *Tier) error {
if tier.ID == "" {
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
}
if _, err := a.db.Exec(a.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
return err
}
return nil
}
// UpdateTier updates a tier's properties in the database
func (a *Manager) UpdateTier(tier *Tier) error {
if _, err := a.db.Exec(a.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
return err
}
return nil
}
// RemoveTier deletes the tier with the given code
func (a *Manager) RemoveTier(code string) error {
if !AllowedTier(code) {
return ErrInvalidArgument
}
// This fails if any user has this tier
if _, err := a.db.Exec(a.queries.deleteTier, code); err != nil {
return err
}
return nil
}
// Tiers returns a list of all Tier structs
func (a *Manager) Tiers() ([]*Tier, error) {
rows, err := a.db.Query(a.queries.selectTiers)
if err != nil {
return nil, err
}
defer rows.Close()
tiers := make([]*Tier, 0)
for {
tier, err := a.readTier(rows)
if errors.Is(err, ErrTierNotFound) {
break
} else if err != nil {
return nil, err
}
tiers = append(tiers, tier)
}
return tiers, nil
}
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
func (a *Manager) Tier(code string) (*Tier, error) {
rows, err := a.db.Query(a.queries.selectTierByCode, code)
if err != nil {
return nil, err
}
defer rows.Close()
return a.readTier(rows)
}
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
rows, err := a.db.Query(a.queries.selectTierByPriceID, priceID, priceID)
if err != nil {
return nil, err
}
defer rows.Close()
return a.readTier(rows)
}
func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
var id, code, name string
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
if !rows.Next() {
return nil, ErrTierNotFound
}
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
// When changed, note readUser() as well
return &Tier{
ID: id,
Code: code,
Name: name,
MessageLimit: messagesLimit.Int64,
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
EmailLimit: emailsLimit.Int64,
CallLimit: callsLimit.Int64,
ReservationLimit: reservationsLimit.Int64,
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}, nil
}
// PhoneNumbers returns all phone numbers for the user with the given user ID
func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
rows, err := a.db.Query(a.queries.selectPhoneNumbers, userID)
if err != nil {
return nil, err
}
defer rows.Close()
phoneNumbers := make([]string, 0)
for {
phoneNumber, err := a.readPhoneNumber(rows)
if errors.Is(err, ErrPhoneNumberNotFound) {
break
} else if err != nil {
return nil, err
}
phoneNumbers = append(phoneNumbers, phoneNumber)
}
return phoneNumbers, nil
}
// AddPhoneNumber adds a phone number to the user with the given user ID
func (a *Manager) AddPhoneNumber(userID, phoneNumber string) error {
if _, err := a.db.Exec(a.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
if isUniqueConstraintError(err) {
return ErrPhoneNumberExists
}
return err
}
return nil
}
// RemovePhoneNumber deletes a phone number from the user with the given user ID
func (a *Manager) RemovePhoneNumber(userID, phoneNumber string) error {
_, err := a.db.Exec(a.queries.deletePhoneNumber, userID, phoneNumber)
return err
}
func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) {
var phoneNumber string
if !rows.Next() {
return "", ErrPhoneNumberNotFound
}
if err := rows.Scan(&phoneNumber); err != nil {
return "", err
} else if err := rows.Err(); err != nil {
return "", err
}
return phoneNumber, nil
}
// ChangeBilling updates a user's billing fields
func (a *Manager) ChangeBilling(username string, billing *Billing) error {
if _, err := a.db.Exec(a.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
return err
}
return nil
}
// UserIDByUsername returns the user ID for the given username
func (a *Manager) UserIDByUsername(username string) (string, error) {
rows, err := a.db.Query(a.queries.selectUserIDFromUsername, username)
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", ErrUserNotFound
}
var userID string
if err := rows.Scan(&userID); err != nil {
return "", err
}
return userID, nil
}
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
func isUniqueConstraintError(err error) bool {
errStr := err.Error()
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
}