Re-add execTx
This commit is contained in:
307
user/manager.go
307
user/manager.go
@@ -114,38 +114,27 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
|||||||
|
|
||||||
// AddUser adds a user with the given username, password and role
|
// AddUser adds a user with the given username, password and role
|
||||||
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
|
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
|
||||||
return a.addUser(username, password, role, hashed, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Manager) addUser(username, password string, role Role, hashed, provisioned bool) error {
|
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
var hash string
|
hash, err := a.maybeHashPassword(password, hashed)
|
||||||
var err error
|
if err != nil {
|
||||||
if hashed {
|
return err
|
||||||
hash = password
|
|
||||||
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hash, err = hashPassword(password, a.config.BcryptCost)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return a.insertUser(username, hash, role, provisioned)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.addUserTx(tx, username, hash, role, false)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertUser adds a user with the given username, password hash and role to the database
|
// addUserTx adds a user with the given username, password hash and role to the database
|
||||||
func (a *Manager) insertUser(username, hash string, role Role, provisioned bool) error {
|
func (a *Manager) addUserTx(tx *sql.Tx, username, hash string, role Role, provisioned bool) error {
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
if _, err := a.db.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
if _, err := tx.Exec(a.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||||
if isUniqueConstraintError(err) {
|
if isUniqueConstraintError(err) {
|
||||||
return ErrUserExists
|
return ErrUserExists
|
||||||
}
|
}
|
||||||
@@ -160,16 +149,18 @@ func (a *Manager) RemoveUser(username string) error {
|
|||||||
if err := a.CanChangeUser(username); err != nil {
|
if err := a.CanChangeUser(username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return a.removeUser(username)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.removeUserTx(tx, username)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeUser deletes the user with the given username
|
// removeUserTx deletes the user with the given username
|
||||||
func (a *Manager) removeUser(username string) error {
|
func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
|
||||||
if !AllowedUsername(username) {
|
if !AllowedUsername(username) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||||
if _, err := a.db.Exec(a.queries.deleteUser, username); err != nil {
|
if _, err := tx.Exec(a.queries.deleteUser, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -181,22 +172,19 @@ func (a *Manager) MarkUserRemoved(user *User) error {
|
|||||||
if !AllowedUsername(user.Name) {
|
if !AllowedUsername(user.Name) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
tx, err := a.db.Begin()
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil {
|
||||||
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||||
if _, err := tx.Exec(a.queries.deleteAllToken, user.ID); err != nil {
|
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
return nil
|
||||||
if _, err := tx.Exec(a.queries.updateUserDeleted, deletedTime, user.ID); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||||
@@ -212,25 +200,18 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
|
|||||||
if err := a.CanChangeUser(username); err != nil {
|
if err := a.CanChangeUser(username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var hash string
|
hash, err := a.maybeHashPassword(password, hashed)
|
||||||
var err error
|
if err != nil {
|
||||||
if hashed {
|
return err
|
||||||
hash = password
|
|
||||||
if err := ValidPasswordHash(hash, a.config.BcryptCost); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
hash, err = hashPassword(password, a.config.BcryptCost)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return a.changePasswordHash(username, hash)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.changePasswordHashTx(tx, username, hash)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// changePasswordHash changes a user's password hash in the database
|
// changePasswordHashTx changes a user's password hash in the database
|
||||||
func (a *Manager) changePasswordHash(username, hash string) error {
|
func (a *Manager) changePasswordHashTx(tx *sql.Tx, username, hash string) error {
|
||||||
if _, err := a.db.Exec(a.queries.updateUserPass, hash, username); err != nil {
|
if _, err := tx.Exec(a.queries.updateUserPass, hash, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -242,19 +223,16 @@ func (a *Manager) ChangeRole(username string, role Role) error {
|
|||||||
if err := a.CanChangeUser(username); err != nil {
|
if err := a.CanChangeUser(username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return a.changeRole(username, role)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.changeRoleTx(tx, username, role)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// changeRole changes a user's role
|
// changeRoleTx changes a user's role
|
||||||
func (a *Manager) changeRole(username string, role Role) error {
|
func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
|
||||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
tx, err := a.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil {
|
if _, err := tx.Exec(a.queries.updateUserRole, string(role), username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -264,7 +242,7 @@ func (a *Manager) changeRole(username string, role Role) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanChangeUser checks if the user with the given username can be changed.
|
// CanChangeUser checks if the user with the given username can be changed.
|
||||||
@@ -281,7 +259,14 @@ func (a *Manager) CanChangeUser(username string) error {
|
|||||||
|
|
||||||
// ChangeProvisioned changes the provisioned status of a user
|
// ChangeProvisioned changes the provisioned status of a user
|
||||||
func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
|
func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
|
||||||
if _, err := a.db.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.changeProvisionedTx(tx, username, provisioned)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// changeProvisionedTx changes the provisioned status of a user
|
||||||
|
func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error {
|
||||||
|
if _, err := tx.Exec(a.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -357,17 +342,14 @@ func (a *Manager) ResetStats() error {
|
|||||||
|
|
||||||
// UpdateStats updates statistics for one or more users in a single transaction
|
// UpdateStats updates statistics for one or more users in a single transaction
|
||||||
func (a *Manager) UpdateStats(stats map[string]*Stats) error {
|
func (a *Manager) UpdateStats(stats map[string]*Stats) error {
|
||||||
tx, err := a.db.Begin()
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for userID, update := range stats {
|
||||||
return err
|
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for userID, update := range stats {
|
|
||||||
if _, err := tx.Exec(a.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
|
// EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
|
||||||
@@ -578,6 +560,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Manager) maybeHashPassword(password string, hashed bool) (string, error) {
|
||||||
|
if hashed {
|
||||||
|
if err := ValidPasswordHash(password, a.config.BcryptCost); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return password, nil
|
||||||
|
}
|
||||||
|
return hashPassword(password, a.config.BcryptCost)
|
||||||
|
}
|
||||||
|
|
||||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||||
// permission. The user param may be nil to signal an anonymous user.
|
// permission. The user param may be nil to signal an anonymous user.
|
||||||
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
||||||
@@ -612,37 +604,42 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
|
|||||||
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
||||||
// owner may either be a user (username), or the system (empty).
|
// owner may either be a user (username), or the system (empty).
|
||||||
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
||||||
return a.allowAccess(username, topicPattern, permission, false)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.allowAccessTx(tx, username, topicPattern, permission, false)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) allowAccess(username string, topicPattern string, permission Permission, provisioned bool) error {
|
func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
|
||||||
if !AllowedUsername(username) && username != Everyone {
|
if !AllowedUsername(username) && username != Everyone {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
} else if !AllowedTopicPattern(topicPattern) {
|
} else if !AllowedTopicPattern(topicPattern) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
return a.allowAccessTx(a.db, username, topicPattern, permission.IsRead(), permission.IsWrite(), "", provisioned)
|
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), "", "", provisioned)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
||||||
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
||||||
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
||||||
return a.resetAccess(username, topicPattern)
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
|
return a.resetAccessTx(tx, username, topicPattern)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) resetAccess(username string, topicPattern string) error {
|
func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error {
|
||||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
if username == "" && topicPattern == "" {
|
if username == "" && topicPattern == "" {
|
||||||
_, err := a.db.Exec(a.queries.deleteAllAccess)
|
_, err := tx.Exec(a.queries.deleteAllAccess)
|
||||||
return err
|
return err
|
||||||
} else if topicPattern == "" {
|
} else if topicPattern == "" {
|
||||||
return a.resetUserAccessTx(a.db, username)
|
return a.resetUserAccessTx(tx, username)
|
||||||
}
|
}
|
||||||
return a.resetTopicAccessTx(a.db, username, topicPattern)
|
return a.resetTopicAccessTx(tx, username, topicPattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultAccess returns the default read/write access if no access control entry matches
|
// DefaultAccess returns the default read/write access if no access control entry matches
|
||||||
@@ -749,18 +746,12 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
|
|||||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
tx, err := a.db.Begin()
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
return a.addReservationAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username)
|
||||||
if err := a.allowAccessTx(tx, username, topic, true, true, username, false); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := a.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveReservations deletes the access control entries associated with the given username/topic,
|
// RemoveReservations deletes the access control entries associated with the given username/topic,
|
||||||
@@ -775,20 +766,17 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
|||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tx, err := a.db.Begin()
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for _, topic := range topics {
|
||||||
return err
|
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for _, topic := range topics {
|
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
||||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
}
|
||||||
if err := a.resetTopicAccessTx(tx, Everyone, topic); err != nil {
|
return nil
|
||||||
return err
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||||
@@ -893,17 +881,17 @@ func (a *Manager) ResetAllProvisionedAccess() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
|
||||||
if !AllowedUsername(username) && username != Everyone {
|
if !AllowedUsername(username) && username != Everyone {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
} else if !AllowedTopicPattern(topicPattern) {
|
} else if !AllowedTopicPattern(topic) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned)
|
_, err := tx.Exec(a.queries.upsertUserAccess, username, toSQLWildcard(topic), read, write, ownerUsername, ownerUsername, false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) resetUserAccessTx(tx execer, username string) error {
|
func (a *Manager) resetUserAccessTx(tx *sql.Tx, username string) error {
|
||||||
if !AllowedUsername(username) && username != Everyone {
|
if !AllowedUsername(username) && username != Everyone {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
@@ -911,7 +899,7 @@ func (a *Manager) resetUserAccessTx(tx execer, username string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) error {
|
func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string) error {
|
||||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||||
@@ -925,17 +913,14 @@ func (a *Manager) resetTopicAccessTx(tx execer, username, topicPattern string) e
|
|||||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||||
// given user, if there are too many of them.
|
// given user, if there are too many of them.
|
||||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||||
return a.createToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||||
|
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// createToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
|
// createTokenTx creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
|
||||||
// If maxTokenCount is 0, no pruning is performed.
|
// If maxTokenCount is 0, no pruning is performed.
|
||||||
func (a *Manager) createToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
|
func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
|
||||||
tx, err := a.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
if _, err := tx.Exec(a.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -950,9 +935,6 @@ func (a *Manager) createToken(userID, token, label string, lastAccess time.Time,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Token{
|
return &Token{
|
||||||
Value: token,
|
Value: token,
|
||||||
Label: label,
|
Label: label,
|
||||||
@@ -1064,17 +1046,14 @@ func (a *Manager) AllProvisionedTokens() ([]*Token, error) {
|
|||||||
|
|
||||||
// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
|
// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
|
||||||
func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
|
func (a *Manager) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
|
||||||
tx, err := a.db.Begin()
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for token, update := range updates {
|
||||||
return err
|
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for token, update := range updates {
|
|
||||||
if _, err := tx.Exec(a.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveProvisionedToken deletes a provisioned token by value, regardless of user
|
// RemoveProvisionedToken deletes a provisioned token by value, regardless of user
|
||||||
@@ -1320,27 +1299,33 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
|||||||
provisionUsernames := util.Map(a.config.Users, func(u *User) string {
|
provisionUsernames := util.Map(a.config.Users, func(u *User) string {
|
||||||
return u.Name
|
return u.Name
|
||||||
})
|
})
|
||||||
if err := a.maybeProvisionUsers(provisionUsernames, existingUsers); err != nil {
|
existingTokens, err := a.AllProvisionedTokens()
|
||||||
return fmt.Errorf("failed to provision users: %v", err)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if err := a.maybeProvisionGrants(); err != nil {
|
return execTx(a.db, func(tx *sql.Tx) error {
|
||||||
return fmt.Errorf("failed to provision grants: %v", err)
|
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to provision users: %v", err)
|
||||||
if err := a.maybeProvisionTokens(provisionUsernames); err != nil {
|
}
|
||||||
return fmt.Errorf("failed to provision tokens: %v", err)
|
if err := a.maybeProvisionGrants(tx); err != nil {
|
||||||
}
|
return fmt.Errorf("failed to provision grants: %v", err)
|
||||||
return nil
|
}
|
||||||
|
if err := a.maybeProvisionTokens(tx, provisionUsernames, existingTokens); err != nil {
|
||||||
|
return fmt.Errorf("failed to provision tokens: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
|
// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
|
||||||
// It also removes users that are provisioned, but not in the config anymore.
|
// It also removes users that are provisioned, but not in the config anymore.
|
||||||
func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers []*User) error {
|
func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
|
||||||
// Remove users that are provisioned, but not in the config anymore
|
// Remove users that are provisioned, but not in the config anymore
|
||||||
for _, user := range existingUsers {
|
for _, user := range existingUsers {
|
||||||
if user.Name == Everyone {
|
if user.Name == Everyone {
|
||||||
continue
|
continue
|
||||||
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
|
} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
|
||||||
if err := a.removeUser(user.Name); err != nil {
|
if err := a.removeUserTx(tx, user.Name); err != nil {
|
||||||
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
|
return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1354,22 +1339,22 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
|
|||||||
return u.Name == user.Name
|
return u.Name == user.Name
|
||||||
})
|
})
|
||||||
if !exists {
|
if !exists {
|
||||||
if err := a.addUser(user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
|
if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true); err != nil && !errors.Is(err, ErrUserExists) {
|
||||||
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
|
return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if !existingUser.Provisioned {
|
if !existingUser.Provisioned {
|
||||||
if err := a.ChangeProvisioned(user.Name, true); err != nil {
|
if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
|
||||||
return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
|
return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if existingUser.Hash != user.Hash {
|
if existingUser.Hash != user.Hash {
|
||||||
if err := a.changePasswordHash(user.Name, user.Hash); err != nil {
|
if err := a.changePasswordHashTx(tx, user.Name, user.Hash); err != nil {
|
||||||
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
|
return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if existingUser.Role != user.Role {
|
if existingUser.Role != user.Role {
|
||||||
if err := a.changeRole(user.Name, user.Role); err != nil {
|
if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
|
||||||
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
|
return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1382,9 +1367,9 @@ func (a *Manager) maybeProvisionUsers(provisionUsernames []string, existingUsers
|
|||||||
//
|
//
|
||||||
// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
|
// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
|
||||||
// access time) or do not have dependent resources (such as grants or tokens).
|
// access time) or do not have dependent resources (such as grants or tokens).
|
||||||
func (a *Manager) maybeProvisionGrants() error {
|
func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
|
||||||
// Remove all provisioned grants
|
// Remove all provisioned grants
|
||||||
if err := a.ResetAllProvisionedAccess(); err != nil {
|
if _, err := tx.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// (Re-)add provisioned grants
|
// (Re-)add provisioned grants
|
||||||
@@ -1398,10 +1383,10 @@ func (a *Manager) maybeProvisionGrants() error {
|
|||||||
return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
|
return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
|
||||||
}
|
}
|
||||||
for _, grant := range grants {
|
for _, grant := range grants {
|
||||||
if err := a.resetAccess(username, grant.TopicPattern); err != nil {
|
if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
|
||||||
return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
|
return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
|
||||||
}
|
}
|
||||||
if err := a.allowAccess(username, grant.TopicPattern, grant.Permission, true); err != nil {
|
if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1409,12 +1394,8 @@ func (a *Manager) maybeProvisionGrants() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string, existingTokens []*Token) error {
|
||||||
// Remove tokens that are provisioned, but not in the config anymore
|
// Remove tokens that are provisioned, but not in the config anymore
|
||||||
existingTokens, err := a.AllProvisionedTokens()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
|
|
||||||
}
|
|
||||||
var provisionTokens []string
|
var provisionTokens []string
|
||||||
for _, userTokens := range a.config.Tokens {
|
for _, userTokens := range a.config.Tokens {
|
||||||
for _, token := range userTokens {
|
for _, token := range userTokens {
|
||||||
@@ -1423,7 +1404,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
|||||||
}
|
}
|
||||||
for _, existingToken := range existingTokens {
|
for _, existingToken := range existingTokens {
|
||||||
if !slices.Contains(provisionTokens, existingToken.Value) {
|
if !slices.Contains(provisionTokens, existingToken.Value) {
|
||||||
if err := a.RemoveProvisionedToken(existingToken.Value); err != nil {
|
if _, err := tx.Exec(a.queries.deleteProvisionedToken, existingToken.Value); err != nil {
|
||||||
return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
|
return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1433,12 +1414,12 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
|
|||||||
if !slices.Contains(provisionUsernames, username) && username != Everyone {
|
if !slices.Contains(provisionUsernames, username) && username != Everyone {
|
||||||
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
||||||
}
|
}
|
||||||
userID, err := a.UserIDByUsername(username)
|
var userID string
|
||||||
if err != nil {
|
if err := tx.QueryRow(a.queries.selectUserIDFromUsername, username).Scan(&userID); err != nil {
|
||||||
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
|
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
|
||||||
}
|
}
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
if _, err := a.createToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
|
if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package user
|
package user
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -345,8 +344,3 @@ type storeQueries struct {
|
|||||||
updateBilling string
|
updateBilling string
|
||||||
}
|
}
|
||||||
|
|
||||||
// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods
|
|
||||||
// to be used both standalone and within a transaction.
|
|
||||||
type execer interface {
|
|
||||||
Exec(query string, args ...any) (sql.Result, error)
|
|
||||||
}
|
|
||||||
|
|||||||
32
user/util.go
32
user/util.go
@@ -113,3 +113,35 @@ func escapeUnderscore(s string) string {
|
|||||||
func unescapeUnderscore(s string) string {
|
func unescapeUnderscore(s string) string {
|
||||||
return strings.ReplaceAll(s, "\\_", "_")
|
return strings.ReplaceAll(s, "\\_", "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
||||||
|
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err := f(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryTx executes a function in a transaction and returns the result. If the function
|
||||||
|
// returns an error, the transaction is rolled back.
|
||||||
|
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
t, err := f(tx)
|
||||||
|
if err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user