diff --git a/user/manager.go b/user/manager.go index 534ff86d..ce890317 100644 --- a/user/manager.go +++ b/user/manager.go @@ -116,24 +116,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { - token := GenerateToken() - access := time.Now() - // Create the token - createdToken, err := a.store.CreateToken(userID, token, label, access, origin, expires, provisioned) - if err != nil { - return nil, err - } - // Check token count and prune if necessary - tokenCount, err := a.store.TokenCount(userID) - if err != nil { - return nil, err - } - if tokenCount >= tokenMaxCount { - if err := a.store.RemoveExcessTokens(userID, tokenMaxCount); err != nil { - return nil, err - } - } - return createdToken, nil + return a.store.CreateToken(userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) } // Tokens returns all existing tokens for the user with the given user ID @@ -151,32 +134,35 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time if token == "" { return nil, errNoTokenProvided } - if err := a.CanChangeToken(userID, token); err != nil { + if err := a.canChangeToken(userID, token); err != nil { + return nil, err + } + t, err := a.store.Token(userID, token) + if err != nil { return nil, err } if label != nil { - if err := a.store.ChangeTokenLabel(userID, token, *label); err != nil { - return nil, err - } + t.Label = *label } if expires != nil { - if err := a.store.ChangeTokenExpiry(userID, token, *expires); err != nil { - return nil, err - } + t.Expires = *expires } - return a.Token(userID, token) + if err := a.store.ChangeToken(userID, token, t.Label, t.Expires); err != nil { + return nil, err + } + return t, nil } // RemoveToken deletes the token defined in User.Token func (a *Manager) RemoveToken(userID, token string) error { - if err := a.CanChangeToken(userID, token); err != nil { + if err := a.canChangeToken(userID, token); err != nil { return err } return a.store.RemoveToken(userID, token) } -// CanChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed. -func (a *Manager) CanChangeToken(userID, token string) error { +// canChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed. +func (a *Manager) canChangeToken(userID, token string) error { t, err := a.Token(userID, token) if err != nil { return err @@ -277,11 +263,8 @@ func (a *Manager) writeUserStatsQueue() error { "calls_count": update.Calls, }). Trace("Updating stats for user %s", userID) - if err := a.store.UpdateStats(userID, update); err != nil { - return err - } } - return nil + return a.store.UpdateStats(statsQueue) } func (a *Manager) writeTokenUpdateQueue() error { @@ -778,7 +761,7 @@ func (a *Manager) maybeProvisionTokens(provisionUsernames []string) error { return fmt.Errorf("failed to find provisioned user %s for provisioned tokens: %v", username, err) } for _, token := range tokens { - if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), true); err != nil { + if _, err := a.store.CreateToken(userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), time.Unix(0, 0), 0, true); err != nil { return err } } diff --git a/user/store.go b/user/store.go index 0784ae42..f03c883e 100644 --- a/user/store.go +++ b/user/store.go @@ -32,22 +32,19 @@ type Store interface { ChangeSettings(userID string, prefs *Prefs) error ChangeTier(username, tierCode string) error ResetTier(username string) error - UpdateStats(userID string, stats *Stats) error + UpdateStats(stats map[string]*Stats) error ResetStats() error // Token operations - CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) + CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) Token(userID, token string) (*Token, error) Tokens(userID string) ([]*Token, error) AllProvisionedTokens() ([]*Token, error) - ChangeTokenLabel(userID, token, label string) error - ChangeTokenExpiry(userID, token string, expires time.Time) error + ChangeToken(userID, token, label string, expires time.Time) error UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error RemoveToken(userID, token string) error RemoveProvisionedToken(token string) error RemoveExpiredTokens() error - TokenCount(userID string) (int, error) - RemoveExcessTokens(userID string, maxCount int) error // Access operations AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) @@ -119,11 +116,9 @@ type storeQueries struct { // Token queries selectToken string selectTokens string - selectTokenCount string selectAllProvisionedTokens string upsertToken string - updateTokenLabel string - updateTokenExpiry string + updateToken string updateTokenLastAccess string deleteToken string deleteProvisionedToken string @@ -366,12 +361,19 @@ func (s *commonStore) ResetTier(username string) error { return err } -// UpdateStats updates the user statistics -func (s *commonStore) UpdateStats(userID string, stats *Stats) error { - if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil { +// UpdateStats updates statistics for one or more users in a single transaction +func (s *commonStore) UpdateStats(stats map[string]*Stats) error { + tx, err := s.db.Begin() + if err != nil { return err } - return nil + defer tx.Rollback() + for userID, update := range stats { + if _, err := tx.Exec(s.queries.updateUserStats, update.Messages, update.Emails, update.Calls, userID); err != nil { + return err + } + } + return tx.Commit() } // ResetStats resets all user stats in the user database @@ -443,9 +445,23 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { return user, nil } -// CreateToken creates a new token -func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) { - if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { +// CreateToken creates a new token and prunes excess tokens if the count exceeds maxTokenCount. +// If maxTokenCount is 0, no pruning is performed. +func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, maxTokenCount int, provisioned bool) (*Token, error) { + tx, err := s.db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + if _, err := tx.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil { + return nil, err + } + if maxTokenCount > 0 { + if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { + return nil, err + } + } + if err := tx.Commit(); err != nil { return nil, err } return &Token{ @@ -508,17 +524,9 @@ func (s *commonStore) AllProvisionedTokens() ([]*Token, error) { return tokens, nil } -// ChangeTokenLabel updates a token's label -func (s *commonStore) ChangeTokenLabel(userID, token, label string) error { - if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil { - return err - } - return nil -} - -// ChangeTokenExpiry updates a token's expiry time -func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error { - if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil { +// ChangeToken updates a token's label and expiry time +func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time) error { + if _, err := s.db.Exec(s.queries.updateToken, label, expires.Unix(), userID, token); err != nil { return err } return nil @@ -562,30 +570,6 @@ func (s *commonStore) RemoveExpiredTokens() error { return nil } -// TokenCount returns the number of tokens for a user -func (s *commonStore) TokenCount(userID string) (int, error) { - rows, err := s.db.Query(s.queries.selectTokenCount, userID) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - var count int - if err := rows.Scan(&count); err != nil { - return 0, err - } - return count, nil -} - -// RemoveExcessTokens deletes excess tokens beyond the specified maximum -func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error { - if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil { - return err - } - return nil -} func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) { var token, label, lastOrigin string var lastAccess, expires int64 diff --git a/user/store_postgres.go b/user/store_postgres.go index 229bfd6d..c97f4800 100644 --- a/user/store_postgres.go +++ b/user/store_postgres.go @@ -146,8 +146,7 @@ const ( ON CONFLICT (user_id, token) DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned ` - postgresUpdateTokenLabelQuery = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3` - postgresUpdateTokenExpiryQuery = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3` + postgresUpdateTokenQuery = `UPDATE user_token SET label = $1, expires = $2 WHERE user_id = $3 AND token = $4` postgresUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3` postgresDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = $1 AND token = $2` postgresDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = $1` @@ -251,11 +250,9 @@ func NewPostgresStore(db *sql.DB) (Store, error) { // Token queries selectToken: postgresSelectTokenQuery, selectTokens: postgresSelectTokensQuery, - selectTokenCount: postgresSelectTokenCountQuery, selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery, upsertToken: postgresUpsertTokenQuery, - updateTokenLabel: postgresUpdateTokenLabelQuery, - updateTokenExpiry: postgresUpdateTokenExpiryQuery, + updateToken: postgresUpdateTokenQuery, updateTokenLastAccess: postgresUpdateTokenLastAccessQuery, deleteToken: postgresDeleteTokenQuery, deleteProvisionedToken: postgresDeleteProvisionedTokenQuery, diff --git a/user/store_sqlite.go b/user/store_sqlite.go index 57c61f0c..602b88d4 100644 --- a/user/store_sqlite.go +++ b/user/store_sqlite.go @@ -144,8 +144,7 @@ const ( ON CONFLICT (user_id, token) DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned ` - sqliteUpdateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` - sqliteUpdateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` + sqliteUpdateTokenQuery = `UPDATE user_token SET label = ?, expires = ? WHERE user_id = ? AND token = ?` sqliteUpdateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` sqliteDeleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` sqliteDeleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?` @@ -255,11 +254,9 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { deleteAllAccess: sqliteDeleteAllAccessQuery, selectToken: sqliteSelectTokenQuery, selectTokens: sqliteSelectTokensQuery, - selectTokenCount: sqliteSelectTokenCountQuery, selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery, upsertToken: sqliteUpsertTokenQuery, - updateTokenLabel: sqliteUpdateTokenLabelQuery, - updateTokenExpiry: sqliteUpdateTokenExpiryQuery, + updateToken: sqliteUpdateTokenQuery, updateTokenLastAccess: sqliteUpdateTokenLastAccessQuery, deleteToken: sqliteDeleteTokenQuery, deleteProvisionedToken: sqliteDeleteProvisionedTokenQuery, diff --git a/user/store_test.go b/user/store_test.go index 7dc3ef38..1294cc86 100644 --- a/user/store_test.go +++ b/user/store_test.go @@ -78,7 +78,7 @@ func TestStoreUserByToken(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false) + tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), 0, false) require.Nil(t, err) require.Equal(t, "tk_test123", tk.Value) @@ -165,7 +165,7 @@ func TestStoreTokens(t *testing.T) { expires := now.Add(24 * time.Hour) origin := netip.MustParseAddr("9.9.9.9") - tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false) + tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, 0, false) require.Nil(t, err) require.Equal(t, "tk_abc", tk.Value) require.Equal(t, "my token", tk.Label) @@ -181,27 +181,25 @@ func TestStoreTokens(t *testing.T) { require.Nil(t, err) require.Len(t, tokens, 1) require.Equal(t, "tk_abc", tokens[0].Value) - - // Token count - count, err := store.TokenCount(u.ID) - require.Nil(t, err) - require.Equal(t, 1, count) }) } -func TestStoreTokenChangeLabel(t *testing.T) { +func TestStoreTokenChange(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, store user.Store) { require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) u, err := store.User("phil") require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + expires := time.Now().Add(time.Hour) + _, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), expires, 0, false) require.Nil(t, err) - require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label")) + newExpires := time.Now().Add(2 * time.Hour) + require.Nil(t, store.ChangeToken(u.ID, "tk_abc", "new label", newExpires)) tk, err := store.Token(u.ID, "tk_abc") require.Nil(t, err) require.Equal(t, "new label", tk.Label) + require.Equal(t, newExpires.Unix(), tk.Expires.Unix()) }) } @@ -211,7 +209,7 @@ func TestStoreTokenRemove(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) require.Nil(t, err) require.Nil(t, store.RemoveToken(u.ID, "tk_abc")) @@ -227,9 +225,9 @@ func TestStoreTokenRemoveExpired(t *testing.T) { require.Nil(t, err) // Create expired token and active token - _, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false) + _, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), 0, false) require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + _, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) require.Nil(t, err) require.Nil(t, store.RemoveExpiredTokens()) @@ -245,28 +243,25 @@ func TestStoreTokenRemoveExpired(t *testing.T) { }) } -func TestStoreTokenRemoveExcess(t *testing.T) { +func TestStoreTokenCreatePrunesExcess(t *testing.T) { forEachStoreBackend(t, func(t *testing.T, store user.Store) { require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false)) u, err := store.User("phil") require.Nil(t, err) - // Create 3 tokens with increasing expiry - for i, name := range []string{"tk_a", "tk_b", "tk_c"} { - _, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false) + // Create 2 tokens with no pruning + for i, name := range []string{"tk_a", "tk_b"} { + _, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), 0, false) require.Nil(t, err) } - count, err := store.TokenCount(u.ID) + // Create a 3rd token with maxTokenCount=2, which should prune tk_a (earliest expiry) + _, err = store.CreateToken(u.ID, "tk_c", "tk_c", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(3*time.Hour), 2, false) require.Nil(t, err) - require.Equal(t, 3, count) - // Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c) - require.Nil(t, store.RemoveExcessTokens(u.ID, 2)) - - count, err = store.TokenCount(u.ID) + tokens, err := store.Tokens(u.ID) require.Nil(t, err) - require.Equal(t, 2, count) + require.Equal(t, 2, len(tokens)) // tk_a should be removed (earliest expiry) _, err = store.Token(u.ID, "tk_a") @@ -286,7 +281,7 @@ func TestStoreTokenUpdateLastAccess(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false) + _, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), 0, false) require.Nil(t, err) newTime := time.Now().Add(5 * time.Minute) @@ -628,7 +623,7 @@ func TestStoreUpdateStats(t *testing.T) { require.Nil(t, err) stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1} - require.Nil(t, store.UpdateStats(u.ID, stats)) + require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: stats})) u2, err := store.User("phil") require.Nil(t, err) @@ -644,7 +639,7 @@ func TestStoreResetStats(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1})) + require.Nil(t, store.UpdateStats(map[string]*user.Stats{u.ID: {Messages: 42, Emails: 3, Calls: 1}})) require.Nil(t, store.ResetStats()) u2, err := store.User("phil")