Re-add execTx

This commit is contained in:
binwiederhier
2026-02-28 19:49:01 -05:00
parent 542aa403d2
commit ccbd02331c
3 changed files with 176 additions and 169 deletions

View File

@@ -114,38 +114,27 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// AddUser adds a user with the given username, password and role // AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error { func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
return a.addUser(username, password, role, hashed, false)
}
func (a *Manager) addUser(username, password string, role Role, hashed, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
var hash string hash, err := a.maybeHashPassword(password, hashed)
var err error if err != nil {
if hashed { return err
hash = password
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
return err
}
} else {
hash, err = hashPassword(password, a.config.BcryptCost)
if err != nil {
return err
}
} }
return a.insertUser(username, hash, role, provisioned) return execTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, hash, role, false)
})
} }
// insertUser adds a user with the given username, password hash and role to the database // addUserTx adds a user with the given username, password hash and role to the database
func (a *Manager) insertUser(username, hash string, role Role, provisioned bool) error { func (a *Manager) addUserTx(tx *sql.Tx, username, hash string, role Role, provisioned bool) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
userID := util.RandomStringPrefix(userIDPrefix, userIDLength) userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength) syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
now := time.Now().Unix() now := time.Now().Unix()
if _, err := a.db.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil { if _, err := tx.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
if isUniqueConstraintError(err) { if isUniqueConstraintError(err) {
return ErrUserExists return ErrUserExists
} }
@@ -160,16 +149,18 @@ func (a *Manager) RemoveUser(username string) error {
if err := a.CanChangeUser(username); err != nil { if err := a.CanChangeUser(username); err != nil {
return err return err
} }
return a.removeUser(username) return execTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username)
})
} }
// removeUser deletes the user with the given username // removeUserTx deletes the user with the given username
func (a *Manager) removeUser(username string) error { func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
if !AllowedUsername(username) { if !AllowedUsername(username) {
return ErrInvalidArgument return ErrInvalidArgument
} }
// Rows in user_access, user_token, etc. are deleted via foreign keys // Rows in user_access, user_token, etc. are deleted via foreign keys
if _, err := a.db.Exec(a.queries.deleteUser, username); err != nil { if _, err := tx.Exec(a.queries.deleteUser, username); err != nil {
return err return err
} }
return nil return nil
@@ -181,22 +172,19 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) { if !AllowedUsername(user.Name) {
return ErrInvalidArgument return ErrInvalidArgument
} }
tx, err := a.db.Begin() return execTx(a.db, func(tx *sql.Tx) error {
if err != nil { if err := a.resetUserAccessTx(tx, user.Name); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil {
if err := a.resetUserAccessTx(tx, user.Name); err != nil { return err
return err }
} deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil { if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil {
return err return err
} }
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix() return nil
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil { })
return err
}
return tx.Commit()
} }
// RemoveDeletedUsers deletes all users that have been marked deleted // RemoveDeletedUsers deletes all users that have been marked deleted
@@ -212,25 +200,18 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
if err := a.CanChangeUser(username); err != nil { if err := a.CanChangeUser(username); err != nil {
return err return err
} }
var hash string hash, err := a.maybeHashPassword(password, hashed)
var err error if err != nil {
if hashed { return err
hash = password
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
return err
}
} else {
hash, err = hashPassword(password, a.config.BcryptCost)
if err != nil {
return err
}
} }
return a.changePasswordHash(username, hash) return execTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordHashTx(tx, username, hash)
})
} }
// changePasswordHash changes a user's password hash in the database // changePasswordHashTx changes a user's password hash in the database
func (a *Manager) changePasswordHash(username, hash string) error { func (a *Manager) changePasswordHashTx(tx *sql.Tx, username, hash string) error {
if _, err := a.db.Exec(a.queries.updateUserPass, hash, username); err != nil { if _, err := tx.Exec(a.queries.updateUserPass, hash, username); err != nil {
return err return err
} }
return nil return nil
@@ -242,19 +223,16 @@ func (a *Manager) ChangeRole(username string, role Role) error {
if err := a.CanChangeUser(username); err != nil { if err := a.CanChangeUser(username); err != nil {
return err return err
} }
return a.changeRole(username, role) return execTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role)
})
} }
// changeRole changes a user's role // changeRoleTx changes a user's role
func (a *Manager) changeRole(username string, role Role) error { func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
if !AllowedUsername(username) || !AllowedRole(role) { if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument return ErrInvalidArgument
} }
tx, err := a.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil { if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil {
return err return err
} }
@@ -264,7 +242,7 @@ func (a *Manager) changeRole(username string, role Role) error {
return err return err
} }
} }
return tx.Commit() return nil
} }
// CanChangeUser checks if the user with the given username can be changed. // CanChangeUser checks if the user with the given username can be changed.
@@ -281,7 +259,14 @@ func (a *Manager) CanChangeUser(username string) error {
// ChangeProvisioned changes the provisioned status of a user // ChangeProvisioned changes the provisioned status of a user
func (a *Manager) ChangeProvisioned(username string, provisioned bool) error { func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
if _, err := a.db.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil { return execTx(a.db, func(tx *sql.Tx) error {
return a.changeProvisionedTx(tx, username, provisioned)
})
}
// changeProvisionedTx changes the provisioned status of a user
func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error {
if _, err := tx.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
return err return err
} }
return nil return nil
@@ -357,17 +342,14 @@ func (a *Manager) ResetStats() error {
// UpdateStats updates statistics for one or more users in a single transaction // UpdateStats updates statistics for one or more users in a single transaction
func (a *Manager) UpdateStats(stats map[string]*Stats) error { func (a *Manager) UpdateStats(stats map[string]*Stats) error {
tx, err := a.db.Begin() return execTx(a.db, func(tx *sql.Tx) error {
if err != nil { for userID, update := range stats {
return err if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
} return err
defer tx.Rollback() }
for userID, update := range stats {
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
@@ -578,6 +560,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
return user, nil return user, nil
} }
func (a *Manager) maybeHashPassword(password string, hashed bool) (string, error) {
if hashed {
if err := ValidPasswordHash(password, a.config.BcryptCost); err != nil {
return "", err
}
return password, nil
}
return hashPassword(password, a.config.BcryptCost)
}
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired
// permission. The user param may be nil to signal an anonymous user. // permission. The user param may be nil to signal an anonymous user.
func (a *Manager) Authorize(user *User, topic string, perm Permission) error { func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
@@ -612,37 +604,42 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty). // owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return a.allowAccess(username, topicPattern, permission, false) return execTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false)
})
} }
func (a *Manager) allowAccess(username string, topicPattern string, permission Permission, provisioned bool) error { func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
if !AllowedUsername(username) && username != Everyone { if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) { } else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument return ErrInvalidArgument
} }
return a.allowAccessTx(a.db, username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned) _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), "", "", provisioned)
return err
} }
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*). // empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error { func (a *Manager) ResetAccess(username string, topicPattern string) error {
return a.resetAccess(username, topicPattern) return execTx(a.db, func(tx *sql.Tx) error {
return a.resetAccessTx(tx, username, topicPattern)
})
} }
func (a *Manager) resetAccess(username string, topicPattern string) error { func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" { if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
return ErrInvalidArgument return ErrInvalidArgument
} }
if username == "" && topicPattern == "" { if username == "" && topicPattern == "" {
_, err := a.db.Exec(a.queries.deleteAllAccess) _, err := tx.Exec(a.queries.deleteAllAccess)
return err return err
} else if topicPattern == "" { } else if topicPattern == "" {
return a.resetUserAccessTx(a.db, username) return a.resetUserAccessTx(tx, username)
} }
return a.resetTopicAccessTx(a.db, username, topicPattern) return a.resetTopicAccessTx(tx, username, topicPattern)
} }
// DefaultAccess returns the default read/write access if no access control entry matches // DefaultAccess returns the default read/write access if no access control entry matches
@@ -749,18 +746,12 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument return ErrInvalidArgument
} }
tx, err := a.db.Begin() return execTx(a.db, func(tx *sql.Tx) error {
if err != nil { if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
return err return err
} }
defer tx.Rollback() return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
if err := a.allowAccessTx(tx, username, topic, true, true, username, false); err != nil { })
return err
}
if err := a.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return tx.Commit()
} }
// RemoveReservations deletes the access control entries associated with the given username/topic, // RemoveReservations deletes the access control entries associated with the given username/topic,
@@ -775,20 +766,17 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument return ErrInvalidArgument
} }
} }
tx, err := a.db.Begin() return execTx(a.db, func(tx *sql.Tx) error {
if err != nil { for _, topic := range topics {
return err if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
} return err
defer tx.Rollback() }
for _, topic := range topics { if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil { return err
return err }
} }
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil { return nil
return err })
}
}
return tx.Commit()
} }
// Reservations returns all user-owned topics, and the associated everyone-access // Reservations returns all user-owned topics, and the associated everyone-access
@@ -893,17 +881,17 @@ func (a *Manager) ResetAllProvisionedAccess() error {
return nil return nil
} }
func (a *Manager) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
if !AllowedUsername(username) && username != Everyone { if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) { } else if !AllowedTopicPattern(topic) {
return ErrInvalidArgument return ErrInvalidArgument
} }
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned) _, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
return err return err
} }
func (a *Manager) resetUserAccessTx(tx execer, username string) error { func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
if !AllowedUsername(username) && username != Everyone { if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument return ErrInvalidArgument
} }
@@ -911,7 +899,7 @@ func (a *Manager) resetUserAccessTx(tx execer, username string) error {
return err return err
} }
func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) error { func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string) error {
if !AllowedUsername(username) && username != Everyone && username != "" { if !AllowedUsername(username) && username != Everyone && username != "" {
return ErrInvalidArgument return ErrInvalidArgument
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
@@ -925,17 +913,14 @@ func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) e
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them. // given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return a.createToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
})
} }
// createToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount. // createTokenTx creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
// If maxTokenCount is 0, no pruning is performed. // If maxTokenCount is 0, no pruning is performed.
func (a *Manager) createToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) { func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
tx, err := a.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
return nil, err return nil, err
} }
@@ -950,9 +935,6 @@ func (a *Manager) createToken(userID, token, label string, lastAccess time.Time,
} }
} }
} }
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{ return &Token{
Value: token, Value: token,
Label: label, Label: label,
@@ -1064,17 +1046,14 @@ func (a *Manager) AllProvisionedTokens() ([]*Token, error) {
// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction // UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error { func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
tx, err := a.db.Begin() return execTx(a.db, func(tx *sql.Tx) error {
if err != nil { for token, update := range updates {
return err if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
} return err
defer tx.Rollback() }
for token, update := range updates {
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// RemoveProvisionedToken deletes a provisioned token by value, regardless of user // RemoveProvisionedToken deletes a provisioned token by value, regardless of user
@@ -1320,27 +1299,33 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
provisionUsernames := util.Map(a.config.Users, func(u *User) string { provisionUsernames := util.Map(a.config.Users, func(u *User) string {
return u.Name return u.Name
}) })
if err := a.maybeProvisionUsers(provisionUsernames, existingUsers); err != nil { existingTokens, err := a.AllProvisionedTokens()
return fmt.Errorf("failed to provision users: %v", err) if err != nil {
return err
} }
if err := a.maybeProvisionGrants(); err != nil { return execTx(a.db, func(tx *sql.Tx) error {
return fmt.Errorf("failed to provision grants: %v", err) if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
} return fmt.Errorf("failed to provision users: %v", err)
if err := a.maybeProvisionTokens(provisionUsernames); err != nil { }
return fmt.Errorf("failed to provision tokens: %v", err) if err := a.maybeProvisionGrants(tx); err != nil {
} return fmt.Errorf("failed to provision grants: %v", err)
return nil }
if err := a.maybeProvisionTokens(tx, provisionUsernames, existingTokens); err != nil {
return fmt.Errorf("failed to provision tokens: %v", err)
}
return nil
})
} }
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them. // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
// It also removes users that are provisioned, but not in the config anymore. // It also removes users that are provisioned, but not in the config anymore.
func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers []*User) error { func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
// Remove users that are provisioned, but not in the config anymore // Remove users that are provisioned, but not in the config anymore
for _, user := range existingUsers { for _, user := range existingUsers {
if user.Name == Everyone { if user.Name == Everyone {
continue continue
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
if err := a.removeUser(user.Name); err != nil { if err := a.removeUserTx(tx, user.Name); err != nil {
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
} }
} }
@@ -1354,22 +1339,22 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
return u.Name == user.Name return u.Name == user.Name
}) })
if !exists { if !exists {
if err := a.addUser(user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) {
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
} }
} else { } else {
if !existingUser.Provisioned { if !existingUser.Provisioned {
if err := a.ChangeProvisioned(user.Name, true); err != nil { if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
} }
} }
if existingUser.Hash != user.Hash { if existingUser.Hash != user.Hash {
if err := a.changePasswordHash(user.Name, user.Hash); err != nil { if err := a.changePasswordHashTx(tx, user.Name, user.Hash); err != nil {
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
} }
} }
if existingUser.Role != user.Role { if existingUser.Role != user.Role {
if err := a.changeRole(user.Name, user.Role); err != nil { if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
} }
} }
@@ -1382,9 +1367,9 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
// //
// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
// access time) or do not have dependent resources (such as grants or tokens). // access time) or do not have dependent resources (such as grants or tokens).
func (a *Manager) maybeProvisionGrants() error { func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
// Remove all provisioned grants // Remove all provisioned grants
if err := a.ResetAllProvisionedAccess(); err != nil { if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
return err return err
} }
// (Re-)add provisioned grants // (Re-)add provisioned grants
@@ -1398,10 +1383,10 @@ func (a *Manager) maybeProvisionGrants() error {
return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username) return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
} }
for _, grant := range grants { for _, grant := range grants {
if err := a.resetAccess(username, grant.TopicPattern); err != nil { if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err) return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
} }
if err := a.allowAccess(username, grant.TopicPattern, grant.Permission, true); err != nil { if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
return err return err
} }
} }
@@ -1409,12 +1394,8 @@ func (a *Manager) maybeProvisionGrants() error {
return nil return nil
} }
func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string, existingTokens []*Token) error {
// Remove tokens that are provisioned, but not in the config anymore // Remove tokens that are provisioned, but not in the config anymore
existingTokens, err := a.AllProvisionedTokens()
if err != nil {
return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
}
var provisionTokens []string var provisionTokens []string
for _, userTokens := range a.config.Tokens { for _, userTokens := range a.config.Tokens {
for _, token := range userTokens { for _, token := range userTokens {
@@ -1423,7 +1404,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
} }
for _, existingToken := range existingTokens { for _, existingToken := range existingTokens {
if !slices.Contains(provisionTokens, existingToken.Value) { if !slices.Contains(provisionTokens, existingToken.Value) {
if err := a.RemoveProvisionedToken(existingToken.Value); err != nil { if _, err := tx.Exec(a.queries.deleteProvisionedToken, existingToken.Value); err != nil {
return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err) return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
} }
} }
@@ -1433,12 +1414,12 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
if !slices.Contains(provisionUsernames, username) && username != Everyone { if !slices.Contains(provisionUsernames, username) && username != Everyone {
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
} }
userID, err := a.UserIDByUsername(username) var userID string
if err != nil { if err := tx.QueryRow(a.queries.selectUserIDFromUsername, username).Scan(&userID); err != nil {
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err) return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
} }
for _, token := range tokens { for _, token := range tokens {
if _, err := a.createToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil { if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
return err return err
} }
} }

View File

@@ -1,7 +1,6 @@
package user package user
import ( import (
"database/sql"
"errors" "errors"
"net/netip" "net/netip"
"strings" "strings"
@@ -345,8 +344,3 @@ type storeQueries struct {
updateBilling string updateBilling string
} }
// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods
// to be used both standalone and within a transaction.
type execer interface {
Exec(query string, args ...any) (sql.Result, error)
}

View File

@@ -113,3 +113,35 @@ func escapeUnderscore(s string) string {
func unescapeUnderscore(s string) string { func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_") return strings.ReplaceAll(s, "\\_", "_")
} }
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}