Optimizations

This commit is contained in:
binwiederhier
2026-02-23 14:11:36 -05:00
parent b02366b42b
commit 28a436c0d2
6 changed files with 25 additions and 12 deletions

View File

@@ -281,11 +281,8 @@ func (a *Manager) writeTokenUpdateQueue() error {
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue { for tokenID, update := range tokenQueue {
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
if err := a.store.UpdateTokenLastAccess(tokenID, update.LastAccess, update.LastOrigin); err != nil {
return err
}
} }
return nil return a.store.UpdateTokenLastAccess(tokenQueue)
} }
// Authorize returns nil if the given user has access to the given topic using the desired // Authorize returns nil if the given user has access to the given topic using the desired

View File

@@ -1218,7 +1218,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
// Update the token last access time and origin (so we can check that it is persisted) // Update the token last access time and origin (so we can check that it is persisted)
lastAccessTime := time.Now().Add(time.Hour) lastAccessTime := time.Now().Add(time.Hour)
lastOrigin := netip.MustParseAddr("1.1.9.9") lastOrigin := netip.MustParseAddr("1.1.9.9")
err = a.store.UpdateTokenLastAccess(tokens[0].Value, lastAccessTime, lastOrigin) err = a.store.UpdateTokenLastAccess(map[string]*TokenUpdate{tokens[0].Value: {LastAccess: lastAccessTime, LastOrigin: lastOrigin}})
require.Nil(t, err) require.Nil(t, err)
// Re-open the DB (second app start) // Re-open the DB (second app start)

View File

@@ -41,7 +41,7 @@ type Store interface {
Tokens(userID string) ([]*Token, error) Tokens(userID string) ([]*Token, error)
AllProvisionedTokens() ([]*Token, error) AllProvisionedTokens() ([]*Token, error)
ChangeToken(userID, token, label string, expires time.Time) error ChangeToken(userID, token, label string, expires time.Time) error
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error UpdateTokenLastAccess(updates map[string]*TokenUpdate) error
RemoveToken(userID, token string) error RemoveToken(userID, token string) error
RemoveProvisionedToken(token string) error RemoveProvisionedToken(token string) error
RemoveExpiredTokens() error RemoveExpiredTokens() error
@@ -116,6 +116,7 @@ type storeQueries struct {
// Token queries // Token queries
selectToken string selectToken string
selectTokens string selectTokens string
selectTokenCount string
selectAllProvisionedTokens string selectAllProvisionedTokens string
upsertToken string upsertToken string
updateToken string updateToken string
@@ -457,9 +458,15 @@ func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.T
return nil, err return nil, err
} }
if maxTokenCount > 0 { if maxTokenCount > 0 {
if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil { var tokenCount int
if err := tx.QueryRow(s.queries.selectTokenCount, userID).Scan(&tokenCount); err != nil {
return nil, err return nil, err
} }
if tokenCount > maxTokenCount {
if _, err := tx.Exec(s.queries.deleteExcessTokens, userID, userID, maxTokenCount); err != nil {
return nil, err
}
}
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return nil, err return nil, err
@@ -532,12 +539,19 @@ func (s *commonStore) ChangeToken(userID, token, label string, expires time.Time
return nil return nil
} }
// UpdateTokenLastAccess updates a token's last access time and origin // UpdateTokenLastAccess updates the last access time and origin for one or more tokens in a single transaction
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error { func (s *commonStore) UpdateTokenLastAccess(updates map[string]*TokenUpdate) error {
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil { tx, err := s.db.Begin()
if err != nil {
return err return err
} }
return nil defer tx.Rollback()
for token, update := range updates {
if _, err := tx.Exec(s.queries.updateTokenLastAccess, update.LastAccess.Unix(), update.LastOrigin.String(), token); err != nil {
return err
}
}
return tx.Commit()
} }
// RemoveToken deletes the token // RemoveToken deletes the token

View File

@@ -250,6 +250,7 @@ func NewPostgresStore(db *sql.DB) (Store, error) {
// Token queries // Token queries
selectToken: postgresSelectTokenQuery, selectToken: postgresSelectTokenQuery,
selectTokens: postgresSelectTokensQuery, selectTokens: postgresSelectTokensQuery,
selectTokenCount: postgresSelectTokenCountQuery,
selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery, selectAllProvisionedTokens: postgresSelectAllProvisionedTokensQuery,
upsertToken: postgresUpsertTokenQuery, upsertToken: postgresUpsertTokenQuery,
updateToken: postgresUpdateTokenQuery, updateToken: postgresUpdateTokenQuery,

View File

@@ -254,6 +254,7 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
deleteAllAccess: sqliteDeleteAllAccessQuery, deleteAllAccess: sqliteDeleteAllAccessQuery,
selectToken: sqliteSelectTokenQuery, selectToken: sqliteSelectTokenQuery,
selectTokens: sqliteSelectTokensQuery, selectTokens: sqliteSelectTokensQuery,
selectTokenCount: sqliteSelectTokenCountQuery,
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery, selectAllProvisionedTokens: sqliteSelectAllProvisionedTokensQuery,
upsertToken: sqliteUpsertTokenQuery, upsertToken: sqliteUpsertTokenQuery,
updateToken: sqliteUpdateTokenQuery, updateToken: sqliteUpdateTokenQuery,

View File

@@ -286,7 +286,7 @@ func TestStoreTokenUpdateLastAccess(t *testing.T) {
newTime := time.Now().Add(5 * time.Minute) newTime := time.Now().Add(5 * time.Minute)
newOrigin := netip.MustParseAddr("5.5.5.5") newOrigin := netip.MustParseAddr("5.5.5.5")
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin)) require.Nil(t, store.UpdateTokenLastAccess(map[string]*user.TokenUpdate{"tk_abc": {LastAccess: newTime, LastOrigin: newOrigin}}))
tk, err := store.Token(u.ID, "tk_abc") tk, err := store.Token(u.ID, "tk_abc")
require.Nil(t, err) require.Nil(t, err)