From 9f987e66faa6417668ba3d8fffccfec98c9d925e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 31 Jul 2025 08:36:05 +0200 Subject: [PATCH] Make sure tokens are updated instead of deleted/re-added --- cmd/webpush_test.go | 5 +- user/manager.go | 266 +++++++++++++++++++++++++++---------------- user/manager_test.go | 26 ++++- 3 files changed, 197 insertions(+), 100 deletions(-) diff --git a/cmd/webpush_test.go b/cmd/webpush_test.go index 01e1a7a1..3c1de4f2 100644 --- a/cmd/webpush_test.go +++ b/cmd/webpush_test.go @@ -1,6 +1,7 @@ package cmd import ( + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -15,10 +16,12 @@ func TestCLI_WebPush_GenerateKeys(t *testing.T) { } func TestCLI_WebPush_WriteKeysToFile(t *testing.T) { + tempDir := t.TempDir() + t.Chdir(tempDir) app, _, _, stderr := newTestApp() require.Nil(t, runWebPushCommand(app, server.NewConfig(), "keys", "--output-file=key-file.yaml")) require.Contains(t, stderr.String(), "Web Push keys written to key-file.yaml") - require.FileExists(t, "key-file.yaml") + require.FileExists(t, filepath.Join(tempDir, "key-file.yaml")) } func runWebPushCommand(app *cli.App, conf *server.Config, args ...string) error { diff --git a/user/manager.go b/user/manager.go index 5e68b177..9d7219f6 100644 --- a/user/manager.go +++ b/user/manager.go @@ -13,6 +13,7 @@ import ( "heckel.io/ntfy/v2/util" "net/netip" "path/filepath" + "slices" "strings" "sync" "time" @@ -258,23 +259,24 @@ const ( AND topic = ? ` - selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` - selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?` - selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?` - upsertTokenQuery = ` + selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` + selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?` + selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?` + selectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1` + upsertTokenQuery = ` INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned) VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT (user_id, token) DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned; ` - updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` - updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` - updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` - deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` - deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` - deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1` - deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` - deleteExcessTokensQuery = ` + updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` + updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` + updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` + deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` + deleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?` + deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` + deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` + deleteExcessTokensQuery = ` DELETE FROM user_token WHERE user_id = ? AND (user_id, token) NOT IN ( @@ -711,6 +713,25 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) { return tokens, nil } +func (a *Manager) allProvisionedTokens() ([]*Token, error) { + rows, err := a.db.Query(selectAllProvisionedTokensQuery) + if err != nil { + return nil, err + } + defer rows.Close() + tokens := make([]*Token, 0) + for { + token, err := a.readToken(rows) + if errors.Is(err, ErrTokenNotFound) { + break + } else if err != nil { + return nil, err + } + tokens = append(tokens, token) + } + return tokens, nil +} + // Token returns a specific token for a user func (a *Manager) Token(userID, token string) (*Token, error) { rows, err := a.db.Query(selectTokenQuery, userID, token) @@ -775,10 +796,16 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time // RemoveToken deletes the token defined in User.Token func (a *Manager) RemoveToken(userID, token string) error { + return execTx(a.db, func(tx *sql.Tx) error { + return a.removeTokenTx(tx, userID, token) + }) +} + +func (a *Manager) removeTokenTx(tx *sql.Tx, userID, token string) error { if token == "" { return errNoTokenProvided } - if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil { + if _, err := tx.Exec(deleteTokenQuery, userID, token); err != nil { return err } return nil @@ -952,13 +979,20 @@ func (a *Manager) writeTokenUpdateQueue() error { log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) for tokenID, update := range tokenQueue { log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) - if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil { + if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil { return err } } return tx.Commit() } +func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error { + if _, err := tx.Exec(updateTokenLastAccessQuery, lastAccess, lastOrigin, token); err != nil { + return err + } + return nil +} + // Authorize returns nil if the given user has access to the given topic using the desired // permission. The user param may be nil to signal an anonymous user. func (a *Manager) Authorize(user *User, topic string, perm Permission) error { @@ -1706,7 +1740,7 @@ func (a *Manager) maybeProvisionUsersAndAccess() error { if !a.config.ProvisionEnabled { return nil } - users, err := a.Users() + existingUsers, err := a.Users() if err != nil { return err } @@ -1714,94 +1748,134 @@ func (a *Manager) maybeProvisionUsersAndAccess() error { return u.Name }) return execTx(a.db, func(tx *sql.Tx) error { - // Remove users that are provisioned, but not in the config anymore - for _, user := range users { - if user.Name == Everyone { - continue - } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { - if err := a.removeUserTx(tx, user.Name); err != nil { - return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) - } - } + if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil { + return fmt.Errorf("failed to provision users: %v", err) } - // Add or update provisioned users - for _, user := range a.config.Users { - if user.Name == Everyone { - continue - } - existingUser, exists := util.Find(users, func(u *User) bool { - return u.Name == user.Name - }) - if !exists { - if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { - return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) - } - } else { - if !existingUser.Provisioned { - if err := a.changeProvisionedTx(tx, user.Name, true); err != nil { - return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) - } - } - if existingUser.Hash != user.Hash { - if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil { - return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) - } - } - if existingUser.Role != user.Role { - if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil { - return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) - } - } - } + if err := a.maybeProvisionGrants(tx); err != nil { + return fmt.Errorf("failed to provision grants: %v", err) } - // Remove and (re-)add provisioned grants - if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil { - return err - } - for username, grants := range a.config.Access { - user, exists := util.Find(a.config.Users, func(u *User) bool { - return u.Name == username - }) - if !exists && username != Everyone { - return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username) - } else if user != nil && user.Role == RoleAdmin { - return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username) - } - for _, grant := range grants { - if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil { - return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err) - } - if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil { - return err - } - } - } - // Remove and (re-)add provisioned tokens - if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil { - return err - } - for username, tokens := range a.config.Tokens { - _, exists := util.Find(a.config.Users, func(u *User) bool { - return u.Name == username - }) - if !exists && username != Everyone { - return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) - } - var userID string - row := tx.QueryRow(selectUserIDFromUsernameQuery, username) - if err := row.Scan(&userID); err != nil { - return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username) - } - for _, token := range tokens { - if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil { - return err - } - } + if err := a.maybeProvisionTokens(tx, provisionUsernames); err != nil { + return fmt.Errorf("failed to provision tokens: %v", err) } return nil }) } +// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them. +// It also removes users that are provisioned, but not in the config anymore. +func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error { + // Remove users that are provisioned, but not in the config anymore + for _, user := range existingUsers { + if user.Name == Everyone { + continue + } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) { + if err := a.removeUserTx(tx, user.Name); err != nil { + return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err) + } + } + } + // Add or update provisioned users + for _, user := range a.config.Users { + if user.Name == Everyone { + continue + } + existingUser, exists := util.Find(existingUsers, func(u *User) bool { + return u.Name == user.Name + }) + if !exists { + if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) { + return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err) + } + } else { + if !existingUser.Provisioned { + if err := a.changeProvisionedTx(tx, user.Name, true); err != nil { + return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err) + } + } + if existingUser.Hash != user.Hash { + if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil { + return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err) + } + } + if existingUser.Role != user.Role { + if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil { + return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err) + } + } + } + } + return nil +} + +// maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config. +// +// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last +// access time) or do not have dependent resources (such as grants or tokens). +func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error { + // Remove all provisioned grants + if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil { + return err + } + // (Re-)add provisioned grants + for username, grants := range a.config.Access { + user, exists := util.Find(a.config.Users, func(u *User) bool { + return u.Name == username + }) + if !exists && username != Everyone { + return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username) + } else if user != nil && user.Role == RoleAdmin { + return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username) + } + for _, grant := range grants { + if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil { + return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err) + } + if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil { + return err + } + } + } + return nil +} + +func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) error { + // Remove tokens that are provisioned, but not in the config anymore + existingTokens, err := a.allProvisionedTokens() + if err != nil { + return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err) + } + var provisionTokens []string + for _, userTokens := range a.config.Tokens { + for _, token := range userTokens { + provisionTokens = append(provisionTokens, token.Value) + } + } + for _, existingToken := range existingTokens { + if !slices.Contains(provisionTokens, existingToken.Value) { + if _, err := tx.Exec(deleteProvisionedTokenQuery, existingToken.Value); err != nil { + return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err) + } + } + } + // (Re-)add provisioned tokens + for username, tokens := range a.config.Tokens { + if !slices.Contains(provisionUsernames, username) && username != Everyone { + return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username) + } + var userID string + row := tx.QueryRow(selectUserIDFromUsernameQuery, username) + if err := row.Scan(&userID); err != nil { + return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username) + } + for _, token := range tokens { + if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil { + return err + } + } + } + return nil +} + // toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards, // and escapes '_', assuming '\' as escape character. func toSQLWildcard(s string) string { diff --git a/user/manager_test.go b/user/manager_test.go index f93e51fe..d51b9b96 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -1152,6 +1152,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) { require.Equal(t, "Alerts token", tokens[0].Label) require.True(t, tokens[0].Provisioned) + // Update the token last access time and origin (so we can check that it is persisted) + lastAccessTime := time.Now().Add(time.Hour) + lastOrigin := netip.MustParseAddr("1.1.9.9") + err = execTx(a.db, func(tx *sql.Tx) error { + return a.updateTokenLastAccessTx(tx, tokens[0].Value, lastAccessTime.Unix(), lastOrigin.String()) + }) + require.Nil(t, err) + // Re-open the DB (second app start) require.Nil(t, a.db.Close()) conf.Users = []*User{ @@ -1165,7 +1173,8 @@ func TestManager_WithProvisionedUsers(t *testing.T) { } conf.Tokens = map[string][]*Token{ "philuser": { - {Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"}, + {Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token updated"}, + {Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"}, }, } a, err = NewManager(conf) @@ -1191,10 +1200,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) { tokens, err = a.Tokens(provisionedUserID) require.Nil(t, err) - require.Equal(t, 1, len(tokens)) - require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value) + require.Equal(t, 2, len(tokens)) + require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value) require.Equal(t, "Alerts token updated", tokens[0].Label) + require.Equal(t, lastAccessTime.Unix(), tokens[0].LastAccess.Unix()) + require.Equal(t, lastOrigin, tokens[0].LastOrigin) require.True(t, tokens[0].Provisioned) + require.Equal(t, "tk_u48wqendnkx9er21pqqcadlytbutx", tokens[1].Value) + require.Equal(t, "Another token", tokens[1].Label) // Re-open the DB again (third app start) require.Nil(t, a.db.Close()) @@ -1220,6 +1233,13 @@ func TestManager_WithProvisionedUsers(t *testing.T) { tokens, err = a.Tokens(provisionedUserID) require.Nil(t, err) require.Equal(t, 0, len(tokens)) + + var count int + a.db.QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count) + require.Equal(t, 0, count) + a.db.QueryRow("SELECT COUNT(*) FROM user_grant WHERE provisioned = 1").Scan(&count) + require.Equal(t, 0, count) + a.db.QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count) } func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {