Make more consistent

This commit is contained in:
binwiederhier
2026-02-23 13:49:54 -05:00
parent 90d0eca14d
commit b02366b42b
5 changed files with 78 additions and 122 deletions

View File

@@ -32,22 +32,19 @@ type Store interface {
ChangeSettings(userID string, prefs *Prefs) error
ChangeTier(username, tierCode string) error
ResetTier(username string) error
UpdateStats(userID string, stats *Stats) error
UpdateStats(stats map[string]*Stats) error
ResetStats() error
// Token operations
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error)
Token(userID, token string) (*Token, error)
Tokens(userID string) ([]*Token, error)
AllProvisionedTokens() ([]*Token, error)
ChangeTokenLabel(userID, token, label string) error
ChangeTokenExpiry(userID, token string, expires time.Time) error
ChangeToken(userID, token, label string, expires time.Time) error
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
RemoveToken(userID, token string) error
RemoveProvisionedToken(token string) error
RemoveExpiredTokens() error
TokenCount(userID string) (int, error)
RemoveExcessTokens(userID string, maxCount int) error
// Access operations
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
@@ -119,11 +116,9 @@ type storeQueries struct {
// Token queries
selectToken string
selectTokens string
selectTokenCount string
selectAllProvisionedTokens string
upsertToken string
updateTokenLabel string
updateTokenExpiry string
updateToken string
updateTokenLastAccess string
deleteToken string
deleteProvisionedToken string
@@ -366,12 +361,19 @@ func (s *commonStore) ResetTier(username string) error {
return err
}
// UpdateStats updates the user statistics
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
// UpdateStats updates statistics for one or more users in a single transaction
func (s *commonStore) UpdateStats(stats map[string]*Stats) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
return nil
defer tx.Rollback()
for userID, update := range stats {
if _, err := tx.Exec(s.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil {
return err
}
}
return tx.Commit()
}
// ResetStats resets all user stats in the user database
@@ -443,9 +445,23 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
return user, nil
}
// CreateToken creates a new token
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
// CreateToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount.
// If maxTokenCount is 0, no pruning is performed.
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) {
tx, err := s.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
if _, err := tx.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
return nil, err
}
if maxTokenCount > 0 {
if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil {
return nil, err
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &Token{
@@ -508,17 +524,9 @@ func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
return tokens, nil
}
// ChangeTokenLabel updates a token's label
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
return err
}
return nil
}
// ChangeTokenExpiry updates a token's expiry time
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
// ChangeToken updates a token's label and expiry time
func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time) error {
if _, err := s.db.Exec(s.queries.updateToken, label, expires.Unix(), userID, token); err != nil {
return err
}
return nil
@@ -562,30 +570,6 @@ func (s *commonStore) RemoveExpiredTokens() error {
return nil
}
// TokenCount returns the number of tokens for a user
func (s *commonStore) TokenCount(userID string) (int, error) {
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
var count int
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
return err
}
return nil
}
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
var token, label, lastOrigin string
var lastAccess, expires int64