Small refinements

This commit is contained in:
binwiederhier
2026-02-28 21:15:51 -05:00
parent ccbd02331c
commit c19377109e
3 changed files with 12 additions and 25 deletions

View File

@@ -103,7 +103,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
if len(token) != tokenLength { if len(token) != tokenLength {
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
} }
user, err := a.UserByToken(token) user, err := a.userByToken(token)
if err != nil { if err != nil {
log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed") log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
return nil, ErrUnauthenticated return nil, ErrUnauthenticated
@@ -114,9 +114,6 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// AddUser adds a user with the given username, password and role // AddUser adds a user with the given username, password and role
func (a *Manager) AddUser(username, password string, role Role, hashed bool) error { func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
if !AllowedUsername(username) || !AllowedRole(role) {
return ErrInvalidArgument
}
hash, err := a.maybeHashPassword(password, hashed) hash, err := a.maybeHashPassword(password, hashed)
if err != nil { if err != nil {
return err return err
@@ -416,8 +413,8 @@ func (a *Manager) UserByID(id string) (*User, error) {
return a.readUser(rows) return a.readUser(rows)
} }
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise // userByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
func (a *Manager) UserByToken(token string) (*User, error) { func (a *Manager) userByToken(token string) (*User, error) {
rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix()) rows, err := a.db.Query(a.queries.selectUserByToken, token, time.Now().Unix())
if err != nil { if err != nil {
return nil, err return nil, err
@@ -581,7 +578,7 @@ func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
username = user.Name username = user.Name
} }
// Select the read/write permissions for this user/topic combo. // Select the read/write permissions for this user/topic combo.
read, write, found, err := a.AuthorizeTopicAccess(username, topic) read, write, found, err := a.authorizeTopicAccess(username, topic)
if err != nil { if err != nil {
return err return err
} }
@@ -663,13 +660,13 @@ func (a *Manager) AllowReservation(username string, topic string) error {
return nil return nil
} }
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic. // authorizeTopicAccess returns the read/write permissions for the given username and topic.
// The found return value indicates whether an ACL entry was found at all. // The found return value indicates whether an ACL entry was found at all.
// //
// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. // - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions // - It also prioritizes write permissions over read permissions
func (a *Manager) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { func (a *Manager) authorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) rows, err := a.db.Query(a.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil { if err != nil {
return false, false, false, err return false, false, false, err
@@ -873,14 +870,6 @@ func (a *Manager) OtherAccessCount(username, topic string) (int, error) {
return count, nil return count, nil
} }
// ResetAllProvisionedAccess removes all provisioned access control entries
func (a *Manager) ResetAllProvisionedAccess() error {
if _, err := a.db.Exec(a.queries.deleteUserAccessProvisioned); err != nil {
return err
}
return nil
}
func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error { func (a *Manager) addReservationAccessTx(tx *sql.Tx, username, topic string, read, write bool, ownerUsername string) error {
if !AllowedUsername(username) && username != Everyone { if !AllowedUsername(username) && username != Everyone {
return ErrInvalidArgument return ErrInvalidArgument
@@ -1024,8 +1013,7 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
return tokens, nil return tokens, nil
} }
// AllProvisionedTokens returns all provisioned tokens func (a *Manager) allProvisionedTokens() ([]*Token, error) {
func (a *Manager) AllProvisionedTokens() ([]*Token, error) {
rows, err := a.db.Query(a.queries.selectAllProvisionedTokens) rows, err := a.db.Query(a.queries.selectAllProvisionedTokens)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1299,7 +1287,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
provisionUsernames := util.Map(a.config.Users, func(u *User) string { provisionUsernames := util.Map(a.config.Users, func(u *User) string {
return u.Name return u.Name
}) })
existingTokens, err := a.AllProvisionedTokens() existingTokens, err := a.allProvisionedTokens()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1797,7 +1797,7 @@ func TestStoreUserByToken(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.NotEmpty(t, tk.Value) require.NotEmpty(t, tk.Value)
u2, err := manager.UserByToken(tk.Value) u2, err := manager.userByToken(tk.Value)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "phil", u2.Name) require.Equal(t, "phil", u2.Name)
}) })
@@ -2054,7 +2054,7 @@ func TestStoreAuthorizeTopicAccess(t *testing.T) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AllowAccess("phil", "mytopic", PermissionReadWrite)) require.Nil(t, manager.AllowAccess("phil", "mytopic", PermissionReadWrite))
read, write, found, err := manager.AuthorizeTopicAccess("phil", "mytopic") read, write, found, err := manager.authorizeTopicAccess("phil", "mytopic")
require.Nil(t, err) require.Nil(t, err)
require.True(t, found) require.True(t, found)
require.True(t, read) require.True(t, read)
@@ -2066,7 +2066,7 @@ func TestStoreAuthorizeTopicAccessNotFound(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, manager *Manager) { forEachStoreBackend(t, func(t *testing.T, manager *Manager) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
_, _, found, err := manager.AuthorizeTopicAccess("phil", "other") _, _, found, err := manager.authorizeTopicAccess("phil", "other")
require.Nil(t, err) require.Nil(t, err)
require.False(t, found) require.False(t, found)
}) })
@@ -2077,7 +2077,7 @@ func TestStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false)) require.Nil(t, manager.AddUser("phil", "mypass", RoleUser, false))
require.Nil(t, manager.AllowAccess("phil", "secret", PermissionDenyAll)) require.Nil(t, manager.AllowAccess("phil", "secret", PermissionDenyAll))
read, write, found, err := manager.AuthorizeTopicAccess("phil", "secret") read, write, found, err := manager.authorizeTopicAccess("phil", "secret")
require.Nil(t, err) require.Nil(t, err)
require.True(t, found) require.True(t, found)
require.False(t, read) require.False(t, read)

View File

@@ -343,4 +343,3 @@ type storeQueries struct {
// Billing queries // Billing queries
updateBilling string updateBilling string
} }