From 28a436c0d2eda18f4e6e68afbfc2cc23e91dd27c Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 23 Feb 2026 14:11:36 -0500 Subject: [PATCH] Optimizations --- user/manager.go | 5 +---- user/manager_test.go | 2 +- user/store.go | 26 ++++++++++++++++++++------ user/store_postgres.go | 1 + user/store_sqlite.go | 1 + user/store_test.go | 2 +- 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/user/manager.go b/user/manager.go index ce890317..2a962be2 100644 --- a/user/manager.go +++ b/user/manager.go @@ -281,11 +281,8 @@ func (a *Manager) writeTokenUpdateQueue() error { log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) - if err := a.store.UpdateTokenLastAccess(tokenID, update.LastAccess, update.LastOrigin); err != nil { - return err - } } - return nil + return a.store.UpdateTokenLastAccess(tokenQueue) } // Authorize returns nil if the given user has access to the given topic using the desired diff --git a/user/manager_test.go b/user/manager_test.go index b1858ad4..ea976a3f 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -1218,7 +1218,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) { // Update the token last access time and origin (so we can check that it is persisted) lastAccessTime := time.Now().Add(time.Hour) lastOrigin := netip.MustParseAddr("1.1.9.9") - err = a.store.UpdateTokenLastAccess(tokens[0].Value, lastAccessTime, lastOrigin) + err = a.store.UpdateTokenLastAccess(map[string]*TokenUpdate{tokens[0].Value: {LastAccess: lastAccessTime, LastOrigin: lastOrigin}}) require.Nil(t, err) // Re-open the DB (second app start) diff --git a/user/store.go b/user/store.go index f03c883e..fd9a3fbb 100644 --- a/user/store.go +++ b/user/store.go @@ -41,7 +41,7 @@ type Store interface { Tokens(userID string) ([]*Token, error) AllProvisionedTokens() ([]*Token, error) ChangeToken(userID, token, label string, expires time.Time) error - UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error + UpdateTokenLastAccess(updates map[string]*TokenUpdate) error RemoveToken(userID, token string) error RemoveProvisionedToken(token string) error RemoveExpiredTokens() error @@ -116,6 +116,7 @@ type storeQueries struct { // Token queries selectToken string selectTokens string + selectTokenCount string selectAllProvisionedTokens string upsertToken string updateToken string @@ -457,9 +458,15 @@ func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.T return nil, err } if maxTokenCount > 0 { - if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { + var tokenCount int + if err := tx.QueryRow(s.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil { return nil, err } + if tokenCount > maxTokenCount { + if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { + return nil, err + } + } } if err := tx.Commit(); err != nil { return nil, err @@ -532,12 +539,19 @@ func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time return nil } -// UpdateTokenLastAccess updates a token's last access time and origin -func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error { - if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil { +// UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction +func (s *commonStore) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error { + tx, err := s.db.Begin() + if err != nil { return err } - return nil + defer tx.Rollback() + for token, update := range updates { + if _, err := tx.Exec(s.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil { + return err + } + } + return tx.Commit() } // RemoveToken deletes the token diff --git a/user/store_postgres.go b/user/store_postgres.go index c97f4800..c4c9577c 100644 --- a/user/store_postgres.go +++ b/user/store_postgres.go @@ -250,6 +250,7 @@ func NewPostgresStore(db *sql.DB) (Store, error) { // Token queries selectToken: postgresSelectTokenQuery, selectTokens: postgresSelectTokensQuery, + selectTokenCount: postgresSelectTokenCountQuery, selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery, upsertToken: postgresUpsertTokenQuery, updateToken: postgresUpdateTokenQuery, diff --git a/user/store_sqlite.go b/user/store_sqlite.go index 602b88d4..c9d6d33f 100644 --- a/user/store_sqlite.go +++ b/user/store_sqlite.go @@ -254,6 +254,7 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) { deleteAllAccess: sqliteDeleteAllAccessQuery, selectToken: sqliteSelectTokenQuery, selectTokens: sqliteSelectTokensQuery, + selectTokenCount: sqliteSelectTokenCountQuery, selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery, upsertToken: sqliteUpsertTokenQuery, updateToken: sqliteUpdateTokenQuery, diff --git a/user/store_test.go b/user/store_test.go index 1294cc86..c5c40a45 100644 --- a/user/store_test.go +++ b/user/store_test.go @@ -286,7 +286,7 @@ func TestStoreTokenUpdateLastAccess(t *testing.T) { newTime := time.Now().Add(5 * time.Minute) newOrigin := netip.MustParseAddr("5.5.5.5") - require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin)) + require.Nil(t, store.UpdateTokenLastAccess(map[string]*user.TokenUpdate{"tk_abc": {LastAccess: newTime, LastOrigin: newOrigin}})) tk, err := store.Token(u.ID, "tk_abc") require.Nil(t, err)