Compare commits
6 Commits
postgres-w
...
postgres-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae5e1fe8d8 | ||
|
|
e3a402ed95 | ||
|
|
1abc1005d0 | ||
|
|
909c3fe17b | ||
|
|
07c3e280bf | ||
|
|
60fa50f0d5 |
25
cmd/user.go
25
cmd/user.go
@@ -29,6 +29,7 @@ var flagsUser = append(
|
||||
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
|
||||
)
|
||||
|
||||
var cmdUser = &cli.Command{
|
||||
@@ -365,24 +366,32 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
authFile := c.String("auth-file")
|
||||
authStartupQueries := c.String("auth-startup-queries")
|
||||
authDefaultAccess := c.String("auth-default-access")
|
||||
if authFile == "" {
|
||||
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
|
||||
} else if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
databaseURL := c.String("database-url")
|
||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||
if err != nil {
|
||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||
}
|
||||
authConfig := &user.Config{
|
||||
Filename: authFile,
|
||||
StartupQueries: authStartupQueries,
|
||||
DefaultAccess: authDefault,
|
||||
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
return user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if databaseURL != "" {
|
||||
store, err = user.NewPostgresStore(databaseURL)
|
||||
} else if authFile != "" {
|
||||
if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
|
||||
} else {
|
||||
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.NewManager(store, authConfig)
|
||||
}
|
||||
|
||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||
|
||||
@@ -204,9 +204,10 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
||||
authConfig := &user.Config{
|
||||
Filename: conf.AuthFile,
|
||||
DatabaseURL: conf.DatabaseURL,
|
||||
StartupQueries: conf.AuthStartupQueries,
|
||||
DefaultAccess: conf.AuthDefault,
|
||||
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||
@@ -216,7 +217,16 @@ func New(conf *Config) (*Server, error) {
|
||||
BcryptCost: conf.AuthBcryptCost,
|
||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||
}
|
||||
userManager, err = user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if conf.DatabaseURL != "" {
|
||||
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
||||
} else {
|
||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userManager, err = user.NewManager(store, authConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
1193
user/manager.go
1193
user/manager.go
File diff suppressed because it is too large
Load Diff
2240
user/manager_test.go
2240
user/manager_test.go
File diff suppressed because it is too large
Load Diff
986
user/store.go
Normal file
986
user/store.go
Normal file
@@ -0,0 +1,986 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Store is the interface for a user database store
|
||||
type Store interface {
|
||||
// User operations
|
||||
UserByID(id string) (*User, error)
|
||||
User(username string) (*User, error)
|
||||
UserByToken(token string) (*User, error)
|
||||
UserByStripeCustomer(customerID string) (*User, error)
|
||||
UserIDByUsername(username string) (string, error)
|
||||
Users() ([]*User, error)
|
||||
UsersCount() (int64, error)
|
||||
AddUser(username, hash string, role Role, provisioned bool) error
|
||||
RemoveUser(username string) error
|
||||
MarkUserRemoved(userID string) error
|
||||
RemoveDeletedUsers() error
|
||||
ChangePassword(username, hash string) error
|
||||
ChangeRole(username string, role Role) error
|
||||
ChangeProvisioned(username string, provisioned bool) error
|
||||
ChangeSettings(userID string, prefs *Prefs) error
|
||||
ChangeTier(username, tierCode string) error
|
||||
ResetTier(username string) error
|
||||
UpdateStats(userID string, stats *Stats) error
|
||||
ResetStats() error
|
||||
|
||||
// Token operations
|
||||
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
|
||||
Token(userID, token string) (*Token, error)
|
||||
Tokens(userID string) ([]*Token, error)
|
||||
AllProvisionedTokens() ([]*Token, error)
|
||||
ChangeTokenLabel(userID, token, label string) error
|
||||
ChangeTokenExpiry(userID, token string, expires time.Time) error
|
||||
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
|
||||
RemoveToken(userID, token string) error
|
||||
RemoveExpiredTokens() error
|
||||
TokenCount(userID string) (int, error)
|
||||
RemoveExcessTokens(userID string, maxCount int) error
|
||||
|
||||
// Access operations
|
||||
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
|
||||
AllGrants() (map[string][]Grant, error)
|
||||
Grants(username string) ([]Grant, error)
|
||||
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
|
||||
ResetAccess(username, topicPattern string) error
|
||||
ResetAllProvisionedAccess() error
|
||||
Reservations(username string) ([]Reservation, error)
|
||||
HasReservation(username, topic string) (bool, error)
|
||||
ReservationsCount(username string) (int64, error)
|
||||
ReservationOwner(topic string) (string, error)
|
||||
OtherAccessCount(username, topic string) (int, error)
|
||||
|
||||
// Tier operations
|
||||
AddTier(tier *Tier) error
|
||||
UpdateTier(tier *Tier) error
|
||||
RemoveTier(code string) error
|
||||
Tiers() ([]*Tier, error)
|
||||
Tier(code string) (*Tier, error)
|
||||
TierByStripePrice(priceID string) (*Tier, error)
|
||||
|
||||
// Phone operations
|
||||
PhoneNumbers(userID string) ([]string, error)
|
||||
AddPhoneNumber(userID, phoneNumber string) error
|
||||
RemovePhoneNumber(userID, phoneNumber string) error
|
||||
|
||||
// Other stuff
|
||||
ChangeBilling(username string, billing *Billing) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries
|
||||
type storeQueries struct {
|
||||
// User queries
|
||||
selectUserByID string
|
||||
selectUserByName string
|
||||
selectUserByToken string
|
||||
selectUserByStripeID string
|
||||
selectUsernames string
|
||||
selectUserCount string
|
||||
selectUserIDFromUsername string
|
||||
insertUser string
|
||||
updateUserPass string
|
||||
updateUserRole string
|
||||
updateUserProvisioned string
|
||||
updateUserPrefs string
|
||||
updateUserStats string
|
||||
updateUserStatsResetAll string
|
||||
updateUserTier string
|
||||
updateUserDeleted string
|
||||
deleteUser string
|
||||
deleteUserTier string
|
||||
deleteUsersMarked string
|
||||
// Access queries
|
||||
selectTopicPerms string
|
||||
selectUserAllAccess string
|
||||
selectUserAccess string
|
||||
selectUserReservations string
|
||||
selectUserReservationsCount string
|
||||
selectUserReservationsOwner string
|
||||
selectUserHasReservation string
|
||||
selectOtherAccessCount string
|
||||
upsertUserAccess string
|
||||
deleteUserAccess string
|
||||
deleteUserAccessProvisioned string
|
||||
deleteTopicAccess string
|
||||
deleteAllAccess string
|
||||
// Token queries
|
||||
selectToken string
|
||||
selectTokens string
|
||||
selectTokenCount string
|
||||
selectAllProvisionedTokens string
|
||||
upsertToken string
|
||||
updateTokenLabel string
|
||||
updateTokenExpiry string
|
||||
updateTokenLastAccess string
|
||||
deleteToken string
|
||||
deleteProvisionedToken string
|
||||
deleteAllToken string
|
||||
deleteExpiredTokens string
|
||||
deleteExcessTokens string
|
||||
// Tier queries
|
||||
insertTier string
|
||||
selectTiers string
|
||||
selectTierByCode string
|
||||
selectTierByPriceID string
|
||||
updateTier string
|
||||
deleteTier string
|
||||
// Phone queries
|
||||
selectPhoneNumbers string
|
||||
insertPhoneNumber string
|
||||
deletePhoneNumber string
|
||||
// Billing queries
|
||||
updateBilling string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that work across database backends
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByID(id string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) User(username string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByToken(token string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// Users returns a list of users
|
||||
func (s *commonStore) Users() ([]*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUsernames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
rows.Close()
|
||||
users := make([]*User, 0)
|
||||
for _, username := range usernames {
|
||||
user, err := s.User(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UsersCount returns the number of users in the database
|
||||
func (s *commonStore) UsersCount() (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddUser adds a user with the given username, password hash and role
|
||||
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
now := time.Now().Unix()
|
||||
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrUserExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveUser deletes the user with the given username
|
||||
func (s *commonStore) RemoveUser(username string) error {
|
||||
if !AllowedUsername(username) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
|
||||
func (s *commonStore) MarkUserRemoved(userID string) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Get username for deleteUserAccess query
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||
func (s *commonStore) RemoveDeletedUsers() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password
|
||||
func (s *commonStore) ChangePassword(username, hash string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeRole changes a user's role
|
||||
func (s *commonStore) ChangeRole(username string, role Role) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
|
||||
return err
|
||||
}
|
||||
// If changing to admin, remove all access entries
|
||||
if role == RoleAdmin {
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ChangeProvisioned changes the provisioned status of a user
|
||||
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeSettings persists the user settings
|
||||
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
|
||||
b, err := json.Marshal(prefs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTier changes a user's tier using the tier code
|
||||
func (s *commonStore) ChangeTier(username, tierCode string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (s *commonStore) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteUserTier, username)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateStats updates the user statistics
|
||||
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetStats resets all user stats in the user database
|
||||
func (s *commonStore) ResetStats() error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var id, username, hash, role, prefs, syncTopic string
|
||||
var provisioned bool
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
||||
var messages, emails, calls int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := &User{
|
||||
ID: id,
|
||||
Name: username,
|
||||
Hash: hash,
|
||||
Role: Role(role),
|
||||
Prefs: &Prefs{},
|
||||
SyncTopic: syncTopic,
|
||||
Provisioned: provisioned,
|
||||
Stats: &Stats{
|
||||
Messages: messages,
|
||||
Emails: emails,
|
||||
Calls: calls,
|
||||
},
|
||||
Billing: &Billing{
|
||||
StripeCustomerID: stripeCustomerID.String,
|
||||
StripeSubscriptionID: stripeSubscriptionID.String,
|
||||
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
|
||||
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
|
||||
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
|
||||
},
|
||||
Deleted: deleted.Valid,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tierCode.Valid {
|
||||
user.Tier = &Tier{
|
||||
ID: tierID.String,
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token
|
||||
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
|
||||
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: lastAccess,
|
||||
LastOrigin: lastOrigin,
|
||||
Expires: expires,
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Token returns a specific token for a user
|
||||
func (s *commonStore) Token(userID, token string) (*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectToken, userID, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readToken(rows)
|
||||
}
|
||||
|
||||
// Tokens returns all existing tokens for the user with the given user ID
|
||||
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokens, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// AllProvisionedTokens returns all provisioned tokens
|
||||
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ChangeTokenLabel updates a token's label
|
||||
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTokenExpiry updates a token's expiry time
|
||||
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenLastAccess updates a token's last access time and origin
|
||||
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveToken deletes the token
|
||||
func (s *commonStore) RemoveToken(userID, token string) error {
|
||||
if token == "" {
|
||||
return errNoTokenProvided
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveExpiredTokens deletes all expired tokens from the database
|
||||
func (s *commonStore) RemoveExpiredTokens() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenCount returns the number of tokens for a user
|
||||
func (s *commonStore) TokenCount(userID string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
|
||||
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
var provisioned bool
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||
if err != nil {
|
||||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
|
||||
// The found return value indicates whether an ACL entry was found at all.
|
||||
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, false, false, nil
|
||||
}
|
||||
if err := rows.Scan(&read, &write); err != nil {
|
||||
return false, false, false, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
return read, write, true, nil
|
||||
}
|
||||
|
||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAllAccess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make(map[string][]Grant, 0)
|
||||
for rows.Next() {
|
||||
var userID, topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := grants[userID]; !ok {
|
||||
grants[userID] = make([]Grant, 0)
|
||||
}
|
||||
grants[userID] = append(grants[userID], Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// Grants returns all user-specific access control entries
|
||||
func (s *commonStore) Grants(username string) ([]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAccess, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make([]Grant, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grants = append(grants, Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// AllowAccess adds or updates an entry in the access control list
|
||||
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAccess removes an access control list entry
|
||||
func (s *commonStore) ResetAccess(username, topicPattern string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if username == "" && topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteAllAccess)
|
||||
return err
|
||||
} else if topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetAllProvisionedAccess removes all provisioned access control entries
|
||||
func (s *commonStore) ResetAllProvisionedAccess() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
reservations := make([]Reservation, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var ownerRead, ownerWrite bool
|
||||
var everyoneRead, everyoneWrite sql.NullBool
|
||||
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reservations = append(reservations, Reservation{
|
||||
Topic: unescapeUnderscore(topic),
|
||||
Owner: NewPermission(ownerRead, ownerWrite),
|
||||
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
|
||||
})
|
||||
}
|
||||
return reservations, nil
|
||||
}
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (s *commonStore) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||
func (s *commonStore) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", nil
|
||||
}
|
||||
var ownerUserID string
|
||||
if err := rows.Scan(&ownerUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddTier creates a new tier in the database
|
||||
func (s *commonStore) AddTier(tier *Tier) error {
|
||||
if tier.ID == "" {
|
||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTier updates a tier's properties in the database
|
||||
func (s *commonStore) UpdateTier(tier *Tier) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTier deletes the tier with the given code
|
||||
func (s *commonStore) RemoveTier(code string) error {
|
||||
if !AllowedTier(code) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tiers returns a list of all Tier structs
|
||||
func (s *commonStore) Tiers() ([]*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTiers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tiers := make([]*Tier, 0)
|
||||
for {
|
||||
tier, err := s.readTier(rows)
|
||||
if errors.Is(err, ErrTierNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tiers = append(tiers, tier)
|
||||
}
|
||||
return tiers, nil
|
||||
}
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) Tier(code string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
var id, code, name string
|
||||
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTierNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tier{
|
||||
ID: id,
|
||||
Code: code,
|
||||
Name: name,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
phoneNumbers := make([]string, 0)
|
||||
for {
|
||||
phoneNumber, err := s.readPhoneNumber(rows)
|
||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
phoneNumbers = append(phoneNumbers, phoneNumber)
|
||||
}
|
||||
return phoneNumbers, nil
|
||||
}
|
||||
|
||||
// AddPhoneNumber adds a phone number to the user with the given user ID
|
||||
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
|
||||
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrPhoneNumberExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePhoneNumber deletes a phone number from the user with the given user ID
|
||||
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
|
||||
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
|
||||
return err
|
||||
}
|
||||
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
|
||||
var phoneNumber string
|
||||
if !rows.Next() {
|
||||
return "", ErrPhoneNumberNotFound
|
||||
}
|
||||
if err := rows.Scan(&phoneNumber); err != nil {
|
||||
return "", err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return phoneNumber, nil
|
||||
}
|
||||
|
||||
// ChangeBilling updates a user's billing fields
|
||||
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
|
||||
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserIDByUsername returns the user ID for the given username
|
||||
func (s *commonStore) UserIDByUsername(username string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", ErrUserNotFound
|
||||
}
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database
|
||||
func (s *commonStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
|
||||
}
|
||||
292
user/store_postgres.go
Normal file
292
user/store_postgres.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL queries
|
||||
const (
|
||||
// User queries
|
||||
postgresSelectUserByID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = $1
|
||||
`
|
||||
postgresSelectUserByName = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user_name = $1
|
||||
`
|
||||
postgresSelectUserByToken = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||
`
|
||||
postgresSelectUserByStripeID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = $1
|
||||
`
|
||||
postgresSelectUsernames = `
|
||||
SELECT user_name
|
||||
FROM "user"
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user_name
|
||||
`
|
||||
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
|
||||
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
|
||||
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
|
||||
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
|
||||
|
||||
// Access queries
|
||||
postgresSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN "user" u ON u.id = a.user_id
|
||||
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||
`
|
||||
postgresSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
postgresSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
`
|
||||
postgresSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = $1
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
postgresSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
AND topic = $2
|
||||
`
|
||||
postgresSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||
`
|
||||
postgresUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES (
|
||||
(SELECT id FROM "user" WHERE user_name = $1),
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||
$7
|
||||
)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
|
||||
`
|
||||
postgresDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
`
|
||||
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
|
||||
postgresDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||
AND topic = $3
|
||||
`
|
||||
postgresDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||
postgresUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
|
||||
`
|
||||
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
|
||||
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
|
||||
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||
postgresDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = $1
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = $2
|
||||
ORDER BY expires DESC
|
||||
LIMIT $3
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
postgresInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`
|
||||
postgresUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||
WHERE code = $13
|
||||
`
|
||||
postgresSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
postgresSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = $1
|
||||
`
|
||||
postgresSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||
`
|
||||
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
|
||||
|
||||
// Phone queries
|
||||
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||
|
||||
// Billing queries
|
||||
postgresUpdateBilling = `
|
||||
UPDATE "user"
|
||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||
WHERE user_name = $7
|
||||
`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
||||
func NewPostgresStore(dsn string) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgres(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
// User queries
|
||||
selectUserByID: postgresSelectUserByID,
|
||||
selectUserByName: postgresSelectUserByName,
|
||||
selectUserByToken: postgresSelectUserByToken,
|
||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||
selectUsernames: postgresSelectUsernames,
|
||||
selectUserCount: postgresSelectUserCount,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||
insertUser: postgresInsertUser,
|
||||
updateUserPass: postgresUpdateUserPass,
|
||||
updateUserRole: postgresUpdateUserRole,
|
||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||
updateUserPrefs: postgresUpdateUserPrefs,
|
||||
updateUserStats: postgresUpdateUserStats,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||
updateUserTier: postgresUpdateUserTier,
|
||||
updateUserDeleted: postgresUpdateUserDeleted,
|
||||
deleteUser: postgresDeleteUser,
|
||||
deleteUserTier: postgresDeleteUserTier,
|
||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms: postgresSelectTopicPerms,
|
||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||
selectUserAccess: postgresSelectUserAccess,
|
||||
selectUserReservations: postgresSelectUserReservations,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||
upsertUserAccess: postgresUpsertUserAccess,
|
||||
deleteUserAccess: postgresDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||
deleteAllAccess: postgresDeleteAllAccess,
|
||||
|
||||
// Token queries
|
||||
selectToken: postgresSelectToken,
|
||||
selectTokens: postgresSelectTokens,
|
||||
selectTokenCount: postgresSelectTokenCount,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||
upsertToken: postgresUpsertToken,
|
||||
updateTokenLabel: postgresUpdateTokenLabel,
|
||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||
deleteToken: postgresDeleteToken,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||
deleteAllToken: postgresDeleteAllToken,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||
|
||||
// Tier queries
|
||||
insertTier: postgresInsertTier,
|
||||
selectTiers: postgresSelectTiers,
|
||||
selectTierByCode: postgresSelectTierByCode,
|
||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||
updateTier: postgresUpdateTier,
|
||||
deleteTier: postgresDeleteTier,
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||
|
||||
// Billing queries
|
||||
updateBilling: postgresUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
113
user/store_postgres_schema.go
Normal file
113
user/store_postgres_schema.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema table management queries for Postgres
|
||||
const (
|
||||
postgresCurrentSchemaVersion = 6
|
||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgres(db)
|
||||
}
|
||||
if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgres(db *sql.DB) error {
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
208
user/store_postgres_test.go
Normal file
208
user/store_postgres_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func newTestPostgresStore(t *testing.T) user.Store {
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||
}
|
||||
// Create a unique schema for this test
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
setupDB, err := sql.Open("pgx", dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupDB.Close())
|
||||
// Open store with search_path set to the new schema
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
store, err := user.NewPostgresStore(u.String())
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
store.Close()
|
||||
cleanDB, err := sql.Open("pgx", dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
}
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresStoreAddUser(t *testing.T) {
|
||||
testStoreAddUser(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
|
||||
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveUser(t *testing.T) {
|
||||
testStoreRemoveUser(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByID(t *testing.T) {
|
||||
testStoreUserByID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByToken(t *testing.T) {
|
||||
testStoreUserByToken(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
|
||||
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUsers(t *testing.T) {
|
||||
testStoreUsers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUsersCount(t *testing.T) {
|
||||
testStoreUsersCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangePassword(t *testing.T) {
|
||||
testStoreChangePassword(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeRole(t *testing.T) {
|
||||
testStoreChangeRole(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokens(t *testing.T) {
|
||||
testStoreTokens(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
|
||||
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemove(t *testing.T) {
|
||||
testStoreTokenRemove(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
|
||||
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
|
||||
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
|
||||
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllowAccess(t *testing.T) {
|
||||
testStoreAllowAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
|
||||
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetAccess(t *testing.T) {
|
||||
testStoreResetAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetAccessAll(t *testing.T) {
|
||||
testStoreResetAccessAll(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservations(t *testing.T) {
|
||||
testStoreReservations(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservationsCount(t *testing.T) {
|
||||
testStoreReservationsCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreHasReservation(t *testing.T) {
|
||||
testStoreHasReservation(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreReservationOwner(t *testing.T) {
|
||||
testStoreReservationOwner(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTiers(t *testing.T) {
|
||||
testStoreTiers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierUpdate(t *testing.T) {
|
||||
testStoreTierUpdate(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierRemove(t *testing.T) {
|
||||
testStoreTierRemove(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreTierByStripePrice(t *testing.T) {
|
||||
testStoreTierByStripePrice(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeTier(t *testing.T) {
|
||||
testStoreChangeTier(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStorePhoneNumbers(t *testing.T) {
|
||||
testStorePhoneNumbers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeSettings(t *testing.T) {
|
||||
testStoreChangeSettings(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreChangeBilling(t *testing.T) {
|
||||
testStoreChangeBilling(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreUpdateStats(t *testing.T) {
|
||||
testStoreUpdateStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreResetStats(t *testing.T) {
|
||||
testStoreResetStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
|
||||
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
|
||||
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreAllGrants(t *testing.T) {
|
||||
testStoreAllGrants(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStoreOtherAccessCount(t *testing.T) {
|
||||
testStoreOtherAccessCount(t, newTestPostgresStore(t))
|
||||
}
|
||||
273
user/store_sqlite.go
Normal file
273
user/store_sqlite.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
// User queries
|
||||
sqliteSelectUserByID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
sqliteSelectUserByName = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
sqliteSelectUserByToken = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||
`
|
||||
sqliteSelectUserByStripeID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
`
|
||||
sqliteSelectUsernames = `
|
||||
SELECT user
|
||||
FROM user
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
|
||||
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
|
||||
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
|
||||
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
|
||||
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
|
||||
|
||||
// Access queries
|
||||
sqliteSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN user u ON u.id = a.user_id
|
||||
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||
`
|
||||
sqliteSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
sqliteSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = ?
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
sqliteSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||
`
|
||||
sqliteUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||
`
|
||||
sqliteDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
|
||||
sqliteDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||
sqliteUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||
`
|
||||
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
|
||||
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
|
||||
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
sqliteDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = ?
|
||||
ORDER BY expires DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
sqliteInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
sqliteSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||
`
|
||||
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
|
||||
|
||||
// Phone queries
|
||||
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||
|
||||
// Billing queries
|
||||
sqliteUpdateBilling = `
|
||||
UPDATE user
|
||||
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||
WHERE user = ?
|
||||
`
|
||||
)
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed user store
|
||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
selectUserByID: sqliteSelectUserByID,
|
||||
selectUserByName: sqliteSelectUserByName,
|
||||
selectUserByToken: sqliteSelectUserByToken,
|
||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||
selectUsernames: sqliteSelectUsernames,
|
||||
selectUserCount: sqliteSelectUserCount,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||
insertUser: sqliteInsertUser,
|
||||
updateUserPass: sqliteUpdateUserPass,
|
||||
updateUserRole: sqliteUpdateUserRole,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||
updateUserStats: sqliteUpdateUserStats,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||
updateUserTier: sqliteUpdateUserTier,
|
||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||
deleteUser: sqliteDeleteUser,
|
||||
deleteUserTier: sqliteDeleteUserTier,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||
selectTopicPerms: sqliteSelectTopicPerms,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||
selectUserAccess: sqliteSelectUserAccess,
|
||||
selectUserReservations: sqliteSelectUserReservations,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||
upsertUserAccess: sqliteUpsertUserAccess,
|
||||
deleteUserAccess: sqliteDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||
deleteAllAccess: sqliteDeleteAllAccess,
|
||||
selectToken: sqliteSelectToken,
|
||||
selectTokens: sqliteSelectTokens,
|
||||
selectTokenCount: sqliteSelectTokenCount,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||
upsertToken: sqliteUpsertToken,
|
||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||
deleteToken: sqliteDeleteToken,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||
deleteAllToken: sqliteDeleteAllToken,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||
insertTier: sqliteInsertTier,
|
||||
selectTiers: sqliteSelectTiers,
|
||||
selectTierByCode: sqliteSelectTierByCode,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||
updateTier: sqliteUpdateTier,
|
||||
deleteTier: sqliteDeleteTier,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||
updateBilling: sqliteUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -2,19 +2,116 @@ package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
currentSchemaVersion = 6
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
|
||||
)
|
||||
|
||||
// Schema version table management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 6
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 1 -> 2 (complex migration!)
|
||||
migrate1To2CreateTablesQueries = `
|
||||
sqliteMigrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -82,12 +179,12 @@ const (
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
migrate1To2InsertUserNoTx = `
|
||||
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
sqliteMigrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
@@ -98,7 +195,7 @@ const (
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
migrate2To3UpdateQueries = `
|
||||
sqliteMigrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
@@ -108,7 +205,7 @@ const (
|
||||
`
|
||||
|
||||
// 3 -> 4
|
||||
migrate3To4UpdateQueries = `
|
||||
sqliteMigrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
@@ -120,12 +217,12 @@ const (
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
migrate4To5UpdateQueries = `
|
||||
sqliteMigrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
migrate5To6UpdateQueries = `
|
||||
sqliteMigrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
@@ -220,16 +317,60 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB) error{
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
func migrateFrom1(db *sql.DB) error {
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
@@ -237,11 +378,11 @@ func migrateFrom1(db *sql.DB) error {
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
|
||||
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -260,15 +401,15 @@ func migrateFrom1(db *sql.DB) error {
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
@@ -277,65 +418,65 @@ func migrateFrom1(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB) error {
|
||||
func sqliteMigrateFrom2(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB) error {
|
||||
func sqliteMigrateFrom3(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB) error {
|
||||
func sqliteMigrateFrom4(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB) error {
|
||||
func sqliteMigrateFrom5(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
180
user/store_sqlite_test.go
Normal file
180
user/store_sqlite_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
func newTestSQLiteStore(t *testing.T) user.Store {
|
||||
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { store.Close() })
|
||||
return store
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAddUser(t *testing.T) {
|
||||
testStoreAddUser(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
|
||||
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveUser(t *testing.T) {
|
||||
testStoreRemoveUser(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByID(t *testing.T) {
|
||||
testStoreUserByID(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByToken(t *testing.T) {
|
||||
testStoreUserByToken(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
|
||||
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUsers(t *testing.T) {
|
||||
testStoreUsers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUsersCount(t *testing.T) {
|
||||
testStoreUsersCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangePassword(t *testing.T) {
|
||||
testStoreChangePassword(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeRole(t *testing.T) {
|
||||
testStoreChangeRole(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokens(t *testing.T) {
|
||||
testStoreTokens(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
|
||||
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemove(t *testing.T) {
|
||||
testStoreTokenRemove(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
|
||||
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
|
||||
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
|
||||
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllowAccess(t *testing.T) {
|
||||
testStoreAllowAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
|
||||
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetAccess(t *testing.T) {
|
||||
testStoreResetAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetAccessAll(t *testing.T) {
|
||||
testStoreResetAccessAll(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservations(t *testing.T) {
|
||||
testStoreReservations(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservationsCount(t *testing.T) {
|
||||
testStoreReservationsCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreHasReservation(t *testing.T) {
|
||||
testStoreHasReservation(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreReservationOwner(t *testing.T) {
|
||||
testStoreReservationOwner(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTiers(t *testing.T) {
|
||||
testStoreTiers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierUpdate(t *testing.T) {
|
||||
testStoreTierUpdate(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierRemove(t *testing.T) {
|
||||
testStoreTierRemove(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
|
||||
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeTier(t *testing.T) {
|
||||
testStoreChangeTier(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStorePhoneNumbers(t *testing.T) {
|
||||
testStorePhoneNumbers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeSettings(t *testing.T) {
|
||||
testStoreChangeSettings(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreChangeBilling(t *testing.T) {
|
||||
testStoreChangeBilling(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreUpdateStats(t *testing.T) {
|
||||
testStoreUpdateStats(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreResetStats(t *testing.T) {
|
||||
testStoreResetStats(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
|
||||
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
|
||||
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreAllGrants(t *testing.T) {
|
||||
testStoreAllGrants(t, newTestSQLiteStore(t))
|
||||
}
|
||||
|
||||
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
|
||||
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
|
||||
}
|
||||
619
user/store_test.go
Normal file
619
user/store_test.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/user"
|
||||
)
|
||||
|
||||
func testStoreAddUser(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
require.Equal(t, user.RoleUser, u.Role)
|
||||
require.False(t, u.Provisioned)
|
||||
require.NotEmpty(t, u.ID)
|
||||
require.NotEmpty(t, u.SyncTopic)
|
||||
}
|
||||
|
||||
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
}
|
||||
|
||||
func testStoreRemoveUser(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
|
||||
require.Nil(t, store.RemoveUser("phil"))
|
||||
_, err = store.User("phil")
|
||||
require.Equal(t, user.ErrUserNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreUserByID(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
u2, err := store.UserByID(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, u.Name, u2.Name)
|
||||
require.Equal(t, u.ID, u2.ID)
|
||||
}
|
||||
|
||||
func testStoreUserByToken(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_test123", tk.Value)
|
||||
|
||||
u2, err := store.UserByToken(tk.Value)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u2.Name)
|
||||
}
|
||||
|
||||
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
|
||||
StripeCustomerID: "cus_test123",
|
||||
StripeSubscriptionID: "sub_test123",
|
||||
}))
|
||||
|
||||
u, err := store.UserByStripeCustomer("cus_test123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "phil", u.Name)
|
||||
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
|
||||
}
|
||||
|
||||
func testStoreUsers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
|
||||
|
||||
users, err := store.Users()
|
||||
require.Nil(t, err)
|
||||
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
|
||||
}
|
||||
|
||||
func testStoreUsersCount(t *testing.T, store user.Store) {
|
||||
count, err := store.UsersCount()
|
||||
require.Nil(t, err)
|
||||
require.True(t, count >= 1) // At least the everyone user
|
||||
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
count2, err := store.UsersCount()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, count+1, count2)
|
||||
}
|
||||
|
||||
func testStoreChangePassword(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "philhash", u.Hash)
|
||||
|
||||
require.Nil(t, store.ChangePassword("phil", "newhash"))
|
||||
u, err = store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "newhash", u.Hash)
|
||||
}
|
||||
|
||||
func testStoreChangeRole(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.RoleUser, u.Role)
|
||||
|
||||
require.Nil(t, store.ChangeRole("phil", user.RoleAdmin))
|
||||
u, err = store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.RoleAdmin, u.Role)
|
||||
}
|
||||
|
||||
func testStoreTokens(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
now := time.Now()
|
||||
expires := now.Add(24 * time.Hour)
|
||||
origin := netip.MustParseAddr("9.9.9.9")
|
||||
|
||||
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_abc", tk.Value)
|
||||
require.Equal(t, "my token", tk.Label)
|
||||
|
||||
// Get single token
|
||||
tk2, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_abc", tk2.Value)
|
||||
require.Equal(t, "my token", tk2.Label)
|
||||
|
||||
// Get all tokens
|
||||
tokens, err := store.Tokens(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, tokens, 1)
|
||||
require.Equal(t, "tk_abc", tokens[0].Value)
|
||||
|
||||
// Token count
|
||||
count, err := store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
|
||||
tk, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "new label", tk.Label)
|
||||
}
|
||||
|
||||
func testStoreTokenRemove(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
|
||||
_, err = store.Token(u.ID, "tk_abc")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create expired token and active token
|
||||
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.RemoveExpiredTokens())
|
||||
|
||||
// Expired token should be gone
|
||||
_, err = store.Token(u.ID, "tk_expired")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
|
||||
// Active token should still exist
|
||||
tk, err := store.Token(u.ID, "tk_active")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "tk_active", tk.Value)
|
||||
}
|
||||
|
||||
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create 3 tokens with increasing expiry
|
||||
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
|
||||
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
count, err := store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 3, count)
|
||||
|
||||
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
|
||||
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
|
||||
|
||||
count, err = store.TokenCount(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, count)
|
||||
|
||||
// tk_a should be removed (earliest expiry)
|
||||
_, err = store.Token(u.ID, "tk_a")
|
||||
require.Equal(t, user.ErrTokenNotFound, err)
|
||||
|
||||
// tk_b and tk_c should remain
|
||||
_, err = store.Token(u.ID, "tk_b")
|
||||
require.Nil(t, err)
|
||||
_, err = store.Token(u.ID, "tk_c")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||
require.Nil(t, err)
|
||||
|
||||
newTime := time.Now().Add(5 * time.Minute)
|
||||
newOrigin := netip.MustParseAddr("5.5.5.5")
|
||||
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin))
|
||||
|
||||
tk, err := store.Token(u.ID, "tk_abc")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
|
||||
require.Equal(t, newOrigin, tk.LastOrigin)
|
||||
}
|
||||
|
||||
func testStoreAllowAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.Equal(t, "mytopic", grants[0].TopicPattern)
|
||||
require.True(t, grants[0].Permission.IsReadWrite())
|
||||
}
|
||||
|
||||
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.True(t, grants[0].Permission.IsRead())
|
||||
require.False(t, grants[0].Permission.IsWrite())
|
||||
}
|
||||
|
||||
func testStoreResetAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 2)
|
||||
|
||||
require.Nil(t, store.ResetAccess("phil", "topic1"))
|
||||
grants, err = store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 1)
|
||||
require.Equal(t, "topic2", grants[0].TopicPattern)
|
||||
}
|
||||
|
||||
func testStoreResetAccessAll(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||
|
||||
require.Nil(t, store.ResetAccess("phil", ""))
|
||||
grants, err := store.Grants("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, grants, 0)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||
|
||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.True(t, found)
|
||||
require.True(t, read)
|
||||
require.True(t, write)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
|
||||
require.Nil(t, err)
|
||||
require.False(t, found)
|
||||
}
|
||||
|
||||
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
|
||||
|
||||
read, write, found, err := store.AuthorizeTopicAccess("phil", "secret")
|
||||
require.Nil(t, err)
|
||||
require.True(t, found)
|
||||
require.False(t, read)
|
||||
require.False(t, write)
|
||||
}
|
||||
|
||||
func testStoreReservations(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
|
||||
|
||||
reservations, err := store.Reservations("phil")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, reservations, 1)
|
||||
require.Equal(t, "mytopic", reservations[0].Topic)
|
||||
require.True(t, reservations[0].Owner.IsReadWrite())
|
||||
require.True(t, reservations[0].Everyone.IsRead())
|
||||
require.False(t, reservations[0].Everyone.IsWrite())
|
||||
}
|
||||
|
||||
func testStoreReservationsCount(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
|
||||
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
|
||||
|
||||
count, err := store.ReservationsCount("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func testStoreHasReservation(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
|
||||
has, err := store.HasReservation("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.True(t, has)
|
||||
|
||||
has, err = store.HasReservation("phil", "other")
|
||||
require.Nil(t, err)
|
||||
require.False(t, has)
|
||||
}
|
||||
|
||||
func testStoreReservationOwner(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||
|
||||
owner, err := store.ReservationOwner("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.NotEmpty(t, owner) // Returns the user ID
|
||||
|
||||
owner, err = store.ReservationOwner("unowned")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, owner)
|
||||
}
|
||||
|
||||
func testStoreTiers(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessageLimit: 5000,
|
||||
MessageExpiryDuration: 24 * time.Hour,
|
||||
EmailLimit: 100,
|
||||
CallLimit: 10,
|
||||
ReservationLimit: 20,
|
||||
AttachmentFileSizeLimit: 10 * 1024 * 1024,
|
||||
AttachmentTotalSizeLimit: 100 * 1024 * 1024,
|
||||
AttachmentExpiryDuration: 48 * time.Hour,
|
||||
AttachmentBandwidthLimit: 500 * 1024 * 1024,
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
// Get by code
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "ti_test", t2.ID)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
require.Equal(t, "Pro", t2.Name)
|
||||
require.Equal(t, int64(5000), t2.MessageLimit)
|
||||
require.Equal(t, int64(100), t2.EmailLimit)
|
||||
require.Equal(t, int64(10), t2.CallLimit)
|
||||
require.Equal(t, int64(20), t2.ReservationLimit)
|
||||
|
||||
// List all tiers
|
||||
tiers, err := store.Tiers()
|
||||
require.Nil(t, err)
|
||||
require.Len(t, tiers, 1)
|
||||
require.Equal(t, "pro", tiers[0].Code)
|
||||
}
|
||||
|
||||
func testStoreTierUpdate(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
tier.Name = "Professional"
|
||||
tier.MessageLimit = 9999
|
||||
require.Nil(t, store.UpdateTier(tier))
|
||||
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Professional", t2.Name)
|
||||
require.Equal(t, int64(9999), t2.MessageLimit)
|
||||
}
|
||||
|
||||
func testStoreTierRemove(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
t2, err := store.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
|
||||
require.Nil(t, store.RemoveTier("pro"))
|
||||
_, err = store.Tier("pro")
|
||||
require.Equal(t, user.ErrTierNotFound, err)
|
||||
}
|
||||
|
||||
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
StripeMonthlyPriceID: "price_monthly",
|
||||
StripeYearlyPriceID: "price_yearly",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
|
||||
t2, err := store.TierByStripePrice("price_monthly")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t2.Code)
|
||||
|
||||
t3, err := store.TierByStripePrice("price_yearly")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "pro", t3.Code)
|
||||
}
|
||||
|
||||
func testStoreChangeTier(t *testing.T, store user.Store) {
|
||||
tier := &user.Tier{
|
||||
ID: "ti_test",
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
}
|
||||
require.Nil(t, store.AddTier(tier))
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.ChangeTier("phil", "pro"))
|
||||
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, u.Tier)
|
||||
require.Equal(t, "pro", u.Tier.Code)
|
||||
}
|
||||
|
||||
func testStorePhoneNumbers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890"))
|
||||
require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321"))
|
||||
|
||||
numbers, err := store.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, numbers, 2)
|
||||
|
||||
require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890"))
|
||||
numbers, err = store.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, numbers, 1)
|
||||
require.Equal(t, "+0987654321", numbers[0])
|
||||
}
|
||||
|
||||
func testStoreChangeSettings(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
lang := "de"
|
||||
prefs := &user.Prefs{Language: &lang}
|
||||
require.Nil(t, store.ChangeSettings(u.ID, prefs))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, u2.Prefs)
|
||||
require.Equal(t, "de", *u2.Prefs.Language)
|
||||
}
|
||||
|
||||
func testStoreChangeBilling(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
|
||||
billing := &user.Billing{
|
||||
StripeCustomerID: "cus_123",
|
||||
StripeSubscriptionID: "sub_456",
|
||||
}
|
||||
require.Nil(t, store.ChangeBilling("phil", billing))
|
||||
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
|
||||
}
|
||||
|
||||
func testStoreUpdateStats(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
|
||||
require.Nil(t, store.UpdateStats(u.ID, stats))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(42), u2.Stats.Messages)
|
||||
require.Equal(t, int64(3), u2.Stats.Emails)
|
||||
require.Equal(t, int64(1), u2.Stats.Calls)
|
||||
}
|
||||
|
||||
func testStoreResetStats(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
|
||||
require.Nil(t, store.ResetStats())
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), u2.Stats.Messages)
|
||||
require.Equal(t, int64(0), u2.Stats.Emails)
|
||||
require.Equal(t, int64(0), u2.Stats.Calls)
|
||||
}
|
||||
|
||||
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.True(t, u2.Deleted)
|
||||
}
|
||||
|
||||
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
u, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||
|
||||
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
|
||||
// Immediately after marking, the user should still exist.
|
||||
require.Nil(t, store.RemoveDeletedUsers())
|
||||
u2, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.True(t, u2.Deleted)
|
||||
}
|
||||
|
||||
func testStoreAllGrants(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||
phil, err := store.User("phil")
|
||||
require.Nil(t, err)
|
||||
ben, err := store.User("ben")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||
require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false))
|
||||
|
||||
grants, err := store.AllGrants()
|
||||
require.Nil(t, err)
|
||||
require.Contains(t, grants, phil.ID)
|
||||
require.Contains(t, grants, ben.ID)
|
||||
}
|
||||
|
||||
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
|
||||
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
|
||||
|
||||
count, err := store.OtherAccessCount("phil", "mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
@@ -242,6 +242,20 @@ const (
|
||||
everyoneID = "u_everyone"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the user Manager
|
||||
type Config struct {
|
||||
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
|
||||
DatabaseURL string // Database connection string (PostgreSQL)
|
||||
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
|
||||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
|
||||
// Error constants used by the package
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
|
||||
@@ -2,6 +2,7 @@ package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
@@ -25,8 +26,8 @@ const (
|
||||
PRIMARY KEY (subscription_id, topic)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS webpush_schema_version (
|
||||
id INT PRIMARY KEY,
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
@@ -65,8 +66,8 @@ const (
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 1
|
||||
pgInsertSchemaVersion = `INSERT INTO webpush_schema_version VALUES (1, $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1`
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
||||
@@ -102,12 +103,15 @@ func NewPostgresStore(dsn string) (Store, error) {
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
// If 'webpush_schema_version' table does not exist, this must be a new database
|
||||
rows, err := db.Query(pgSelectSchemaVersionQuery)
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
@@ -82,10 +83,10 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLiteWebPushDB(db); err != nil {
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
@@ -108,16 +109,19 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupSQLiteWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLiteWebPushDB(db)
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
return rows.Close()
|
||||
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,7 +131,7 @@ func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user