diff --git a/user/store.go b/user/store.go index 7ebb5b58..6f5fb606 100644 --- a/user/store.go +++ b/user/store.go @@ -155,18 +155,18 @@ type commonStore struct { queries storeQueries } -// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise -func (s *commonStore) UserByID(id string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByID, id) +// User returns the user with the given username if it exists, or ErrUserNotFound otherwise +func (s *commonStore) User(username string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByName, username) if err != nil { return nil, err } return s.readUser(rows) } -// User returns the user with the given username if it exists, or ErrUserNotFound otherwise -func (s *commonStore) User(username string) (*User, error) { - rows, err := s.db.Query(s.queries.selectUserByName, username) +// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise +func (s *commonStore) UserByID(id string) (*User, error) { + rows, err := s.db.Query(s.queries.selectUserByID, id) if err != nil { return nil, err } @@ -273,7 +273,7 @@ func (s *commonStore) MarkUserRemoved(userID, username string) error { return err } defer tx.Rollback() - if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil { + if err := s.resetUserAccessTx(tx, username); err != nil { return err } if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil { @@ -317,7 +317,7 @@ func (s *commonStore) ChangeRole(username string, role Role) error { } // If changing to admin, remove all access entries if role == RoleAdmin { - if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil { + if err := s.resetUserAccessTx(tx, username); err != nil { return err } } @@ -383,6 +383,7 @@ func (s *commonStore) ResetStats() error { } return nil } + func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { defer rows.Close() var id, username, hash, role, prefs, syncTopic string @@ -688,37 +689,44 @@ func (s *commonStore) Grants(username string) ([]Grant, error) { // AllowAccess adds or updates an entry in the access control list func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { + return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned) +} + +func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { if !AllowedUsername(username) && username != Everyone { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } - return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned) -} - -func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { _, err := tx.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned) return err } // ResetAccess removes an access control list entry func (s *commonStore) ResetAccess(username, topicPattern string) error { + if username == "" && topicPattern == "" { + _, err := s.db.Exec(s.queries.deleteAllAccess) + return err + } else if topicPattern == "" { + return s.resetUserAccessTx(s.db, username) + } + return s.resetTopicAccessTx(s.db, username, topicPattern) +} + +func (s *commonStore) resetUserAccessTx(tx execer, username string) error { + if !AllowedUsername(username) && username != Everyone { + return ErrInvalidArgument + } + _, err := tx.Exec(s.queries.deleteUserAccess, username, username) + return err +} + +func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error { if !AllowedUsername(username) && username != Everyone && username != "" { return ErrInvalidArgument } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" { return ErrInvalidArgument } - if username == "" && topicPattern == "" { - _, err := s.db.Exec(s.queries.deleteAllAccess) - return err - } else if topicPattern == "" { - _, err := s.db.Exec(s.queries.deleteUserAccess, username, username) - return err - } - return s.resetTopicAccessTx(s.db, username, topicPattern) -} - -func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error { _, err := tx.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) return err }