Re-add execTx
This commit is contained in:
307
user/manager.go
307
user/manager.go
@@ -114,38 +114,27 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
||||
|
||||
// AddUser adds a user with the given username, password and role
|
||||
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
|
||||
return a.addUser(username, password, role, hashed, false)
|
||||
}
|
||||
|
||||
func (a *Manager) addUser(username, password string, role Role, hashed, provisioned bool) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
var hash string
|
||||
var err error
|
||||
if hashed {
|
||||
hash = password
|
||||
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
hash, err = hashPassword(password, a.config.BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash, err := a.maybeHashPassword(password, hashed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.insertUser(username, hash, role, provisioned)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.addUserTx(tx, username, hash, role, false)
|
||||
})
|
||||
}
|
||||
|
||||
// insertUser adds a user with the given username, password hash and role to the database
|
||||
func (a *Manager) insertUser(username, hash string, role Role, provisioned bool) error {
|
||||
// addUserTx adds a user with the given username, password hash and role to the database
|
||||
func (a *Manager) addUserTx(tx *sql.Tx, username, hash string, role Role, provisioned bool) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
now := time.Now().Unix()
|
||||
if _, err := a.db.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||
if _, err := tx.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrUserExists
|
||||
}
|
||||
@@ -160,16 +149,18 @@ func (a *Manager) RemoveUser(username string) error {
|
||||
if err := a.CanChangeUser(username); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.removeUser(username)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.removeUserTx(tx, username)
|
||||
})
|
||||
}
|
||||
|
||||
// removeUser deletes the user with the given username
|
||||
func (a *Manager) removeUser(username string) error {
|
||||
// removeUserTx deletes the user with the given username
|
||||
func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
|
||||
if !AllowedUsername(username) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||
if _, err := a.db.Exec(a.queries.deleteUser, username); err != nil {
|
||||
if _, err := tx.Exec(a.queries.deleteUser, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -181,22 +172,19 @@ func (a *Manager) MarkUserRemoved(user *User) error {
|
||||
if !AllowedUsername(user.Name) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||
@@ -212,25 +200,18 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
|
||||
if err := a.CanChangeUser(username); err != nil {
|
||||
return err
|
||||
}
|
||||
var hash string
|
||||
var err error
|
||||
if hashed {
|
||||
hash = password
|
||||
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
hash, err = hashPassword(password, a.config.BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash, err := a.maybeHashPassword(password, hashed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return a.changePasswordHash(username, hash)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.changePasswordHashTx(tx, username, hash)
|
||||
})
|
||||
}
|
||||
|
||||
// changePasswordHash changes a user's password hash in the database
|
||||
func (a *Manager) changePasswordHash(username, hash string) error {
|
||||
if _, err := a.db.Exec(a.queries.updateUserPass, hash, username); err != nil {
|
||||
// changePasswordHashTx changes a user's password hash in the database
|
||||
func (a *Manager) changePasswordHashTx(tx *sql.Tx, username, hash string) error {
|
||||
if _, err := tx.Exec(a.queries.updateUserPass, hash, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -242,19 +223,16 @@ func (a *Manager) ChangeRole(username string, role Role) error {
|
||||
if err := a.CanChangeUser(username); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.changeRole(username, role)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.changeRoleTx(tx, username, role)
|
||||
})
|
||||
}
|
||||
|
||||
// changeRole changes a user's role
|
||||
func (a *Manager) changeRole(username string, role Role) error {
|
||||
// changeRoleTx changes a user's role
|
||||
func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -264,7 +242,7 @@ func (a *Manager) changeRole(username string, role Role) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanChangeUser checks if the user with the given username can be changed.
|
||||
@@ -281,7 +259,14 @@ func (a *Manager) CanChangeUser(username string) error {
|
||||
|
||||
// ChangeProvisioned changes the provisioned status of a user
|
||||
func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
|
||||
if _, err := a.db.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.changeProvisionedTx(tx, username, provisioned)
|
||||
})
|
||||
}
|
||||
|
||||
// changeProvisionedTx changes the provisioned status of a user
|
||||
func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error {
|
||||
if _, err := tx.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -357,17 +342,14 @@ func (a *Manager) ResetStats() error {
|
||||
|
||||
// UpdateStats updates statistics for one or more users in a single transaction
|
||||
func (a *Manager) UpdateStats(stats map[string]*Stats) error {
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for userID, update := range stats {
|
||||
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
|
||||
return err
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
for userID, update := range stats {
|
||||
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
|
||||
@@ -578,6 +560,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *Manager) maybeHashPassword(password string, hashed bool) (string, error) {
|
||||
if hashed {
|
||||
if err := ValidPasswordHash(password, a.config.BcryptCost); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
return hashPassword(password, a.config.BcryptCost)
|
||||
}
|
||||
|
||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||
// permission. The user param may be nil to signal an anonymous user.
|
||||
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
||||
@@ -612,37 +604,42 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
|
||||
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
||||
// owner may either be a user (username), or the system (empty).
|
||||
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
||||
return a.allowAccess(username, topicPattern, permission, false)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.allowAccessTx(tx, username, topicPattern, permission, false)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) allowAccess(username string, topicPattern string, permission Permission, provisioned bool) error {
|
||||
func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
return a.allowAccessTx(a.db, username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned)
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), "", "", provisioned)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
||||
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
||||
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
||||
return a.resetAccess(username, topicPattern)
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.resetAccessTx(tx, username, topicPattern)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) resetAccess(username string, topicPattern string) error {
|
||||
func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if username == "" && topicPattern == "" {
|
||||
_, err := a.db.Exec(a.queries.deleteAllAccess)
|
||||
_, err := tx.Exec(a.queries.deleteAllAccess)
|
||||
return err
|
||||
} else if topicPattern == "" {
|
||||
return a.resetUserAccessTx(a.db, username)
|
||||
return a.resetUserAccessTx(tx, username)
|
||||
}
|
||||
return a.resetTopicAccessTx(a.db, username, topicPattern)
|
||||
return a.resetTopicAccessTx(tx, username, topicPattern)
|
||||
}
|
||||
|
||||
// DefaultAccess returns the default read/write access if no access control entry matches
|
||||
@@ -749,18 +746,12 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
|
||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := a.allowAccessTx(tx, username, topic, true, true, username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveReservations deletes the access control entries associated with the given username/topic,
|
||||
@@ -775,20 +766,17 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
}
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
for _, topic := range topics {
|
||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
@@ -893,17 +881,17 @@ func (a *Manager) ResetAllProvisionedAccess() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Manager) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
||||
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
} else if !AllowedTopicPattern(topic) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned)
|
||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Manager) resetUserAccessTx(tx execer, username string) error {
|
||||
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
@@ -911,7 +899,7 @@ func (a *Manager) resetUserAccessTx(tx execer, username string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) error {
|
||||
func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||
@@ -925,17 +913,14 @@ func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) e
|
||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
return a.createToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
||||
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
||||
})
|
||||
}
|
||||
|
||||
// createToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
|
||||
// createTokenTx creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
|
||||
// If maxTokenCount is 0, no pruning is performed.
|
||||
func (a *Manager) createToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
|
||||
if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -950,9 +935,6 @@ func (a *Manager) createToken(userID, token, label string, lastAccess time.Time,
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
@@ -1064,17 +1046,14 @@ func (a *Manager) AllProvisionedTokens() ([]*Token, error) {
|
||||
|
||||
// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
|
||||
func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for token, update := range updates {
|
||||
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
|
||||
return err
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
for token, update := range updates {
|
||||
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveProvisionedToken deletes a provisioned token by value, regardless of user
|
||||
@@ -1320,27 +1299,33 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
||||
provisionUsernames := util.Map(a.config.Users, func(u *User) string {
|
||||
return u.Name
|
||||
})
|
||||
if err := a.maybeProvisionUsers(provisionUsernames, existingUsers); err != nil {
|
||||
return fmt.Errorf("failed to provision users: %v", err)
|
||||
existingTokens, err := a.AllProvisionedTokens()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.maybeProvisionGrants(); err != nil {
|
||||
return fmt.Errorf("failed to provision grants: %v", err)
|
||||
}
|
||||
if err := a.maybeProvisionTokens(provisionUsernames); err != nil {
|
||||
return fmt.Errorf("failed to provision tokens: %v", err)
|
||||
}
|
||||
return nil
|
||||
return execTx(a.db, func(tx *sql.Tx) error {
|
||||
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
|
||||
return fmt.Errorf("failed to provision users: %v", err)
|
||||
}
|
||||
if err := a.maybeProvisionGrants(tx); err != nil {
|
||||
return fmt.Errorf("failed to provision grants: %v", err)
|
||||
}
|
||||
if err := a.maybeProvisionTokens(tx, provisionUsernames, existingTokens); err != nil {
|
||||
return fmt.Errorf("failed to provision tokens: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
|
||||
// It also removes users that are provisioned, but not in the config anymore.
|
||||
func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers []*User) error {
|
||||
func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
|
||||
// Remove users that are provisioned, but not in the config anymore
|
||||
for _, user := range existingUsers {
|
||||
if user.Name == Everyone {
|
||||
continue
|
||||
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
|
||||
if err := a.removeUser(user.Name); err != nil {
|
||||
if err := a.removeUserTx(tx, user.Name); err != nil {
|
||||
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
|
||||
}
|
||||
}
|
||||
@@ -1354,22 +1339,22 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
|
||||
return u.Name == user.Name
|
||||
})
|
||||
if !exists {
|
||||
if err := a.addUser(user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
|
||||
if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) {
|
||||
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
|
||||
}
|
||||
} else {
|
||||
if !existingUser.Provisioned {
|
||||
if err := a.ChangeProvisioned(user.Name, true); err != nil {
|
||||
if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
|
||||
return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
|
||||
}
|
||||
}
|
||||
if existingUser.Hash != user.Hash {
|
||||
if err := a.changePasswordHash(user.Name, user.Hash); err != nil {
|
||||
if err := a.changePasswordHashTx(tx, user.Name, user.Hash); err != nil {
|
||||
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
|
||||
}
|
||||
}
|
||||
if existingUser.Role != user.Role {
|
||||
if err := a.changeRole(user.Name, user.Role); err != nil {
|
||||
if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
|
||||
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
|
||||
}
|
||||
}
|
||||
@@ -1382,9 +1367,9 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
|
||||
//
|
||||
// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
|
||||
// access time) or do not have dependent resources (such as grants or tokens).
|
||||
func (a *Manager) maybeProvisionGrants() error {
|
||||
func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
|
||||
// Remove all provisioned grants
|
||||
if err := a.ResetAllProvisionedAccess(); err != nil {
|
||||
if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
// (Re-)add provisioned grants
|
||||
@@ -1398,10 +1383,10 @@ func (a *Manager) maybeProvisionGrants() error {
|
||||
return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
|
||||
}
|
||||
for _, grant := range grants {
|
||||
if err := a.resetAccess(username, grant.TopicPattern); err != nil {
|
||||
if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
|
||||
return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
|
||||
}
|
||||
if err := a.allowAccess(username, grant.TopicPattern, grant.Permission, true); err != nil {
|
||||
if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1409,12 +1394,8 @@ func (a *Manager) maybeProvisionGrants() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
||||
func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string, existingTokens []*Token) error {
|
||||
// Remove tokens that are provisioned, but not in the config anymore
|
||||
existingTokens, err := a.AllProvisionedTokens()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
|
||||
}
|
||||
var provisionTokens []string
|
||||
for _, userTokens := range a.config.Tokens {
|
||||
for _, token := range userTokens {
|
||||
@@ -1423,7 +1404,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
||||
}
|
||||
for _, existingToken := range existingTokens {
|
||||
if !slices.Contains(provisionTokens, existingToken.Value) {
|
||||
if err := a.RemoveProvisionedToken(existingToken.Value); err != nil {
|
||||
if _, err := tx.Exec(a.queries.deleteProvisionedToken, existingToken.Value); err != nil {
|
||||
return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
|
||||
}
|
||||
}
|
||||
@@ -1433,12 +1414,12 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
||||
if !slices.Contains(provisionUsernames, username) && username != Everyone {
|
||||
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
||||
}
|
||||
userID, err := a.UserIDByUsername(username)
|
||||
if err != nil {
|
||||
var userID string
|
||||
if err := tx.QueryRow(a.queries.selectUserIDFromUsername, username).Scan(&userID); err != nil {
|
||||
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
|
||||
}
|
||||
for _, token := range tokens {
|
||||
if _, err := a.createToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
|
||||
if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@@ -345,8 +344,3 @@ type storeQueries struct {
|
||||
updateBilling string
|
||||
}
|
||||
|
||||
// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods
|
||||
// to be used both standalone and within a transaction.
|
||||
type execer interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
32
user/util.go
32
user/util.go
@@ -113,3 +113,35 @@ func escapeUnderscore(s string) string {
|
||||
func unescapeUnderscore(s string) string {
|
||||
return strings.ReplaceAll(s, "\\_", "_")
|
||||
}
|
||||
|
||||
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
||||
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if err := f(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// queryTx executes a function in a transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back.
|
||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user