Refactor user package to Store interface with PostgreSQL support
Extract database operations from Manager into a Store interface with SQLite and PostgreSQL implementations using a shared commonStore. Split SQLite migrations into store_sqlite_migrations.go, use shared schema_version table for PostgreSQL, rename user_user/user_tier tables to "user"/tier, and wire up database-url in CLI commands.
This commit is contained in:
25
cmd/user.go
25
cmd/user.go
@@ -29,6 +29,7 @@ var flagsUser = append(
|
||||
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
|
||||
)
|
||||
|
||||
var cmdUser = &cli.Command{
|
||||
@@ -365,24 +366,32 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
||||
authFile := c.String("auth-file")
|
||||
authStartupQueries := c.String("auth-startup-queries")
|
||||
authDefaultAccess := c.String("auth-default-access")
|
||||
if authFile == "" {
|
||||
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
|
||||
} else if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
databaseURL := c.String("database-url")
|
||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||
if err != nil {
|
||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||
}
|
||||
authConfig := &user.Config{
|
||||
Filename: authFile,
|
||||
StartupQueries: authStartupQueries,
|
||||
DefaultAccess: authDefault,
|
||||
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
}
|
||||
return user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if databaseURL != "" {
|
||||
store, err = user.NewPostgresStore(databaseURL)
|
||||
} else if authFile != "" {
|
||||
if !util.FileExists(authFile) {
|
||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||
}
|
||||
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
|
||||
} else {
|
||||
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.NewManager(store, authConfig)
|
||||
}
|
||||
|
||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||
|
||||
@@ -204,9 +204,10 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
||||
authConfig := &user.Config{
|
||||
Filename: conf.AuthFile,
|
||||
DatabaseURL: conf.DatabaseURL,
|
||||
StartupQueries: conf.AuthStartupQueries,
|
||||
DefaultAccess: conf.AuthDefault,
|
||||
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||
@@ -216,7 +217,16 @@ func New(conf *Config) (*Server, error) {
|
||||
BcryptCost: conf.AuthBcryptCost,
|
||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||
}
|
||||
userManager, err = user.NewManager(authConfig)
|
||||
var store user.Store
|
||||
if conf.DatabaseURL != "" {
|
||||
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
||||
} else {
|
||||
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userManager, err = user.NewManager(store, authConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
1193
user/manager.go
1193
user/manager.go
File diff suppressed because it is too large
Load Diff
@@ -224,7 +224,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.True(t, u.Deleted)
|
||||
|
||||
_, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
|
||||
_, err = testDB(a).Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, a.RemoveDeletedUsers())
|
||||
|
||||
@@ -604,14 +604,14 @@ func TestManager_Token_Expire(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
|
||||
// Modify token expiration in database
|
||||
_, err = a.db.Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value)
|
||||
_, err = testDB(a).Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Now token1 shouldn't work anymore
|
||||
_, err = a.AuthenticateToken(token1.Value)
|
||||
require.Equal(t, ErrUnauthenticated, err)
|
||||
|
||||
result, err := a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
||||
result, err := testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
||||
require.Nil(t, err)
|
||||
require.True(t, result.Next()) // Still a matching row
|
||||
require.Nil(t, result.Close())
|
||||
@@ -619,7 +619,7 @@ func TestManager_Token_Expire(t *testing.T) {
|
||||
// Expire tokens and check database rows
|
||||
require.Nil(t, a.RemoveExpiredTokens())
|
||||
|
||||
result, err = a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
||||
result, err = testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
||||
require.Nil(t, err)
|
||||
require.False(t, result.Next()) // No matching row!
|
||||
require.Nil(t, result.Close())
|
||||
@@ -687,7 +687,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||
benTokens = append(benTokens, token.Value)
|
||||
|
||||
// Manually modify expiry date to avoid sorting issues (this is a hack)
|
||||
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
|
||||
_, err = testDB(a).Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
@@ -715,14 +715,14 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
var benCount int
|
||||
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID)
|
||||
rows, err := testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&benCount))
|
||||
require.Equal(t, 60, benCount)
|
||||
|
||||
var philCount int
|
||||
rows, err = a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID)
|
||||
rows, err = testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&philCount))
|
||||
@@ -730,15 +730,13 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "db")
|
||||
conf := &Config{
|
||||
Filename: filepath.Join(t.TempDir(), "db"),
|
||||
StartupQueries: "",
|
||||
DefaultAccess: PermissionReadWrite,
|
||||
BcryptCost: bcrypt.MinCost,
|
||||
QueueWriterInterval: 1500 * time.Millisecond,
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a := newTestManagerFromStoreConfig(t, filename, conf)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
|
||||
// Baseline: No messages or emails
|
||||
@@ -779,15 +777,13 @@ func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "db")
|
||||
conf := &Config{
|
||||
Filename: filepath.Join(t.TempDir(), "db"),
|
||||
StartupQueries: "",
|
||||
DefaultAccess: PermissionReadWrite,
|
||||
BcryptCost: bcrypt.MinCost,
|
||||
QueueWriterInterval: 500 * time.Millisecond,
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a := newTestManagerFromStoreConfig(t, filename, conf)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
|
||||
// Create user and token
|
||||
@@ -819,15 +815,13 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManager_ChangeSettings(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "db")
|
||||
conf := &Config{
|
||||
Filename: filepath.Join(t.TempDir(), "db"),
|
||||
StartupQueries: "",
|
||||
DefaultAccess: PermissionReadWrite,
|
||||
BcryptCost: bcrypt.MinCost,
|
||||
QueueWriterInterval: 1500 * time.Millisecond,
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a := newTestManagerFromStoreConfig(t, filename, conf)
|
||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||
|
||||
// No settings
|
||||
@@ -1053,7 +1047,7 @@ func TestUser_PhoneNumberAddListRemove(t *testing.T) {
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
|
||||
// Paranoia check: We do NOT want to keep phone numbers in there
|
||||
rows, err := a.db.Query(`SELECT * FROM user_phone`)
|
||||
rows, err := testDB(a).Query(`SELECT * FROM user_phone`)
|
||||
require.Nil(t, err)
|
||||
require.False(t, rows.Next())
|
||||
require.Nil(t, rows.Close())
|
||||
@@ -1098,7 +1092,6 @@ func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
|
||||
func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "user.db")
|
||||
conf := &Config{
|
||||
Filename: f,
|
||||
DefaultAccess: PermissionReadWrite,
|
||||
ProvisionEnabled: true,
|
||||
Users: []*User{
|
||||
@@ -1117,8 +1110,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a := newTestManagerFromStoreConfig(t, f, conf)
|
||||
|
||||
// Manually add user
|
||||
require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false))
|
||||
@@ -1154,13 +1146,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
// Update the token last access time and origin (so we can check that it is persisted)
|
||||
lastAccessTime := time.Now().Add(time.Hour)
|
||||
lastOrigin := netip.MustParseAddr("1.1.9.9")
|
||||
err = execTx(a.db, func(tx *sql.Tx) error {
|
||||
return a.updateTokenLastAccessTx(tx, tokens[0].Value, lastAccessTime.Unix(), lastOrigin.String())
|
||||
})
|
||||
err = a.store.UpdateTokenLastAccess(tokens[0].Value, lastAccessTime, lastOrigin)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Re-open the DB (second app start)
|
||||
require.Nil(t, a.db.Close())
|
||||
require.Nil(t, a.Close())
|
||||
conf.Users = []*User{
|
||||
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
||||
}
|
||||
@@ -1176,8 +1166,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
{Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"},
|
||||
},
|
||||
}
|
||||
a, err = NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a = newTestManagerFromStoreConfig(t, f, conf)
|
||||
|
||||
// Check that the provisioned users are there
|
||||
users, err = a.Users()
|
||||
@@ -1212,12 +1201,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
require.Error(t, a.ChangePassword("philuser", "new-pass", false))
|
||||
|
||||
// Re-open the DB again (third app start)
|
||||
require.Nil(t, a.db.Close())
|
||||
require.Nil(t, a.Close())
|
||||
conf.Users = []*User{}
|
||||
conf.Access = map[string][]*Grant{}
|
||||
conf.Tokens = map[string][]*Token{}
|
||||
a, err = NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a = newTestManagerFromStoreConfig(t, f, conf)
|
||||
|
||||
// Check that the provisioned users are all gone
|
||||
users, err = a.Users()
|
||||
@@ -1237,17 +1225,16 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||
require.Equal(t, 0, len(tokens))
|
||||
|
||||
var count int
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count)
|
||||
testDB(a).QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count)
|
||||
require.Equal(t, 0, count)
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM user_grant WHERE provisioned = 1").Scan(&count)
|
||||
testDB(a).QueryRow("SELECT COUNT(*) FROM user_access WHERE provisioned = 1").Scan(&count)
|
||||
require.Equal(t, 0, count)
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count)
|
||||
testDB(a).QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count)
|
||||
}
|
||||
|
||||
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||
f := filepath.Join(t.TempDir(), "user.db")
|
||||
conf := &Config{
|
||||
Filename: f,
|
||||
DefaultAccess: PermissionReadWrite,
|
||||
ProvisionEnabled: true,
|
||||
Users: []*User{},
|
||||
@@ -1257,8 +1244,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a := newTestManagerFromStoreConfig(t, f, conf)
|
||||
|
||||
// Manually add user
|
||||
require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false))
|
||||
@@ -1290,7 +1276,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||
require.True(t, grants[0].Provisioned) // Provisioned entry
|
||||
|
||||
// Re-open the DB (second app start)
|
||||
require.Nil(t, a.db.Close())
|
||||
require.Nil(t, a.Close())
|
||||
conf.Users = []*User{
|
||||
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
||||
}
|
||||
@@ -1299,8 +1285,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||
{TopicPattern: "stats", Permission: PermissionReadWrite},
|
||||
},
|
||||
}
|
||||
a, err = NewManager(conf)
|
||||
require.Nil(t, err)
|
||||
a = newTestManagerFromStoreConfig(t, f, conf)
|
||||
|
||||
// Check that the user was "upgraded" to a provisioned user
|
||||
users, err = a.Users()
|
||||
@@ -1383,7 +1368,7 @@ func TestMigrationFrom1(t *testing.T) {
|
||||
|
||||
// Create manager to trigger migration
|
||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||
checkSchemaVersion(t, a.db)
|
||||
checkSchemaVersion(t, testDB(a))
|
||||
|
||||
users, err := a.Users()
|
||||
require.Nil(t, err)
|
||||
@@ -1526,7 +1511,7 @@ func TestMigrationFrom4(t *testing.T) {
|
||||
|
||||
// Create manager to trigger migration
|
||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||
checkSchemaVersion(t, a.db)
|
||||
checkSchemaVersion(t, testDB(a))
|
||||
|
||||
// Add another
|
||||
require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite))
|
||||
@@ -1587,14 +1572,26 @@ func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
||||
}
|
||||
|
||||
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager {
|
||||
store, err := NewSQLiteStore(filename, startupQueries)
|
||||
require.Nil(t, err)
|
||||
conf := &Config{
|
||||
Filename: filename,
|
||||
StartupQueries: startupQueries,
|
||||
DefaultAccess: defaultAccess,
|
||||
BcryptCost: bcryptCost,
|
||||
QueueWriterInterval: statsWriterInterval,
|
||||
}
|
||||
a, err := NewManager(conf)
|
||||
a, err := NewManager(store, conf)
|
||||
require.Nil(t, err)
|
||||
return a
|
||||
}
|
||||
|
||||
func testDB(a *Manager) *sql.DB {
|
||||
return a.store.(*commonStore).db
|
||||
}
|
||||
|
||||
func newTestManagerFromStoreConfig(t *testing.T, filename string, conf *Config) *Manager {
|
||||
store, err := NewSQLiteStore(filename, "")
|
||||
require.Nil(t, err)
|
||||
a, err := NewManager(store, conf)
|
||||
require.Nil(t, err)
|
||||
return a
|
||||
}
|
||||
|
||||
992
user/store.go
Normal file
992
user/store.go
Normal file
@@ -0,0 +1,992 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"heckel.io/ntfy/v2/payments"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Store is the interface for a user database store
|
||||
type Store interface {
|
||||
// User operations
|
||||
UserByID(id string) (*User, error)
|
||||
User(username string) (*User, error)
|
||||
UserByToken(token string) (*User, error)
|
||||
UserByStripeCustomer(customerID string) (*User, error)
|
||||
Users() ([]*User, error)
|
||||
UsersCount() (int64, error)
|
||||
AddUser(username, hash string, role Role, provisioned bool) error
|
||||
RemoveUser(username string) error
|
||||
MarkUserRemoved(userID string) error
|
||||
RemoveDeletedUsers() error
|
||||
ChangePassword(username, hash string) error
|
||||
ChangeRole(username string, role Role) error
|
||||
ChangeProvisioned(username string, provisioned bool) error
|
||||
ChangeSettings(userID string, prefs *Prefs) error
|
||||
ChangeTier(username, tierCode string) error
|
||||
ResetTier(username string) error
|
||||
UpdateStats(userID string, stats *Stats) error
|
||||
ResetStats() error
|
||||
// Token operations
|
||||
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
|
||||
Token(userID, token string) (*Token, error)
|
||||
Tokens(userID string) ([]*Token, error)
|
||||
AllProvisionedTokens() ([]*Token, error)
|
||||
ChangeTokenLabel(userID, token, label string) error
|
||||
ChangeTokenExpiry(userID, token string, expires time.Time) error
|
||||
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
|
||||
RemoveToken(userID, token string) error
|
||||
RemoveExpiredTokens() error
|
||||
TokenCount(userID string) (int, error)
|
||||
RemoveExcessTokens(userID string, maxCount int) error
|
||||
// Access operations
|
||||
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
|
||||
AllGrants() (map[string][]Grant, error)
|
||||
Grants(username string) ([]Grant, error)
|
||||
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
|
||||
ResetAccess(username, topicPattern string) error
|
||||
ResetAllProvisionedAccess() error
|
||||
Reservations(username string) ([]Reservation, error)
|
||||
HasReservation(username, topic string) (bool, error)
|
||||
ReservationsCount(username string) (int64, error)
|
||||
ReservationOwner(topic string) (string, error)
|
||||
OtherAccessCount(username, topic string) (int, error)
|
||||
// Tier operations
|
||||
AddTier(tier *Tier) error
|
||||
UpdateTier(tier *Tier) error
|
||||
RemoveTier(code string) error
|
||||
Tiers() ([]*Tier, error)
|
||||
Tier(code string) (*Tier, error)
|
||||
TierByStripePrice(priceID string) (*Tier, error)
|
||||
// Phone operations
|
||||
PhoneNumbers(userID string) ([]string, error)
|
||||
AddPhoneNumber(userID, phoneNumber string) error
|
||||
RemovePhoneNumber(userID, phoneNumber string) error
|
||||
// Billing
|
||||
ChangeBilling(username string, billing *Billing) error
|
||||
// Internal helpers
|
||||
UserIDByUsername(username string) (string, error)
|
||||
// System
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries
|
||||
type storeQueries struct {
|
||||
// User queries
|
||||
selectUserByID string
|
||||
selectUserByName string
|
||||
selectUserByToken string
|
||||
selectUserByStripeID string
|
||||
selectUsernames string
|
||||
selectUserCount string
|
||||
selectUserIDFromUsername string
|
||||
insertUser string
|
||||
updateUserPass string
|
||||
updateUserRole string
|
||||
updateUserProvisioned string
|
||||
updateUserPrefs string
|
||||
updateUserStats string
|
||||
updateUserStatsResetAll string
|
||||
updateUserTier string
|
||||
updateUserDeleted string
|
||||
deleteUser string
|
||||
deleteUserTier string
|
||||
deleteUsersMarked string
|
||||
// Access queries
|
||||
selectTopicPerms string
|
||||
selectUserAllAccess string
|
||||
selectUserAccess string
|
||||
selectUserReservations string
|
||||
selectUserReservationsCount string
|
||||
selectUserReservationsOwner string
|
||||
selectUserHasReservation string
|
||||
selectOtherAccessCount string
|
||||
upsertUserAccess string
|
||||
deleteUserAccess string
|
||||
deleteUserAccessProvisioned string
|
||||
deleteTopicAccess string
|
||||
deleteAllAccess string
|
||||
// Token queries
|
||||
selectToken string
|
||||
selectTokens string
|
||||
selectTokenCount string
|
||||
selectAllProvisionedTokens string
|
||||
upsertToken string
|
||||
updateTokenLabel string
|
||||
updateTokenExpiry string
|
||||
updateTokenLastAccess string
|
||||
deleteToken string
|
||||
deleteProvisionedToken string
|
||||
deleteAllToken string
|
||||
deleteExpiredTokens string
|
||||
deleteExcessTokens string
|
||||
// Tier queries
|
||||
insertTier string
|
||||
selectTiers string
|
||||
selectTierByCode string
|
||||
selectTierByPriceID string
|
||||
updateTier string
|
||||
deleteTier string
|
||||
// Phone queries
|
||||
selectPhoneNumbers string
|
||||
insertPhoneNumber string
|
||||
deletePhoneNumber string
|
||||
// Billing queries
|
||||
updateBilling string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that work across database backends
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByID(id string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) User(username string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByToken(token string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.readUser(rows)
|
||||
}
|
||||
|
||||
// Users returns a list of users
|
||||
func (s *commonStore) Users() ([]*User, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUsernames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
usernames := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
rows.Close()
|
||||
users := make([]*User, 0)
|
||||
for _, username := range usernames {
|
||||
user, err := s.User(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UsersCount returns the number of users in the database
|
||||
func (s *commonStore) UsersCount() (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserCount)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddUser adds a user with the given username, password hash and role
|
||||
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
now := time.Now().Unix()
|
||||
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrUserExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveUser deletes the user with the given username
|
||||
func (s *commonStore) RemoveUser(username string) error {
|
||||
if !AllowedUsername(username) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
|
||||
func (s *commonStore) MarkUserRemoved(userID string) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Get username for deleteUserAccess query
|
||||
user, err := s.UserByID(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||
func (s *commonStore) RemoveDeletedUsers() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password
|
||||
func (s *commonStore) ChangePassword(username, hash string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeRole changes a user's role
|
||||
func (s *commonStore) ChangeRole(username string, role Role) error {
|
||||
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
|
||||
return err
|
||||
}
|
||||
// If changing to admin, remove all access entries
|
||||
if role == RoleAdmin {
|
||||
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ChangeProvisioned changes the provisioned status of a user
|
||||
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeSettings persists the user settings
|
||||
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
|
||||
b, err := json.Marshal(prefs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTier changes a user's tier using the tier code
|
||||
func (s *commonStore) ChangeTier(username, tierCode string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetTier removes the tier from the given user
|
||||
func (s *commonStore) ResetTier(username string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteUserTier, username)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateStats updates the user statistics
|
||||
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetStats resets all user stats in the user database
|
||||
func (s *commonStore) ResetStats() error {
|
||||
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var id, username, hash, role, prefs, syncTopic string
|
||||
var provisioned bool
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
||||
var messages, emails, calls int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := &User{
|
||||
ID: id,
|
||||
Name: username,
|
||||
Hash: hash,
|
||||
Role: Role(role),
|
||||
Prefs: &Prefs{},
|
||||
SyncTopic: syncTopic,
|
||||
Provisioned: provisioned,
|
||||
Stats: &Stats{
|
||||
Messages: messages,
|
||||
Emails: emails,
|
||||
Calls: calls,
|
||||
},
|
||||
Billing: &Billing{
|
||||
StripeCustomerID: stripeCustomerID.String,
|
||||
StripeSubscriptionID: stripeSubscriptionID.String,
|
||||
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
|
||||
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
|
||||
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
|
||||
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
|
||||
},
|
||||
Deleted: deleted.Valid,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tierCode.Valid {
|
||||
user.Tier = &Tier{
|
||||
ID: tierID.String,
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new token
|
||||
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
|
||||
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: lastAccess,
|
||||
LastOrigin: lastOrigin,
|
||||
Expires: expires,
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Token returns a specific token for a user
|
||||
func (s *commonStore) Token(userID, token string) (*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectToken, userID, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readToken(rows)
|
||||
}
|
||||
|
||||
// Tokens returns all existing tokens for the user with the given user ID
|
||||
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokens, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// AllProvisionedTokens returns all provisioned tokens
|
||||
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
|
||||
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tokens := make([]*Token, 0)
|
||||
for {
|
||||
token, err := s.readToken(rows)
|
||||
if errors.Is(err, ErrTokenNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ChangeTokenLabel updates a token's label
|
||||
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangeTokenExpiry updates a token's expiry time
|
||||
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTokenLastAccess updates a token's last access time and origin
|
||||
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveToken deletes the token
|
||||
func (s *commonStore) RemoveToken(userID, token string) error {
|
||||
if token == "" {
|
||||
return errNoTokenProvided
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveExpiredTokens deletes all expired tokens from the database
|
||||
func (s *commonStore) RemoveExpiredTokens() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenCount returns the number of tokens for a user
|
||||
func (s *commonStore) TokenCount(userID string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
|
||||
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
|
||||
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
|
||||
var token, label, lastOrigin string
|
||||
var lastAccess, expires int64
|
||||
var provisioned bool
|
||||
if !rows.Next() {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||
if err != nil {
|
||||
lastOriginIP = netip.IPv4Unspecified()
|
||||
}
|
||||
return &Token{
|
||||
Value: token,
|
||||
Label: label,
|
||||
LastAccess: time.Unix(lastAccess, 0),
|
||||
LastOrigin: lastOriginIP,
|
||||
Expires: time.Unix(expires, 0),
|
||||
Provisioned: provisioned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
|
||||
// The found return value indicates whether an ACL entry was found at all.
|
||||
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||
if err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, false, false, nil
|
||||
}
|
||||
if err := rows.Scan(&read, &write); err != nil {
|
||||
return false, false, false, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return false, false, false, err
|
||||
}
|
||||
return read, write, true, nil
|
||||
}
|
||||
|
||||
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAllAccess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make(map[string][]Grant, 0)
|
||||
for rows.Next() {
|
||||
var userID, topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := grants[userID]; !ok {
|
||||
grants[userID] = make([]Grant, 0)
|
||||
}
|
||||
grants[userID] = append(grants[userID], Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// Grants returns all user-specific access control entries
|
||||
func (s *commonStore) Grants(username string) ([]Grant, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserAccess, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
grants := make([]Grant, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var read, write, provisioned bool
|
||||
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grants = append(grants, Grant{
|
||||
TopicPattern: fromSQLWildcard(topic),
|
||||
Permission: NewPermission(read, write),
|
||||
Provisioned: provisioned,
|
||||
})
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// AllowAccess adds or updates an entry in the access control list
|
||||
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
||||
if !AllowedUsername(username) && username != Everyone {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAccess removes an access control list entry
|
||||
func (s *commonStore) ResetAccess(username, topicPattern string) error {
|
||||
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||
return ErrInvalidArgument
|
||||
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if username == "" && topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteAllAccess)
|
||||
return err
|
||||
} else if topicPattern == "" {
|
||||
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetAllProvisionedAccess removes all provisioned access control entries
|
||||
func (s *commonStore) ResetAllProvisionedAccess() error {
|
||||
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
reservations := make([]Reservation, 0)
|
||||
for rows.Next() {
|
||||
var topic string
|
||||
var ownerRead, ownerWrite bool
|
||||
var everyoneRead, everyoneWrite sql.NullBool
|
||||
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reservations = append(reservations, Reservation{
|
||||
Topic: unescapeUnderscore(topic),
|
||||
Owner: NewPermission(ownerRead, ownerWrite),
|
||||
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
|
||||
})
|
||||
}
|
||||
return reservations, nil
|
||||
}
|
||||
|
||||
// HasReservation returns true if the given topic access is owned by the user
|
||||
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return false, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ReservationsCount returns the number of reservations owned by this user
|
||||
func (s *commonStore) ReservationsCount(username string) (int64, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int64
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||
func (s *commonStore) ReservationOwner(topic string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", nil
|
||||
}
|
||||
var ownerUserID string
|
||||
if err := rows.Scan(&ownerUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ownerUserID, nil
|
||||
}
|
||||
|
||||
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
|
||||
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
var count int
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// AddTier creates a new tier in the database
|
||||
func (s *commonStore) AddTier(tier *Tier) error {
|
||||
if tier.ID == "" {
|
||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTier updates a tier's properties in the database
|
||||
func (s *commonStore) UpdateTier(tier *Tier) error {
|
||||
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTier deletes the tier with the given code
|
||||
func (s *commonStore) RemoveTier(code string) error {
|
||||
if !AllowedTier(code) {
|
||||
return ErrInvalidArgument
|
||||
}
|
||||
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tiers returns a list of all Tier structs
|
||||
func (s *commonStore) Tiers() ([]*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTiers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
tiers := make([]*Tier, 0)
|
||||
for {
|
||||
tier, err := s.readTier(rows)
|
||||
if errors.Is(err, ErrTierNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tiers = append(tiers, tier)
|
||||
}
|
||||
return tiers, nil
|
||||
}
|
||||
|
||||
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) Tier(code string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByCode, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
|
||||
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
|
||||
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return s.readTier(rows)
|
||||
}
|
||||
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
|
||||
var id, code, name string
|
||||
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrTierNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tier{
|
||||
ID: id,
|
||||
Code: code,
|
||||
Name: name,
|
||||
MessageLimit: messagesLimit.Int64,
|
||||
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||
EmailLimit: emailsLimit.Int64,
|
||||
CallLimit: callsLimit.Int64,
|
||||
ReservationLimit: reservationsLimit.Int64,
|
||||
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
phoneNumbers := make([]string, 0)
|
||||
for {
|
||||
phoneNumber, err := s.readPhoneNumber(rows)
|
||||
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
phoneNumbers = append(phoneNumbers, phoneNumber)
|
||||
}
|
||||
return phoneNumbers, nil
|
||||
}
|
||||
|
||||
// AddPhoneNumber adds a phone number to the user with the given user ID
|
||||
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
|
||||
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
|
||||
if isUniqueConstraintError(err) {
|
||||
return ErrPhoneNumberExists
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePhoneNumber deletes a phone number from the user with the given user ID
|
||||
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
|
||||
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
|
||||
return err
|
||||
}
|
||||
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
|
||||
var phoneNumber string
|
||||
if !rows.Next() {
|
||||
return "", ErrPhoneNumberNotFound
|
||||
}
|
||||
if err := rows.Scan(&phoneNumber); err != nil {
|
||||
return "", err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return phoneNumber, nil
|
||||
}
|
||||
|
||||
// ChangeBilling updates a user's billing fields
|
||||
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
|
||||
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserIDByUsername returns the user ID for the given username
|
||||
func (s *commonStore) UserIDByUsername(username string) (string, error) {
|
||||
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return "", ErrUserNotFound
|
||||
}
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database
|
||||
func (s *commonStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
// Check for SQLite unique constraint error
|
||||
if sqliteErr, ok := err.(interface{ ExtendedCode() int }); ok {
|
||||
if sqliteErr.ExtendedCode() == 2067 { // sqlite3.ErrConstraintUnique
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check for PostgreSQL unique violation (error code 23505)
|
||||
if pgErr, ok := err.(interface{ Code() string }); ok {
|
||||
if pgErr.Code() == "23505" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
424
user/store_postgres.go
Normal file
424
user/store_postgres.go
Normal file
@@ -0,0 +1,424 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL schema and queries
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
|
||||
// User queries
|
||||
postgresSelectUserByID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = $1
|
||||
`
|
||||
postgresSelectUserByName = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user_name = $1
|
||||
`
|
||||
postgresSelectUserByToken = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||
`
|
||||
postgresSelectUserByStripeID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM "user" u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = $1
|
||||
`
|
||||
postgresSelectUsernames = `
|
||||
SELECT user_name
|
||||
FROM "user"
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user_name
|
||||
`
|
||||
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
|
||||
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
|
||||
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
|
||||
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
|
||||
|
||||
// Access queries
|
||||
postgresSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN "user" u ON u.id = a.user_id
|
||||
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||
`
|
||||
postgresSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||
`
|
||||
postgresSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
postgresSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
`
|
||||
postgresSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = $1
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
postgresSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
AND topic = $2
|
||||
`
|
||||
postgresSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||
`
|
||||
postgresUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES (
|
||||
(SELECT id FROM "user" WHERE user_name = $1),
|
||||
$2,
|
||||
$3,
|
||||
$4,
|
||||
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||
$7
|
||||
)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
|
||||
`
|
||||
postgresDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||
`
|
||||
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
|
||||
postgresDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||
AND topic = $3
|
||||
`
|
||||
postgresDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||
postgresUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
|
||||
`
|
||||
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
|
||||
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
|
||||
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
|
||||
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||
postgresDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = $1
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = $2
|
||||
ORDER BY expires DESC
|
||||
LIMIT $3
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
postgresInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
`
|
||||
postgresUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||
WHERE code = $13
|
||||
`
|
||||
postgresSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
postgresSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = $1
|
||||
`
|
||||
postgresSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||
`
|
||||
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
|
||||
|
||||
// Phone queries
|
||||
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||
|
||||
// Billing queries
|
||||
postgresUpdateBilling = `
|
||||
UPDATE "user"
|
||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||
WHERE user_name = $7
|
||||
`
|
||||
|
||||
// Schema version queries
|
||||
postgresSelectSchemaVersionExists = `SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'schema_version'
|
||||
)`
|
||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
||||
func NewPostgresStore(dsn string) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgresDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: postgresQueries(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func postgresQueries() storeQueries {
|
||||
return storeQueries{
|
||||
// User queries
|
||||
selectUserByID: postgresSelectUserByID,
|
||||
selectUserByName: postgresSelectUserByName,
|
||||
selectUserByToken: postgresSelectUserByToken,
|
||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||
selectUsernames: postgresSelectUsernames,
|
||||
selectUserCount: postgresSelectUserCount,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||
insertUser: postgresInsertUser,
|
||||
updateUserPass: postgresUpdateUserPass,
|
||||
updateUserRole: postgresUpdateUserRole,
|
||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||
updateUserPrefs: postgresUpdateUserPrefs,
|
||||
updateUserStats: postgresUpdateUserStats,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||
updateUserTier: postgresUpdateUserTier,
|
||||
updateUserDeleted: postgresUpdateUserDeleted,
|
||||
deleteUser: postgresDeleteUser,
|
||||
deleteUserTier: postgresDeleteUserTier,
|
||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms: postgresSelectTopicPerms,
|
||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||
selectUserAccess: postgresSelectUserAccess,
|
||||
selectUserReservations: postgresSelectUserReservations,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||
upsertUserAccess: postgresUpsertUserAccess,
|
||||
deleteUserAccess: postgresDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||
deleteAllAccess: postgresDeleteAllAccess,
|
||||
|
||||
// Token queries
|
||||
selectToken: postgresSelectToken,
|
||||
selectTokens: postgresSelectTokens,
|
||||
selectTokenCount: postgresSelectTokenCount,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||
upsertToken: postgresUpsertToken,
|
||||
updateTokenLabel: postgresUpdateTokenLabel,
|
||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||
deleteToken: postgresDeleteToken,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||
deleteAllToken: postgresDeleteAllToken,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||
|
||||
// Tier queries
|
||||
insertTier: postgresInsertTier,
|
||||
selectTiers: postgresSelectTiers,
|
||||
selectTierByCode: postgresSelectTierByCode,
|
||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||
updateTier: postgresUpdateTier,
|
||||
deleteTier: postgresDeleteTier,
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||
|
||||
// Billing queries
|
||||
updateBilling: postgresUpdateBilling,
|
||||
}
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
// Check if schema version table exists
|
||||
var exists bool
|
||||
if err := db.QueryRow(postgresSelectSchemaVersionExists).Scan(&exists); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// New database, create all tables
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Table exists, check if user store has a row
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
// schema_version table exists (e.g. created by webpush) but no user row yet
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("cannot determine schema version: %v", err)
|
||||
}
|
||||
|
||||
if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
// For now, we only support new installations at the latest schema version
|
||||
|
||||
return nil
|
||||
}
|
||||
420
user/store_sqlite.go
Normal file
420
user/store_sqlite.go
Normal file
@@ -0,0 +1,420 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// SQLite schema and queries
|
||||
const (
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
sqliteBuiltinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
// User queries
|
||||
sqliteSelectUserByID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
sqliteSelectUserByName = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
sqliteSelectUserByToken = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
JOIN user_token tk on u.id = tk.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||
`
|
||||
sqliteSelectUserByStripeID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
`
|
||||
sqliteSelectUsernames = `
|
||||
SELECT user
|
||||
FROM user
|
||||
ORDER BY
|
||||
CASE role
|
||||
WHEN 'admin' THEN 1
|
||||
WHEN 'anonymous' THEN 3
|
||||
ELSE 2
|
||||
END, user
|
||||
`
|
||||
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
|
||||
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
|
||||
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
|
||||
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
|
||||
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
|
||||
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
|
||||
|
||||
// Access queries
|
||||
sqliteSelectTopicPerms = `
|
||||
SELECT read, write
|
||||
FROM user_access a
|
||||
JOIN user u ON u.id = a.user_id
|
||||
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||
`
|
||||
sqliteSelectUserAllAccess = `
|
||||
SELECT user_id, topic, read, write, provisioned
|
||||
FROM user_access
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserAccess = `
|
||||
SELECT topic, read, write, provisioned
|
||||
FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||
`
|
||||
sqliteSelectUserReservations = `
|
||||
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||
FROM user_access a_user
|
||||
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||
WHERE a_user.user_id = a_user.owner_user_id
|
||||
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
ORDER BY a_user.topic
|
||||
`
|
||||
sqliteSelectUserReservationsCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteSelectUserReservationsOwner = `
|
||||
SELECT owner_user_id
|
||||
FROM user_access
|
||||
WHERE topic = ?
|
||||
AND user_id = owner_user_id
|
||||
`
|
||||
sqliteSelectUserHasReservation = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE user_id = owner_user_id
|
||||
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteSelectOtherAccessCount = `
|
||||
SELECT COUNT(*)
|
||||
FROM user_access
|
||||
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||
`
|
||||
sqliteUpsertUserAccess = `
|
||||
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||
ON CONFLICT (user_id, topic)
|
||||
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||
`
|
||||
sqliteDeleteUserAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||
`
|
||||
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
|
||||
sqliteDeleteTopicAccess = `
|
||||
DELETE FROM user_access
|
||||
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||
AND topic = ?
|
||||
`
|
||||
sqliteDeleteAllAccess = `DELETE FROM user_access`
|
||||
|
||||
// Token queries
|
||||
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||
sqliteUpsertToken = `
|
||||
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (user_id, token)
|
||||
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||
`
|
||||
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
|
||||
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
|
||||
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||
sqliteDeleteExcessTokens = `
|
||||
DELETE FROM user_token
|
||||
WHERE user_id = ?
|
||||
AND (user_id, token) NOT IN (
|
||||
SELECT user_id, token
|
||||
FROM user_token
|
||||
WHERE user_id = ?
|
||||
ORDER BY expires DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`
|
||||
|
||||
// Tier queries
|
||||
sqliteInsertTier = `
|
||||
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteUpdateTier = `
|
||||
UPDATE tier
|
||||
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTiers = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
`
|
||||
sqliteSelectTierByCode = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE code = ?
|
||||
`
|
||||
sqliteSelectTierByPriceID = `
|
||||
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||
FROM tier
|
||||
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||
`
|
||||
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
|
||||
|
||||
// Phone queries
|
||||
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||
|
||||
// Billing queries
|
||||
sqliteUpdateBilling = `
|
||||
UPDATE user
|
||||
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||
WHERE user = ?
|
||||
`
|
||||
)
|
||||
|
||||
// NewSQLiteStore creates a new SQLite-backed user store
|
||||
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLiteDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: sqliteQueries(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sqliteQueries() storeQueries {
|
||||
return storeQueries{
|
||||
selectUserByID: sqliteSelectUserByID,
|
||||
selectUserByName: sqliteSelectUserByName,
|
||||
selectUserByToken: sqliteSelectUserByToken,
|
||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||
selectUsernames: sqliteSelectUsernames,
|
||||
selectUserCount: sqliteSelectUserCount,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||
insertUser: sqliteInsertUser,
|
||||
updateUserPass: sqliteUpdateUserPass,
|
||||
updateUserRole: sqliteUpdateUserRole,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||
updateUserStats: sqliteUpdateUserStats,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||
updateUserTier: sqliteUpdateUserTier,
|
||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||
deleteUser: sqliteDeleteUser,
|
||||
deleteUserTier: sqliteDeleteUserTier,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||
selectTopicPerms: sqliteSelectTopicPerms,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||
selectUserAccess: sqliteSelectUserAccess,
|
||||
selectUserReservations: sqliteSelectUserReservations,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||
upsertUserAccess: sqliteUpsertUserAccess,
|
||||
deleteUserAccess: sqliteDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||
deleteAllAccess: sqliteDeleteAllAccess,
|
||||
selectToken: sqliteSelectToken,
|
||||
selectTokens: sqliteSelectTokens,
|
||||
selectTokenCount: sqliteSelectTokenCount,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||
upsertToken: sqliteUpsertToken,
|
||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||
deleteToken: sqliteDeleteToken,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||
deleteAllToken: sqliteDeleteAllToken,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||
insertTier: sqliteInsertTier,
|
||||
selectTiers: sqliteSelectTiers,
|
||||
selectTierByCode: sqliteSelectTierByCode,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||
updateTier: sqliteUpdateTier,
|
||||
deleteTier: sqliteDeleteTier,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||
updateBilling: sqliteUpdateBilling,
|
||||
}
|
||||
}
|
||||
|
||||
func setupSQLiteDB(db *sql.DB) error {
|
||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewSQLiteDB(db)
|
||||
}
|
||||
defer rowsSV.Close()
|
||||
schemaVersion := 0
|
||||
if !rowsSV.Next() {
|
||||
return fmt.Errorf("cannot determine schema version: database file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < currentSchemaVersion; i++ {
|
||||
fn, ok := migrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLiteDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,11 +2,12 @@ package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
// SQLite migrations
|
||||
const (
|
||||
currentSchemaVersion = 6
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
@@ -242,6 +242,20 @@ const (
|
||||
everyoneID = "u_everyone"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the user Manager
|
||||
type Config struct {
|
||||
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
|
||||
DatabaseURL string // Database connection string (PostgreSQL)
|
||||
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
|
||||
DefaultAccess Permission // Default permission if no ACL matches
|
||||
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||
Users []*User // Predefined users to create on startup
|
||||
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||
}
|
||||
|
||||
// Error constants used by the package
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
|
||||
@@ -25,8 +25,8 @@ const (
|
||||
PRIMARY KEY (subscription_id, topic)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS webpush_schema_version (
|
||||
id INT PRIMARY KEY,
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
@@ -65,8 +65,8 @@ const (
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 1
|
||||
pgInsertSchemaVersion = `INSERT INTO webpush_schema_version VALUES (1, $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1`
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
||||
@@ -102,12 +102,16 @@ func NewPostgresStore(dsn string) (Store, error) {
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
// If 'webpush_schema_version' table does not exist, this must be a new database
|
||||
// If 'schema_version' table does not exist or no webpush row, this must be a new database
|
||||
rows, err := db.Query(pgSelectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
|
||||
Reference in New Issue
Block a user