Add "auth-tokens"
This commit is contained in:
165
user/manager.go
165
user/manager.go
@@ -111,9 +111,11 @@ const (
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
@@ -181,16 +183,17 @@ const (
|
||||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
||||
selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
|
||||
updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
||||
updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
||||
deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
||||
|
||||
upsertUserAccessQuery = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
@@ -220,7 +223,7 @@ const (
|
||||
selectUserReservationsCountQuery = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
selectUserReservationsOwnerQuery = `
|
||||
@@ -255,17 +258,23 @@ const (
|
||||
AND topic = ?
|
||||
`
|
||||
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
upsertTokenQuery = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||
`
|
||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||
deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
|
||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
deleteExcessTokensQuery = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
@@ -470,7 +479,7 @@ const (
|
||||
role,
|
||||
prefs,
|
||||
sync_topic,
|
||||
0,
|
||||
0, -- provisioned
|
||||
stats_messages,
|
||||
stats_emails,
|
||||
stats_calls,
|
||||
@@ -480,7 +489,8 @@ const (
|
||||
stripe_subscription_interval,
|
||||
stripe_subscription_paid_until,
|
||||
stripe_subscription_cancel_at,
|
||||
created, deleted
|
||||
created,
|
||||
deleted
|
||||
FROM user_old;
|
||||
DROP TABLE user_old;
|
||||
|
||||
@@ -500,10 +510,27 @@ const (
|
||||
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
||||
DROP TABLE user_access_old;
|
||||
|
||||
-- Alter user_token table: Add provisioned column
|
||||
ALTER TABLE user_token RENAME TO user_token_old;
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
||||
DROP TABLE user_token_old;
|
||||
|
||||
-- Recreate indices
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
|
||||
-- Re-enable foreign keys
|
||||
PRAGMA foreign_keys=on;
|
||||
@@ -537,7 +564,8 @@ type Config struct {
|
||||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
@@ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
||||
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||
// given user, if there are too many of them.
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
|
||||
token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses
|
||||
tx, err := a.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||
return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||
access := time.Now()
|
||||
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
|
||||
if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := tx.Query(selectTokenCountQuery, userID)
|
||||
@@ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: access,
|
||||
LastOrigin: origin,
|
||||
Expires: expires,
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: access,
|
||||
LastOrigin: origin,
|
||||
Expires: expires,
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
||||
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
var provisioned bool
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
@@ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
||||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
||||
phoneNumbers := make([]string, 0)
|
||||
for {
|
||||
phoneNumber, err := a.readPhoneNumber(rows)
|
||||
if err == ErrPhoneNumberNotFound {
|
||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
@@ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove and (re-)add provisioned tokens
|
||||
if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
for username, tokens := range a.config.Tokens {
|
||||
_, exists := util.Find(a.config.Users, func(u *User) bool {
|
||||
return u.Name == username
|
||||
})
|
||||
if !exists && username != Everyone {
|
||||
return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
||||
}
|
||||
var userID string
|
||||
row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
|
||||
if err := row.Scan(&userID); err != nil {
|
||||
return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
|
||||
}
|
||||
for _, token := range tokens {
|
||||
if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// queryTx executes a function in a transaction and returns the result. If the function
|
||||
// returns an error, the transaction is rolled back.
|
||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
t, err := f(tx)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return t, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user