Make more consistent

This commit is contained in:
binwiederhier
2026-02-23 13:49:54 -05:00
parent 90d0eca14d
commit b02366b42b
5 changed files with 78 additions and 122 deletions

View File

@@ -116,24 +116,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
token := GenerateToken()
access := time.Now()
// Create the token
createdToken, err := a.store.CreateToken(userID, token, label, access, origin, expires, provisioned)
if err != nil {
return nil, err
}
// Check token count and prune if necessary
tokenCount, err := a.store.TokenCount(userID)
if err != nil {
return nil, err
}
if tokenCount >= tokenMaxCount {
if err := a.store.RemoveExcessTokens(userID, tokenMaxCount); err != nil {
return nil, err
}
}
return createdToken, nil
return a.store.CreateToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
}
// Tokens returns all existing tokens for the user with the given user ID
@@ -151,32 +134,35 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time
if token == "" {
return nil, errNoTokenProvided
}
if err := a.CanChangeToken(userID, token); err != nil {
if err := a.canChangeToken(userID, token); err != nil {
return nil, err
}
t, err := a.store.Token(userID, token)
if err != nil {
return nil, err
}
if label != nil {
if err := a.store.ChangeTokenLabel(userID, token, *label); err != nil {
return nil, err
}
t.Label = *label
}
if expires != nil {
if err := a.store.ChangeTokenExpiry(userID, token, *expires); err != nil {
return nil, err
}
t.Expires = *expires
}
return a.Token(userID, token)
if err := a.store.ChangeToken(userID, token, t.Label, t.Expires); err != nil {
return nil, err
}
return t, nil
}
// RemoveToken deletes the token defined in User.Token
func (a *Manager) RemoveToken(userID, token string) error {
if err := a.CanChangeToken(userID, token); err != nil {
if err := a.canChangeToken(userID, token); err != nil {
return err
}
return a.store.RemoveToken(userID, token)
}
// CanChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed.
func (a *Manager) CanChangeToken(userID, token string) error {
// canChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed.
func (a *Manager) canChangeToken(userID, token string) error {
t, err := a.Token(userID, token)
if err != nil {
return err
@@ -277,11 +263,8 @@ func (a *Manager) writeUserStatsQueue() error {
"calls_count": update.Calls,
}).
Trace("Updating stats for user %s", userID)
if err := a.store.UpdateStats(userID, update); err != nil {
return err
}
}
return nil
return a.store.UpdateStats(statsQueue)
}
func (a *Manager) writeTokenUpdateQueue() error {
@@ -778,7 +761,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error {
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err)
}
for _, token := range tokens {
if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), true); err != nil {
if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil {
return err
}
}

View File

@@ -32,22 +32,19 @@ type Store interface {
ChangeSettings(userID string, prefs *Prefs) error
ChangeTier(username, tierCode string) error
ResetTier(username string) error
UpdateStats(userID string, stats *Stats) error
UpdateStats(stats map[string]*Stats) error
ResetStats() error
// Token operations
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error)
Token(userID, token string) (*Token, error)
Tokens(userID string) ([]*Token, error)
AllProvisionedTokens() ([]*Token, error)
ChangeTokenLabel(userID, token, label string) error
ChangeTokenExpiry(userID, token string, expires time.Time) error
ChangeToken(userID, token, label string, expires time.Time) error
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
RemoveToken(userID, token string) error
RemoveProvisionedToken(token string) error
RemoveExpiredTokens() error
TokenCount(userID string) (int, error)
RemoveExcessTokens(userID string, maxCount int) error
// Access operations
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
@@ -119,11 +116,9 @@ type storeQueries struct {
// Token queries
selectToken string
selectTokens string
selectTokenCount string
selectAllProvisionedTokens string
upsertToken string
updateTokenLabel string
updateTokenExpiry string
updateToken string
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
@@ -366,12 +361,19 @@ func (s *commonStore) ResetTier(username string) error {
return err
}
// UpdateStats updates the user statistics
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
// UpdateStats updates statistics for one or more users in a single transaction
func (s *commonStore) UpdateStats(stats map[string]*Stats) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
return nil
defer tx.Rollback()
for userID, update := range stats {
if _, err := tx.Exec(s.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
return err
}
}
return tx.Commit()
}
// ResetStats resets all user stats in the user database
@@ -443,9 +445,23 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
return user, nil
}
// CreateToken creates a new token
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
// CreateToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
// If maxTokenCount is 0, no pruning is performed.
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
tx, err := s.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
return nil, err
}
if maxTokenCount > 0 {
if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{
@@ -508,17 +524,9 @@ func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
return tokens, nil
}
// ChangeTokenLabel updates a token's label
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
return err
}
return nil
}
// ChangeTokenExpiry updates a token's expiry time
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
// ChangeToken updates a token's label and expiry time
func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time) error {
if _, err := s.db.Exec(s.queries.updateToken, label, expires.Unix(), userID, token); err != nil {
return err
}
return nil
@@ -562,30 +570,6 @@ func (s *commonStore) RemoveExpiredTokens() error {
return nil
}
// TokenCount returns the number of tokens for a user
func (s *commonStore) TokenCount(userID string) (int, error) {
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
return err
}
return nil
}
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
var token, label, lastOrigin string
var lastAccess, expires int64

View File

@@ -146,8 +146,7 @@ const (
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
postgresUpdateTokenLabelQuery = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
postgresUpdateTokenExpiryQuery = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4`
postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1`
@@ -251,11 +250,9 @@ func NewPostgresStore(db *sql.DB) (Store, error) {
// Token queries
selectToken: postgresSelectTokenQuery,
selectTokens: postgresSelectTokensQuery,
selectTokenCount: postgresSelectTokenCountQuery,
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
upsertToken: postgresUpsertTokenQuery,
updateTokenLabel: postgresUpdateTokenLabelQuery,
updateTokenExpiry: postgresUpdateTokenExpiryQuery,
updateToken: postgresUpdateTokenQuery,
updateTokenLastAccess: postgresUpdateTokenLastAccessQuery,
deleteToken: postgresDeleteTokenQuery,
deleteProvisionedToken: postgresDeleteProvisionedTokenQuery,

View File

@@ -144,8 +144,7 @@ const (
ON CONFLICT (user_id, token)
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned
`
sqliteUpdateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?`
sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
@@ -255,11 +254,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
deleteAllAccess: sqliteDeleteAllAccessQuery,
selectToken: sqliteSelectTokenQuery,
selectTokens: sqliteSelectTokensQuery,
selectTokenCount: sqliteSelectTokenCountQuery,
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
upsertToken: sqliteUpsertTokenQuery,
updateTokenLabel: sqliteUpdateTokenLabelQuery,
updateTokenExpiry: sqliteUpdateTokenExpiryQuery,
updateToken: sqliteUpdateTokenQuery,
updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery,
deleteToken: sqliteDeleteTokenQuery,
deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery,

View File

@@ -78,7 +78,7 @@ func TestStoreUserByToken(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), 0, false)
require.Nil(t, err)
require.Equal(t, "tk_test123", tk.Value)
@@ -165,7 +165,7 @@ func TestStoreTokens(t *testing.T) {
expires := now.Add(24 * time.Hour)
origin := netip.MustParseAddr("9.9.9.9")
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, 0, false)
require.Nil(t, err)
require.Equal(t, "tk_abc", tk.Value)
require.Equal(t, "my token", tk.Label)
@@ -181,27 +181,25 @@ func TestStoreTokens(t *testing.T) {
require.Nil(t, err)
require.Len(t, tokens, 1)
require.Equal(t, "tk_abc", tokens[0].Value)
// Token count
count, err := store.TokenCount(u.ID)
require.Nil(t, err)
require.Equal(t, 1, count)
})
}
func TestStoreTokenChangeLabel(t *testing.T) {
func TestStoreTokenChange(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
expires := time.Now().Add(time.Hour)
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), expires, 0, false)
require.Nil(t, err)
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
newExpires := time.Now().Add(2 * time.Hour)
require.Nil(t, store.ChangeToken(u.ID, "tk_abc", "new label", newExpires))
tk, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err)
require.Equal(t, "new label", tk.Label)
require.Equal(t, newExpires.Unix(), tk.Expires.Unix())
})
}
@@ -211,7 +209,7 @@ func TestStoreTokenRemove(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false)
require.Nil(t, err)
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
@@ -227,9 +225,9 @@ func TestStoreTokenRemoveExpired(t *testing.T) {
require.Nil(t, err)
// Create expired token and active token
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), 0, false)
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false)
require.Nil(t, err)
require.Nil(t, store.RemoveExpiredTokens())
@@ -245,28 +243,25 @@ func TestStoreTokenRemoveExpired(t *testing.T) {
})
}
func TestStoreTokenRemoveExcess(t *testing.T) {
func TestStoreTokenCreatePrunesExcess(t *testing.T) {
forEachStoreBackend(t, func(t *testing.T, store user.Store) {
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
u, err := store.User("phil")
require.Nil(t, err)
// Create 3 tokens with increasing expiry
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
// Create 2 tokens with no pruning
for i, name := range []string{"tk_a", "tk_b"} {
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), 0, false)
require.Nil(t, err)
}
count, err := store.TokenCount(u.ID)
// Create a 3rd token with maxTokenCount=2, which should prune tk_a (earliest expiry)
_, err = store.CreateToken(u.ID, "tk_c", "tk_c", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(3*time.Hour), 2, false)
require.Nil(t, err)
require.Equal(t, 3, count)
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
count, err = store.TokenCount(u.ID)
tokens, err := store.Tokens(u.ID)
require.Nil(t, err)
require.Equal(t, 2, count)
require.Equal(t, 2, len(tokens))
// tk_a should be removed (earliest expiry)
_, err = store.Token(u.ID, "tk_a")
@@ -286,7 +281,7 @@ func TestStoreTokenUpdateLastAccess(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false)
require.Nil(t, err)
newTime := time.Now().Add(5 * time.Minute)
@@ -628,7 +623,7 @@ func TestStoreUpdateStats(t *testing.T) {
require.Nil(t, err)
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
require.Nil(t, store.UpdateStats(u.ID, stats))
require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: stats}))
u2, err := store.User("phil")
require.Nil(t, err)
@@ -644,7 +639,7 @@ func TestStoreResetStats(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: {Messages: 42, Emails: 3, Calls: 1}}))
require.Nil(t, store.ResetStats())
u2, err := store.User("phil")