Compare commits
6 Commits
postgres-w
...
postgres-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae5e1fe8d8 | ||
|
|
e3a402ed95 | ||
|
|
1abc1005d0 | ||
|
|
909c3fe17b | ||
|
|
07c3e280bf | ||
|
|
60fa50f0d5 |
25
cmd/user.go
25
cmd/user.go
@@ -29,6 +29,7 @@ var flagsUser = append(
|
|||||||
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
&cli.StringFlag{Name: "config", Aliases: []string{"c"}, EnvVars: []string{"NTFY_CONFIG_FILE"}, Value: server.DefaultConfigFile, DefaultText: server.DefaultConfigFile, Usage: "config file"},
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-file", Aliases: []string{"auth_file", "H"}, EnvVars: []string{"NTFY_AUTH_FILE"}, Usage: "auth database file used for access control"}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{Name: "database-url", Aliases: []string{"database_url"}, EnvVars: []string{"NTFY_DATABASE_URL"}, Usage: "PostgreSQL connection string for database-backed stores"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
var cmdUser = &cli.Command{
|
var cmdUser = &cli.Command{
|
||||||
@@ -365,24 +366,32 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
|
|||||||
authFile := c.String("auth-file")
|
authFile := c.String("auth-file")
|
||||||
authStartupQueries := c.String("auth-startup-queries")
|
authStartupQueries := c.String("auth-startup-queries")
|
||||||
authDefaultAccess := c.String("auth-default-access")
|
authDefaultAccess := c.String("auth-default-access")
|
||||||
if authFile == "" {
|
databaseURL := c.String("database-url")
|
||||||
return nil, errors.New("option auth-file not set; auth is unconfigured for this server")
|
|
||||||
} else if !util.FileExists(authFile) {
|
|
||||||
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
|
||||||
}
|
|
||||||
authDefault, err := user.ParsePermission(authDefaultAccess)
|
authDefault, err := user.ParsePermission(authDefaultAccess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
return nil, errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
|
||||||
}
|
}
|
||||||
authConfig := &user.Config{
|
authConfig := &user.Config{
|
||||||
Filename: authFile,
|
|
||||||
StartupQueries: authStartupQueries,
|
|
||||||
DefaultAccess: authDefault,
|
DefaultAccess: authDefault,
|
||||||
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
ProvisionEnabled: false, // Hack: Do not re-provision users on manager initialization
|
||||||
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
BcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||||
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
return user.NewManager(authConfig)
|
var store user.Store
|
||||||
|
if databaseURL != "" {
|
||||||
|
store, err = user.NewPostgresStore(databaseURL)
|
||||||
|
} else if authFile != "" {
|
||||||
|
if !util.FileExists(authFile) {
|
||||||
|
return nil, errors.New("auth-file does not exist; please start the server at least once to create it")
|
||||||
|
}
|
||||||
|
store, err = user.NewSQLiteStore(authFile, authStartupQueries)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("option database-url or auth-file not set; auth is unconfigured for this server")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return user.NewManager(store, authConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
func readPasswordAndConfirm(c *cli.Context) (string, error) {
|
||||||
|
|||||||
@@ -204,9 +204,10 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" {
|
if conf.AuthFile != "" || conf.DatabaseURL != "" {
|
||||||
authConfig := &user.Config{
|
authConfig := &user.Config{
|
||||||
Filename: conf.AuthFile,
|
Filename: conf.AuthFile,
|
||||||
|
DatabaseURL: conf.DatabaseURL,
|
||||||
StartupQueries: conf.AuthStartupQueries,
|
StartupQueries: conf.AuthStartupQueries,
|
||||||
DefaultAccess: conf.AuthDefault,
|
DefaultAccess: conf.AuthDefault,
|
||||||
ProvisionEnabled: true, // Enable provisioning of users and access
|
ProvisionEnabled: true, // Enable provisioning of users and access
|
||||||
@@ -216,7 +217,16 @@ func New(conf *Config) (*Server, error) {
|
|||||||
BcryptCost: conf.AuthBcryptCost,
|
BcryptCost: conf.AuthBcryptCost,
|
||||||
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
|
||||||
}
|
}
|
||||||
userManager, err = user.NewManager(authConfig)
|
var store user.Store
|
||||||
|
if conf.DatabaseURL != "" {
|
||||||
|
store, err = user.NewPostgresStore(conf.DatabaseURL)
|
||||||
|
} else {
|
||||||
|
store, err = user.NewSQLiteStore(conf.AuthFile, conf.AuthStartupQueries)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userManager, err = user.NewManager(store, authConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
1177
user/manager.go
1177
user/manager.go
File diff suppressed because it is too large
Load Diff
@@ -3,20 +3,68 @@ package user
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
"heckel.io/ntfy/v2/util"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
|
const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
|
||||||
|
|
||||||
|
// newStoreFunc creates a Store. Calling it multiple times within the same test
|
||||||
|
// returns a new Store object pointing at the same underlying data (same SQLite
|
||||||
|
// file / same PostgreSQL schema), enabling close-and-reopen tests.
|
||||||
|
type newStoreFunc func() Store
|
||||||
|
|
||||||
|
func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) {
|
||||||
|
t.Run("sqlite", func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
f(t, func() Store {
|
||||||
|
store, err := NewSQLiteStore(filepath.Join(dir, "user.db"), "")
|
||||||
|
require.Nil(t, err)
|
||||||
|
return store
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("postgres", func(t *testing.T) {
|
||||||
|
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("NTFY_TEST_DATABASE_URL not set")
|
||||||
|
}
|
||||||
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
|
setupDB, err := sql.Open("pgx", dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("search_path", schema)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
schemaDSN := u.String()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cleanDB, _ := sql.Open("pgx", dsn)
|
||||||
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
})
|
||||||
|
f(t, func() Store {
|
||||||
|
store, err := NewPostgresStore(schemaDSN)
|
||||||
|
require.Nil(t, err)
|
||||||
|
return store
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
||||||
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AddUser("john", "john", RoleUser, false))
|
require.Nil(t, a.AddUser("john", "john", RoleUser, false))
|
||||||
@@ -35,7 +83,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||||||
phil, err := a.Authenticate("phil", "phil")
|
phil, err := a.Authenticate("phil", "phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "phil", phil.Name)
|
require.Equal(t, "phil", phil.Name)
|
||||||
require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$"))
|
require.True(t, strings.HasPrefix(phil.Hash, "$2a$04$"))
|
||||||
require.Equal(t, RoleAdmin, phil.Role)
|
require.Equal(t, RoleAdmin, phil.Role)
|
||||||
|
|
||||||
philGrants, err := a.Grants("phil")
|
philGrants, err := a.Grants("phil")
|
||||||
@@ -45,7 +93,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||||||
ben, err := a.Authenticate("ben", "ben")
|
ben, err := a.Authenticate("ben", "ben")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "ben", ben.Name)
|
require.Equal(t, "ben", ben.Name)
|
||||||
require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$"))
|
require.True(t, strings.HasPrefix(ben.Hash, "$2a$04$"))
|
||||||
require.Equal(t, RoleUser, ben.Role)
|
require.Equal(t, RoleUser, ben.Role)
|
||||||
|
|
||||||
benGrants, err := a.Grants("ben")
|
benGrants, err := a.Grants("ben")
|
||||||
@@ -60,7 +108,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||||||
john, err := a.Authenticate("john", "john")
|
john, err := a.Authenticate("john", "john")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "john", john.Name)
|
require.Equal(t, "john", john.Name)
|
||||||
require.True(t, strings.HasPrefix(john.Hash, "$2a$10$"))
|
require.True(t, strings.HasPrefix(john.Hash, "$2a$04$"))
|
||||||
require.Equal(t, RoleUser, john.Role)
|
require.Equal(t, RoleUser, john.Role)
|
||||||
|
|
||||||
johnGrants, err := a.Grants("john")
|
johnGrants, err := a.Grants("john")
|
||||||
@@ -126,13 +174,14 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
|
|||||||
require.Nil(t, a.Authorize(nil, "everyonewrite", PermissionWrite))
|
require.Nil(t, a.Authorize(nil, "everyonewrite", PermissionWrite))
|
||||||
require.Nil(t, a.Authorize(nil, "up1234", PermissionWrite)) // Wildcard permission
|
require.Nil(t, a.Authorize(nil, "up1234", PermissionWrite)) // Wildcard permission
|
||||||
require.Nil(t, a.Authorize(nil, "up5678", PermissionWrite))
|
require.Nil(t, a.Authorize(nil, "up5678", PermissionWrite))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Access_Order_LengthWriteRead(t *testing.T) {
|
func TestManager_Access_Order_LengthWriteRead(t *testing.T) {
|
||||||
// This test validates issue #914 / #917, i.e. that write permissions are prioritized over read permissions,
|
// This test validates issue #914 / #917, i.e. that write permissions are prioritized over read permissions,
|
||||||
// and longer ACL rules are prioritized as well.
|
// and longer ACL rules are prioritized as well.
|
||||||
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AllowAccess("ben", "test*", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess("ben", "test*", PermissionReadWrite))
|
||||||
require.Nil(t, a.AllowAccess("ben", "*", PermissionRead))
|
require.Nil(t, a.AllowAccess("ben", "*", PermissionRead))
|
||||||
@@ -142,12 +191,15 @@ func TestManager_Access_Order_LengthWriteRead(t *testing.T) {
|
|||||||
require.Nil(t, a.Authorize(ben, "any-topic-can-be-read", PermissionRead))
|
require.Nil(t, a.Authorize(ben, "any-topic-can-be-read", PermissionRead))
|
||||||
require.Nil(t, a.Authorize(ben, "this-too", PermissionRead))
|
require.Nil(t, a.Authorize(ben, "this-too", PermissionRead))
|
||||||
require.Nil(t, a.Authorize(ben, "test123", PermissionWrite))
|
require.Nil(t, a.Authorize(ben, "test123", PermissionWrite))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_AddUser_Invalid(t *testing.T) {
|
func TestManager_AddUser_Invalid(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, false))
|
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, false))
|
||||||
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", false))
|
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", false))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_AddUser_Timing(t *testing.T) {
|
func TestManager_AddUser_Timing(t *testing.T) {
|
||||||
@@ -158,7 +210,8 @@ func TestManager_AddUser_Timing(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_AddUser_And_Query(t *testing.T) {
|
func TestManager_AddUser_And_Query(t *testing.T) {
|
||||||
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
||||||
require.Nil(t, a.ChangeBilling("user", &Billing{
|
require.Nil(t, a.ChangeBilling("user", &Billing{
|
||||||
StripeCustomerID: "acct_123",
|
StripeCustomerID: "acct_123",
|
||||||
@@ -180,10 +233,12 @@ func TestManager_AddUser_And_Query(t *testing.T) {
|
|||||||
u3, err := a.UserByStripeCustomer("acct_123")
|
u3, err := a.UserByStripeCustomer("acct_123")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, u.ID, u3.ID)
|
require.Equal(t, u.ID, u3.ID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
// Create user, add reservations and token
|
// Create user, add reservations and token
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
||||||
@@ -224,16 +279,20 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.True(t, u.Deleted)
|
require.True(t, u.Deleted)
|
||||||
|
|
||||||
_, err = a.db.Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
|
// Backdate the deleted timestamp so RemoveDeletedUsers will prune the user
|
||||||
|
q := a.store.(*commonStore).queries.updateUserDeleted
|
||||||
|
_, err = testDB(a).Exec(q, time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.RemoveDeletedUsers())
|
require.Nil(t, a.RemoveDeletedUsers())
|
||||||
|
|
||||||
_, err = a.User("user")
|
_, err = a.User("user")
|
||||||
require.Equal(t, ErrUserNotFound, err)
|
require.Equal(t, ErrUserNotFound, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_CreateToken_Only_Lower(t *testing.T) {
|
func TestManager_CreateToken_Only_Lower(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
// Create user, add reservations and token
|
// Create user, add reservations and token
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
|
||||||
@@ -243,10 +302,12 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) {
|
|||||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
|
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, token.Value, strings.ToLower(token.Value))
|
require.Equal(t, token.Value, strings.ToLower(token.Value))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_UserManagement(t *testing.T) {
|
func TestManager_UserManagement(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||||
@@ -334,10 +395,12 @@ func TestManager_UserManagement(t *testing.T) {
|
|||||||
require.Equal(t, 2, len(users))
|
require.Equal(t, 2, len(users))
|
||||||
require.Equal(t, "phil", users[0].Name)
|
require.Equal(t, "phil", users[0].Name)
|
||||||
require.Equal(t, "*", users[1].Name)
|
require.Equal(t, "*", users[1].Name)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangePassword(t *testing.T) {
|
func TestManager_ChangePassword(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
|
||||||
require.Nil(t, a.AddUser("jane", "$2a$10$OyqU72muEy7VMd1SAU2Iru5IbeSMgrtCGHu/fWLmxL1MwlijQXWbG", RoleUser, true))
|
require.Nil(t, a.AddUser("jane", "$2a$10$OyqU72muEy7VMd1SAU2Iru5IbeSMgrtCGHu/fWLmxL1MwlijQXWbG", RoleUser, true))
|
||||||
|
|
||||||
@@ -358,10 +421,12 @@ func TestManager_ChangePassword(t *testing.T) {
|
|||||||
require.Equal(t, ErrUnauthenticated, err)
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
_, err = a.Authenticate("jane", "newpass")
|
_, err = a.Authenticate("jane", "newpass")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeRole(t *testing.T) {
|
func TestManager_ChangeRole(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
|
||||||
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
|
||||||
@@ -383,10 +448,12 @@ func TestManager_ChangeRole(t *testing.T) {
|
|||||||
benGrants, err = a.Grants("ben")
|
benGrants, err = a.Grants("ben")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(benGrants))
|
require.Equal(t, 0, len(benGrants))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Reservations(t *testing.T) {
|
func TestManager_Reservations(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
|
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
|
||||||
@@ -453,10 +520,12 @@ func TestManager_Reservations(t *testing.T) {
|
|||||||
count, err = a.ReservationsCount("ben")
|
count, err = a.ReservationsCount("ben")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, int64(0), count)
|
require.Equal(t, int64(0), count)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddTier(&Tier{
|
require.Nil(t, a.AddTier(&Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
Name: "ntfy Pro",
|
Name: "ntfy Pro",
|
||||||
@@ -512,10 +581,12 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
|||||||
everyoneGrants, err = a.Grants(Everyone)
|
everyoneGrants, err = a.Grants(Everyone)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(everyoneGrants))
|
require.Equal(t, 0, len(everyoneGrants))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_Valid(t *testing.T) {
|
func TestManager_Token_Valid(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
u, err := a.User("ben")
|
u, err := a.User("ben")
|
||||||
@@ -556,10 +627,12 @@ func TestManager_Token_Valid(t *testing.T) {
|
|||||||
tokens, err = a.Tokens(u.ID)
|
tokens, err = a.Tokens(u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(tokens))
|
require.Equal(t, 0, len(tokens))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_Invalid(t *testing.T) {
|
func TestManager_Token_Invalid(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
|
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
|
||||||
@@ -569,16 +642,20 @@ func TestManager_Token_Invalid(t *testing.T) {
|
|||||||
u, err = a.AuthenticateToken("not long enough anyway")
|
u, err = a.AuthenticateToken("not long enough anyway")
|
||||||
require.Nil(t, u)
|
require.Nil(t, u)
|
||||||
require.Equal(t, ErrUnauthenticated, err)
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_NotFound(t *testing.T) {
|
func TestManager_Token_NotFound(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
_, err := a.Token("u_bla", "notfound")
|
_, err := a.Token("u_bla", "notfound")
|
||||||
require.Equal(t, ErrTokenNotFound, err)
|
require.Equal(t, ErrTokenNotFound, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_Expire(t *testing.T) {
|
func TestManager_Token_Expire(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
u, err := a.User("ben")
|
u, err := a.User("ben")
|
||||||
@@ -603,30 +680,33 @@ func TestManager_Token_Expire(t *testing.T) {
|
|||||||
_, err = a.AuthenticateToken(token2.Value)
|
_, err = a.AuthenticateToken(token2.Value)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Modify token expiration in database
|
// Expire token1 via the API
|
||||||
_, err = a.db.Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value)
|
_, err = a.ChangeToken(u.ID, token1.Value, nil, util.Time(time.Unix(1, 0)))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Now token1 shouldn't work anymore
|
// Now token1 shouldn't work anymore
|
||||||
_, err = a.AuthenticateToken(token1.Value)
|
_, err = a.AuthenticateToken(token1.Value)
|
||||||
require.Equal(t, ErrUnauthenticated, err)
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
|
|
||||||
result, err := a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
// But the token row should still exist
|
||||||
|
tokens, err := a.Tokens(u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.True(t, result.Next()) // Still a matching row
|
require.Equal(t, token1.Value, tokens[0].Value)
|
||||||
require.Nil(t, result.Close())
|
require.Equal(t, 2, len(tokens))
|
||||||
|
|
||||||
// Expire tokens and check database rows
|
// Expire tokens and check that token1 is gone
|
||||||
require.Nil(t, a.RemoveExpiredTokens())
|
require.Nil(t, a.RemoveExpiredTokens())
|
||||||
|
|
||||||
result, err = a.db.Query("SELECT * from user_token WHERE token = ?", token1.Value)
|
tokens, err = a.Tokens(u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.False(t, result.Next()) // No matching row!
|
require.Equal(t, 1, len(tokens))
|
||||||
require.Nil(t, result.Close())
|
require.Equal(t, token2.Value, tokens[0].Value)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_Extend(t *testing.T) {
|
func TestManager_Token_Extend(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
// Try to extend token for user without token
|
// Try to extend token for user without token
|
||||||
@@ -650,12 +730,13 @@ func TestManager_Token_Extend(t *testing.T) {
|
|||||||
require.Equal(t, "changed label", extendedToken.Label)
|
require.Equal(t, "changed label", extendedToken.Label)
|
||||||
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
|
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
|
||||||
require.True(t, time.Now().Add(99*time.Hour).Unix() < extendedToken.Expires.Unix())
|
require.True(t, time.Now().Add(99*time.Hour).Unix() < extendedToken.Expires.Unix())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||||
// Tests that tokens are automatically deleted when the maximum number of tokens is reached
|
// Tests that tokens are automatically deleted when the maximum number of tokens is reached
|
||||||
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||||
|
|
||||||
@@ -687,7 +768,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||||||
benTokens = append(benTokens, token.Value)
|
benTokens = append(benTokens, token.Value)
|
||||||
|
|
||||||
// Manually modify expiry date to avoid sorting issues (this is a hack)
|
// Manually modify expiry date to avoid sorting issues (this is a hack)
|
||||||
_, err = a.db.Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value)
|
_, err = a.ChangeToken(ben.ID, token.Value, nil, util.Time(baseTime.Add(time.Duration(i)*time.Minute)))
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,31 +795,24 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
|||||||
require.Equal(t, philTokens[i], userWithToken.Token)
|
require.Equal(t, philTokens[i], userWithToken.Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
var benCount int
|
benTokensList, err := a.Tokens(ben.ID)
|
||||||
rows, err := a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.True(t, rows.Next())
|
require.Equal(t, 60, len(benTokensList))
|
||||||
require.Nil(t, rows.Scan(&benCount))
|
|
||||||
require.Equal(t, 60, benCount)
|
|
||||||
|
|
||||||
var philCount int
|
philTokensList, err := a.Tokens(phil.ID)
|
||||||
rows, err = a.db.Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID)
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.True(t, rows.Next())
|
require.Equal(t, 2, len(philTokensList))
|
||||||
require.Nil(t, rows.Scan(&philCount))
|
})
|
||||||
require.Equal(t, 2, philCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: filepath.Join(t.TempDir(), "db"),
|
|
||||||
StartupQueries: "",
|
|
||||||
DefaultAccess: PermissionReadWrite,
|
DefaultAccess: PermissionReadWrite,
|
||||||
BcryptCost: bcrypt.MinCost,
|
BcryptCost: bcrypt.MinCost,
|
||||||
QueueWriterInterval: 1500 * time.Millisecond,
|
QueueWriterInterval: 1500 * time.Millisecond,
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a := newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
// Baseline: No messages or emails
|
// Baseline: No messages or emails
|
||||||
@@ -776,18 +850,17 @@ func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, int64(0), u.Stats.Messages)
|
require.Equal(t, int64(0), u.Stats.Messages)
|
||||||
require.Equal(t, int64(0), u.Stats.Emails)
|
require.Equal(t, int64(0), u.Stats.Emails)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: filepath.Join(t.TempDir(), "db"),
|
|
||||||
StartupQueries: "",
|
|
||||||
DefaultAccess: PermissionReadWrite,
|
DefaultAccess: PermissionReadWrite,
|
||||||
BcryptCost: bcrypt.MinCost,
|
BcryptCost: bcrypt.MinCost,
|
||||||
QueueWriterInterval: 500 * time.Millisecond,
|
QueueWriterInterval: 500 * time.Millisecond,
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a := newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
// Create user and token
|
// Create user and token
|
||||||
@@ -816,18 +889,17 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix())
|
require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix())
|
||||||
require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin)
|
require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeSettings(t *testing.T) {
|
func TestManager_ChangeSettings(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: filepath.Join(t.TempDir(), "db"),
|
|
||||||
StartupQueries: "",
|
|
||||||
DefaultAccess: PermissionReadWrite,
|
DefaultAccess: PermissionReadWrite,
|
||||||
BcryptCost: bcrypt.MinCost,
|
BcryptCost: bcrypt.MinCost,
|
||||||
QueueWriterInterval: 1500 * time.Millisecond,
|
QueueWriterInterval: 1500 * time.Millisecond,
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a := newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
|
|
||||||
// No settings
|
// No settings
|
||||||
@@ -864,10 +936,12 @@ func TestManager_ChangeSettings(t *testing.T) {
|
|||||||
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
|
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
|
||||||
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
|
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
|
||||||
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
|
func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
// Create tier and user
|
// Create tier and user
|
||||||
require.Nil(t, a.AddTier(&Tier{
|
require.Nil(t, a.AddTier(&Tier{
|
||||||
@@ -982,10 +1056,12 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
|
|||||||
require.Equal(t, 1, len(tiers))
|
require.Equal(t, 1, len(tiers))
|
||||||
require.Equal(t, "pro", tiers[0].Code)
|
require.Equal(t, "pro", tiers[0].Code)
|
||||||
require.Equal(t, "pro", tiers[0].Code)
|
require.Equal(t, "pro", tiers[0].Code)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
require.Nil(t, a.AddTier(&Tier{
|
require.Nil(t, a.AddTier(&Tier{
|
||||||
ID: "ti_123",
|
ID: "ti_123",
|
||||||
@@ -995,10 +1071,12 @@ func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
|||||||
ti, err := a.Tier("pro")
|
ti, err := a.Tier("pro")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "ti_123", ti.ID)
|
require.Equal(t, "ti_123", ti.ID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Tier_Change_And_Reset(t *testing.T) {
|
func TestManager_Tier_Change_And_Reset(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
// Create tier and user
|
// Create tier and user
|
||||||
require.Nil(t, a.AddTier(&Tier{
|
require.Nil(t, a.AddTier(&Tier{
|
||||||
@@ -1032,10 +1110,12 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) {
|
|||||||
// Resetting after removing all reservations
|
// Resetting after removing all reservations
|
||||||
require.Nil(t, a.RemoveReservations("phil", "topic1", "topic2", "topic3"))
|
require.Nil(t, a.RemoveReservations("phil", "topic1", "topic2", "topic3"))
|
||||||
require.Nil(t, a.ResetTier("phil"))
|
require.Nil(t, a.ResetTier("phil"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_PhoneNumberAddListRemove(t *testing.T) {
|
func TestUser_PhoneNumberAddListRemove(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||||
phil, err := a.User("phil")
|
phil, err := a.User("phil")
|
||||||
@@ -1053,14 +1133,16 @@ func TestUser_PhoneNumberAddListRemove(t *testing.T) {
|
|||||||
require.Equal(t, 0, len(phoneNumbers))
|
require.Equal(t, 0, len(phoneNumbers))
|
||||||
|
|
||||||
// Paranoia check: We do NOT want to keep phone numbers in there
|
// Paranoia check: We do NOT want to keep phone numbers in there
|
||||||
rows, err := a.db.Query(`SELECT * FROM user_phone`)
|
rows, err := testDB(a).Query(`SELECT * FROM user_phone`)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.False(t, rows.Next())
|
require.False(t, rows.Next())
|
||||||
require.Nil(t, rows.Close())
|
require.Nil(t, rows.Close())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
|
func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
|
|
||||||
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
|
||||||
@@ -1070,11 +1152,12 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddPhoneNumber(phil.ID, "+1234567890"))
|
require.Nil(t, a.AddPhoneNumber(phil.ID, "+1234567890"))
|
||||||
require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
|
require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
|
func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
|
||||||
f := filepath.Join(t.TempDir(), "user.db")
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead))
|
require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead))
|
||||||
require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead))
|
require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead))
|
||||||
require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead))
|
require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead))
|
||||||
@@ -1083,22 +1166,23 @@ func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
|
|||||||
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead))
|
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead))
|
||||||
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead))
|
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead))
|
||||||
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead))
|
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
|
func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
|
||||||
f := filepath.Join(t.TempDir(), "user.db")
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManager(t, newStore, PermissionDenyAll)
|
||||||
require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite))
|
||||||
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
|
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
|
||||||
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite))
|
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite))
|
||||||
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
|
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
|
||||||
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite))
|
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_WithProvisionedUsers(t *testing.T) {
|
func TestManager_WithProvisionedUsers(t *testing.T) {
|
||||||
f := filepath.Join(t.TempDir(), "user.db")
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: f,
|
|
||||||
DefaultAccess: PermissionReadWrite,
|
DefaultAccess: PermissionReadWrite,
|
||||||
ProvisionEnabled: true,
|
ProvisionEnabled: true,
|
||||||
Users: []*User{
|
Users: []*User{
|
||||||
@@ -1117,8 +1201,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a := newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Manually add user
|
// Manually add user
|
||||||
require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false))
|
require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false))
|
||||||
@@ -1154,13 +1237,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||||||
// Update the token last access time and origin (so we can check that it is persisted)
|
// Update the token last access time and origin (so we can check that it is persisted)
|
||||||
lastAccessTime := time.Now().Add(time.Hour)
|
lastAccessTime := time.Now().Add(time.Hour)
|
||||||
lastOrigin := netip.MustParseAddr("1.1.9.9")
|
lastOrigin := netip.MustParseAddr("1.1.9.9")
|
||||||
err = execTx(a.db, func(tx *sql.Tx) error {
|
err = a.store.UpdateTokenLastAccess(tokens[0].Value, lastAccessTime, lastOrigin)
|
||||||
return a.updateTokenLastAccessTx(tx, tokens[0].Value, lastAccessTime.Unix(), lastOrigin.String())
|
|
||||||
})
|
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Re-open the DB (second app start)
|
// Re-open the DB (second app start)
|
||||||
require.Nil(t, a.db.Close())
|
require.Nil(t, a.Close())
|
||||||
conf.Users = []*User{
|
conf.Users = []*User{
|
||||||
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
||||||
}
|
}
|
||||||
@@ -1176,8 +1257,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||||||
{Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"},
|
{Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
a, err = NewManager(conf)
|
a = newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Check that the provisioned users are there
|
// Check that the provisioned users are there
|
||||||
users, err = a.Users()
|
users, err = a.Users()
|
||||||
@@ -1212,12 +1292,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||||||
require.Error(t, a.ChangePassword("philuser", "new-pass", false))
|
require.Error(t, a.ChangePassword("philuser", "new-pass", false))
|
||||||
|
|
||||||
// Re-open the DB again (third app start)
|
// Re-open the DB again (third app start)
|
||||||
require.Nil(t, a.db.Close())
|
require.Nil(t, a.Close())
|
||||||
conf.Users = []*User{}
|
conf.Users = []*User{}
|
||||||
conf.Access = map[string][]*Grant{}
|
conf.Access = map[string][]*Grant{}
|
||||||
conf.Tokens = map[string][]*Token{}
|
conf.Tokens = map[string][]*Token{}
|
||||||
a, err = NewManager(conf)
|
a = newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Check that the provisioned users are all gone
|
// Check that the provisioned users are all gone
|
||||||
users, err = a.Users()
|
users, err = a.Users()
|
||||||
@@ -1236,18 +1315,26 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
|
|||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, 0, len(tokens))
|
require.Equal(t, 0, len(tokens))
|
||||||
|
|
||||||
var count int
|
// Verify no provisioned data remains
|
||||||
a.db.QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count)
|
for _, u := range users {
|
||||||
require.Equal(t, 0, count)
|
require.False(t, u.Provisioned)
|
||||||
a.db.QueryRow("SELECT COUNT(*) FROM user_grant WHERE provisioned = 1").Scan(&count)
|
userGrants, err := a.Grants(u.Name)
|
||||||
require.Equal(t, 0, count)
|
require.Nil(t, err)
|
||||||
a.db.QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count)
|
for _, g := range userGrants {
|
||||||
|
require.False(t, g.Provisioned)
|
||||||
|
}
|
||||||
|
userTokens, err := a.Tokens(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
for _, tk := range userTokens {
|
||||||
|
require.False(t, tk.Provisioned)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
||||||
f := filepath.Join(t.TempDir(), "user.db")
|
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: f,
|
|
||||||
DefaultAccess: PermissionReadWrite,
|
DefaultAccess: PermissionReadWrite,
|
||||||
ProvisionEnabled: true,
|
ProvisionEnabled: true,
|
||||||
Users: []*User{},
|
Users: []*User{},
|
||||||
@@ -1257,8 +1344,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a := newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Manually add user
|
// Manually add user
|
||||||
require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false))
|
require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false))
|
||||||
@@ -1290,7 +1376,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
|||||||
require.True(t, grants[0].Provisioned) // Provisioned entry
|
require.True(t, grants[0].Provisioned) // Provisioned entry
|
||||||
|
|
||||||
// Re-open the DB (second app start)
|
// Re-open the DB (second app start)
|
||||||
require.Nil(t, a.db.Close())
|
require.Nil(t, a.Close())
|
||||||
conf.Users = []*User{
|
conf.Users = []*User{
|
||||||
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
{Name: "philuser", Hash: "$2a$10$AAAAU21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C", Role: RoleUser},
|
||||||
}
|
}
|
||||||
@@ -1299,8 +1385,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
|||||||
{TopicPattern: "stats", Permission: PermissionReadWrite},
|
{TopicPattern: "stats", Permission: PermissionReadWrite},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
a, err = NewManager(conf)
|
a = newTestManagerFromStoreConfig(t, newStore, conf)
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Check that the user was "upgraded" to a provisioned user
|
// Check that the user was "upgraded" to a provisioned user
|
||||||
users, err = a.Users()
|
users, err = a.Users()
|
||||||
@@ -1324,6 +1409,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
|
|||||||
grants, err = a.Grants(Everyone)
|
grants, err = a.Grants(Everyone)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Empty(t, grants)
|
require.Empty(t, grants)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToFromSQLWildcard(t *testing.T) {
|
func TestToFromSQLWildcard(t *testing.T) {
|
||||||
@@ -1383,7 +1469,7 @@ func TestMigrationFrom1(t *testing.T) {
|
|||||||
|
|
||||||
// Create manager to trigger migration
|
// Create manager to trigger migration
|
||||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||||
checkSchemaVersion(t, a.db)
|
checkSchemaVersion(t, testDB(a))
|
||||||
|
|
||||||
users, err := a.Users()
|
users, err := a.Users()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
@@ -1526,7 +1612,7 @@ func TestMigrationFrom4(t *testing.T) {
|
|||||||
|
|
||||||
// Create manager to trigger migration
|
// Create manager to trigger migration
|
||||||
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
||||||
checkSchemaVersion(t, a.db)
|
checkSchemaVersion(t, testDB(a))
|
||||||
|
|
||||||
// Add another
|
// Add another
|
||||||
require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite))
|
require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite))
|
||||||
@@ -1578,23 +1664,43 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
|||||||
|
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
require.Nil(t, rows.Scan(&schemaVersion))
|
require.Nil(t, rows.Scan(&schemaVersion))
|
||||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion)
|
||||||
require.Nil(t, rows.Close())
|
require.Nil(t, rows.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
|
func newTestManager(t *testing.T, newStore newStoreFunc, defaultAccess Permission) *Manager {
|
||||||
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
|
store := newStore()
|
||||||
|
a, err := NewManager(store, &Config{
|
||||||
|
DefaultAccess: defaultAccess,
|
||||||
|
BcryptCost: bcrypt.MinCost,
|
||||||
|
QueueWriterInterval: DefaultUserStatsQueueWriterInterval,
|
||||||
|
})
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { a.Close() })
|
||||||
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager {
|
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager {
|
||||||
|
store, err := NewSQLiteStore(filename, startupQueries)
|
||||||
|
require.Nil(t, err)
|
||||||
conf := &Config{
|
conf := &Config{
|
||||||
Filename: filename,
|
|
||||||
StartupQueries: startupQueries,
|
|
||||||
DefaultAccess: defaultAccess,
|
DefaultAccess: defaultAccess,
|
||||||
BcryptCost: bcryptCost,
|
BcryptCost: bcryptCost,
|
||||||
QueueWriterInterval: statsWriterInterval,
|
QueueWriterInterval: statsWriterInterval,
|
||||||
}
|
}
|
||||||
a, err := NewManager(conf)
|
a, err := NewManager(store, conf)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestManagerFromStoreConfig(t *testing.T, newStore newStoreFunc, conf *Config) *Manager {
|
||||||
|
store := newStore()
|
||||||
|
a, err := NewManager(store, conf)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { a.Close() })
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDB(a *Manager) *sql.DB {
|
||||||
|
return a.store.(*commonStore).db
|
||||||
|
}
|
||||||
|
|||||||
986
user/store.go
Normal file
986
user/store.go
Normal file
@@ -0,0 +1,986 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/payments"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is the interface for a user database store
|
||||||
|
type Store interface {
|
||||||
|
// User operations
|
||||||
|
UserByID(id string) (*User, error)
|
||||||
|
User(username string) (*User, error)
|
||||||
|
UserByToken(token string) (*User, error)
|
||||||
|
UserByStripeCustomer(customerID string) (*User, error)
|
||||||
|
UserIDByUsername(username string) (string, error)
|
||||||
|
Users() ([]*User, error)
|
||||||
|
UsersCount() (int64, error)
|
||||||
|
AddUser(username, hash string, role Role, provisioned bool) error
|
||||||
|
RemoveUser(username string) error
|
||||||
|
MarkUserRemoved(userID string) error
|
||||||
|
RemoveDeletedUsers() error
|
||||||
|
ChangePassword(username, hash string) error
|
||||||
|
ChangeRole(username string, role Role) error
|
||||||
|
ChangeProvisioned(username string, provisioned bool) error
|
||||||
|
ChangeSettings(userID string, prefs *Prefs) error
|
||||||
|
ChangeTier(username, tierCode string) error
|
||||||
|
ResetTier(username string) error
|
||||||
|
UpdateStats(userID string, stats *Stats) error
|
||||||
|
ResetStats() error
|
||||||
|
|
||||||
|
// Token operations
|
||||||
|
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
|
||||||
|
Token(userID, token string) (*Token, error)
|
||||||
|
Tokens(userID string) ([]*Token, error)
|
||||||
|
AllProvisionedTokens() ([]*Token, error)
|
||||||
|
ChangeTokenLabel(userID, token, label string) error
|
||||||
|
ChangeTokenExpiry(userID, token string, expires time.Time) error
|
||||||
|
UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error
|
||||||
|
RemoveToken(userID, token string) error
|
||||||
|
RemoveExpiredTokens() error
|
||||||
|
TokenCount(userID string) (int, error)
|
||||||
|
RemoveExcessTokens(userID string, maxCount int) error
|
||||||
|
|
||||||
|
// Access operations
|
||||||
|
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
|
||||||
|
AllGrants() (map[string][]Grant, error)
|
||||||
|
Grants(username string) ([]Grant, error)
|
||||||
|
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
|
||||||
|
ResetAccess(username, topicPattern string) error
|
||||||
|
ResetAllProvisionedAccess() error
|
||||||
|
Reservations(username string) ([]Reservation, error)
|
||||||
|
HasReservation(username, topic string) (bool, error)
|
||||||
|
ReservationsCount(username string) (int64, error)
|
||||||
|
ReservationOwner(topic string) (string, error)
|
||||||
|
OtherAccessCount(username, topic string) (int, error)
|
||||||
|
|
||||||
|
// Tier operations
|
||||||
|
AddTier(tier *Tier) error
|
||||||
|
UpdateTier(tier *Tier) error
|
||||||
|
RemoveTier(code string) error
|
||||||
|
Tiers() ([]*Tier, error)
|
||||||
|
Tier(code string) (*Tier, error)
|
||||||
|
TierByStripePrice(priceID string) (*Tier, error)
|
||||||
|
|
||||||
|
// Phone operations
|
||||||
|
PhoneNumbers(userID string) ([]string, error)
|
||||||
|
AddPhoneNumber(userID, phoneNumber string) error
|
||||||
|
RemovePhoneNumber(userID, phoneNumber string) error
|
||||||
|
|
||||||
|
// Other stuff
|
||||||
|
ChangeBilling(username string, billing *Billing) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeQueries holds the database-specific SQL queries
|
||||||
|
type storeQueries struct {
|
||||||
|
// User queries
|
||||||
|
selectUserByID string
|
||||||
|
selectUserByName string
|
||||||
|
selectUserByToken string
|
||||||
|
selectUserByStripeID string
|
||||||
|
selectUsernames string
|
||||||
|
selectUserCount string
|
||||||
|
selectUserIDFromUsername string
|
||||||
|
insertUser string
|
||||||
|
updateUserPass string
|
||||||
|
updateUserRole string
|
||||||
|
updateUserProvisioned string
|
||||||
|
updateUserPrefs string
|
||||||
|
updateUserStats string
|
||||||
|
updateUserStatsResetAll string
|
||||||
|
updateUserTier string
|
||||||
|
updateUserDeleted string
|
||||||
|
deleteUser string
|
||||||
|
deleteUserTier string
|
||||||
|
deleteUsersMarked string
|
||||||
|
// Access queries
|
||||||
|
selectTopicPerms string
|
||||||
|
selectUserAllAccess string
|
||||||
|
selectUserAccess string
|
||||||
|
selectUserReservations string
|
||||||
|
selectUserReservationsCount string
|
||||||
|
selectUserReservationsOwner string
|
||||||
|
selectUserHasReservation string
|
||||||
|
selectOtherAccessCount string
|
||||||
|
upsertUserAccess string
|
||||||
|
deleteUserAccess string
|
||||||
|
deleteUserAccessProvisioned string
|
||||||
|
deleteTopicAccess string
|
||||||
|
deleteAllAccess string
|
||||||
|
// Token queries
|
||||||
|
selectToken string
|
||||||
|
selectTokens string
|
||||||
|
selectTokenCount string
|
||||||
|
selectAllProvisionedTokens string
|
||||||
|
upsertToken string
|
||||||
|
updateTokenLabel string
|
||||||
|
updateTokenExpiry string
|
||||||
|
updateTokenLastAccess string
|
||||||
|
deleteToken string
|
||||||
|
deleteProvisionedToken string
|
||||||
|
deleteAllToken string
|
||||||
|
deleteExpiredTokens string
|
||||||
|
deleteExcessTokens string
|
||||||
|
// Tier queries
|
||||||
|
insertTier string
|
||||||
|
selectTiers string
|
||||||
|
selectTierByCode string
|
||||||
|
selectTierByPriceID string
|
||||||
|
updateTier string
|
||||||
|
deleteTier string
|
||||||
|
// Phone queries
|
||||||
|
selectPhoneNumbers string
|
||||||
|
insertPhoneNumber string
|
||||||
|
deletePhoneNumber string
|
||||||
|
// Billing queries
|
||||||
|
updateBilling string
|
||||||
|
}
|
||||||
|
|
||||||
|
// commonStore implements store operations that work across database backends
|
||||||
|
type commonStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
queries storeQueries
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
|
||||||
|
func (s *commonStore) UserByID(id string) (*User, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserByID, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.readUser(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User returns the user with the given username if it exists, or ErrUserNotFound otherwise
|
||||||
|
func (s *commonStore) User(username string) (*User, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserByName, username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.readUser(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserByToken returns the user with the given token if it exists and is not expired, or ErrUserNotFound otherwise
|
||||||
|
func (s *commonStore) UserByToken(token string) (*User, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserByToken, token, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.readUser(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise
|
||||||
|
func (s *commonStore) UserByStripeCustomer(customerID string) (*User, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserByStripeID, customerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.readUser(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Users returns a list of users
|
||||||
|
func (s *commonStore) Users() ([]*User, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUsernames)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
usernames := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var username string
|
||||||
|
if err := rows.Scan(&username); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
usernames = append(usernames, username)
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
users := make([]*User, 0)
|
||||||
|
for _, username := range usernames {
|
||||||
|
user, err := s.User(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
users = append(users, user)
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsersCount returns the number of users in the database
|
||||||
|
func (s *commonStore) UsersCount() (int64, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserCount)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := rows.Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUser adds a user with the given username, password hash and role
|
||||||
|
func (s *commonStore) AddUser(username, hash string, role Role, provisioned bool) error {
|
||||||
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||||
|
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if _, err := s.db.Exec(s.queries.insertUser, userID, username, hash, string(role), syncTopic, provisioned, now); err != nil {
|
||||||
|
if isUniqueConstraintError(err) {
|
||||||
|
return ErrUserExists
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveUser deletes the user with the given username
|
||||||
|
func (s *commonStore) RemoveUser(username string) error {
|
||||||
|
if !AllowedUsername(username) {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
// Rows in user_access, user_token, etc. are deleted via foreign keys
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteUser, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
|
||||||
|
func (s *commonStore) MarkUserRemoved(userID string) error {
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
// Get username for deleteUserAccess query
|
||||||
|
user, err := s.UserByID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
deletedTime := time.Now().Add(userHardDeleteAfterDuration).Unix()
|
||||||
|
if _, err := tx.Exec(s.queries.updateUserDeleted, deletedTime, userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDeletedUsers deletes all users that have been marked deleted
|
||||||
|
func (s *commonStore) RemoveDeletedUsers() error {
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteUsersMarked, time.Now().Unix()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangePassword changes a user's password
|
||||||
|
func (s *commonStore) ChangePassword(username, hash string) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserPass, hash, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeRole changes a user's role
|
||||||
|
func (s *commonStore) ChangeRole(username string, role Role) error {
|
||||||
|
if !AllowedUsername(username) || !AllowedRole(role) {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(s.queries.updateUserRole, string(role), username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If changing to admin, remove all access entries
|
||||||
|
if role == RoleAdmin {
|
||||||
|
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeProvisioned changes the provisioned status of a user
|
||||||
|
func (s *commonStore) ChangeProvisioned(username string, provisioned bool) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserProvisioned, provisioned, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeSettings persists the user settings
|
||||||
|
func (s *commonStore) ChangeSettings(userID string, prefs *Prefs) error {
|
||||||
|
b, err := json.Marshal(prefs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserPrefs, string(b), userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeTier changes a user's tier using the tier code
|
||||||
|
func (s *commonStore) ChangeTier(username, tierCode string) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserTier, tierCode, username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTier removes the tier from the given user
|
||||||
|
func (s *commonStore) ResetTier(username string) error {
|
||||||
|
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
_, err := s.db.Exec(s.queries.deleteUserTier, username)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStats updates the user statistics
|
||||||
|
func (s *commonStore) UpdateStats(userID string, stats *Stats) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserStats, stats.Messages, stats.Emails, stats.Calls, userID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStats resets all user stats in the user database
|
||||||
|
func (s *commonStore) ResetStats() error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateUserStatsResetAll); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
|
||||||
|
defer rows.Close()
|
||||||
|
var id, username, hash, role, prefs, syncTopic string
|
||||||
|
var provisioned bool
|
||||||
|
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
|
||||||
|
var messages, emails, calls int64
|
||||||
|
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
user := &User{
|
||||||
|
ID: id,
|
||||||
|
Name: username,
|
||||||
|
Hash: hash,
|
||||||
|
Role: Role(role),
|
||||||
|
Prefs: &Prefs{},
|
||||||
|
SyncTopic: syncTopic,
|
||||||
|
Provisioned: provisioned,
|
||||||
|
Stats: &Stats{
|
||||||
|
Messages: messages,
|
||||||
|
Emails: emails,
|
||||||
|
Calls: calls,
|
||||||
|
},
|
||||||
|
Billing: &Billing{
|
||||||
|
StripeCustomerID: stripeCustomerID.String,
|
||||||
|
StripeSubscriptionID: stripeSubscriptionID.String,
|
||||||
|
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
|
||||||
|
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
|
||||||
|
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
|
||||||
|
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
|
||||||
|
},
|
||||||
|
Deleted: deleted.Valid,
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tierCode.Valid {
|
||||||
|
user.Tier = &Tier{
|
||||||
|
ID: tierID.String,
|
||||||
|
Code: tierCode.String,
|
||||||
|
Name: tierName.String,
|
||||||
|
MessageLimit: messagesLimit.Int64,
|
||||||
|
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||||
|
EmailLimit: emailsLimit.Int64,
|
||||||
|
CallLimit: callsLimit.Int64,
|
||||||
|
ReservationLimit: reservationsLimit.Int64,
|
||||||
|
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||||
|
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||||
|
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||||
|
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||||
|
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||||
|
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateToken creates a new token
|
||||||
|
func (s *commonStore) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) {
|
||||||
|
if _, err := s.db.Exec(s.queries.upsertToken, userID, token, label, lastAccess.Unix(), lastOrigin.String(), expires.Unix(), provisioned); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Token{
|
||||||
|
Value: token,
|
||||||
|
Label: label,
|
||||||
|
LastAccess: lastAccess,
|
||||||
|
LastOrigin: lastOrigin,
|
||||||
|
Expires: expires,
|
||||||
|
Provisioned: provisioned,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token returns a specific token for a user
|
||||||
|
func (s *commonStore) Token(userID, token string) (*Token, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectToken, userID, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return s.readToken(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tokens returns all existing tokens for the user with the given user ID
|
||||||
|
func (s *commonStore) Tokens(userID string) ([]*Token, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTokens, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
tokens := make([]*Token, 0)
|
||||||
|
for {
|
||||||
|
token, err := s.readToken(rows)
|
||||||
|
if errors.Is(err, ErrTokenNotFound) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllProvisionedTokens returns all provisioned tokens
|
||||||
|
func (s *commonStore) AllProvisionedTokens() ([]*Token, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectAllProvisionedTokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
tokens := make([]*Token, 0)
|
||||||
|
for {
|
||||||
|
token, err := s.readToken(rows)
|
||||||
|
if errors.Is(err, ErrTokenNotFound) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeTokenLabel updates a token's label
|
||||||
|
func (s *commonStore) ChangeTokenLabel(userID, token, label string) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateTokenLabel, label, userID, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeTokenExpiry updates a token's expiry time
|
||||||
|
func (s *commonStore) ChangeTokenExpiry(userID, token string, expires time.Time) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateTokenExpiry, expires.Unix(), userID, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTokenLastAccess updates a token's last access time and origin
|
||||||
|
func (s *commonStore) UpdateTokenLastAccess(token string, lastAccess time.Time, lastOrigin netip.Addr) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateTokenLastAccess, lastAccess.Unix(), lastOrigin.String(), token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveToken deletes the token
|
||||||
|
func (s *commonStore) RemoveToken(userID, token string) error {
|
||||||
|
if token == "" {
|
||||||
|
return errNoTokenProvided
|
||||||
|
}
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteToken, userID, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveExpiredTokens deletes all expired tokens from the database
|
||||||
|
func (s *commonStore) RemoveExpiredTokens() error {
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteExpiredTokens, time.Now().Unix()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenCount returns the number of tokens for a user
|
||||||
|
func (s *commonStore) TokenCount(userID string) (int, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTokenCount, userID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveExcessTokens deletes excess tokens beyond the specified maximum
|
||||||
|
func (s *commonStore) RemoveExcessTokens(userID string, maxCount int) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteExcessTokens, userID, userID, maxCount); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
|
||||||
|
var token, label, lastOrigin string
|
||||||
|
var lastAccess, expires int64
|
||||||
|
var provisioned bool
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, ErrTokenNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||||
|
if err != nil {
|
||||||
|
lastOriginIP = netip.IPv4Unspecified()
|
||||||
|
}
|
||||||
|
return &Token{
|
||||||
|
Value: token,
|
||||||
|
Label: label,
|
||||||
|
LastAccess: time.Unix(lastAccess, 0),
|
||||||
|
LastOrigin: lastOriginIP,
|
||||||
|
Expires: time.Unix(expires, 0),
|
||||||
|
Provisioned: provisioned,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
|
||||||
|
// The found return value indicates whether an ACL entry was found at all.
|
||||||
|
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
|
||||||
|
if err != nil {
|
||||||
|
return false, false, false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return false, false, false, nil
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&read, &write); err != nil {
|
||||||
|
return false, false, false, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return false, false, false, err
|
||||||
|
}
|
||||||
|
return read, write, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllGrants returns all user-specific access control entries, mapped to their respective user IDs
|
||||||
|
func (s *commonStore) AllGrants() (map[string][]Grant, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserAllAccess)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
grants := make(map[string][]Grant, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var userID, topic string
|
||||||
|
var read, write, provisioned bool
|
||||||
|
if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, ok := grants[userID]; !ok {
|
||||||
|
grants[userID] = make([]Grant, 0)
|
||||||
|
}
|
||||||
|
grants[userID] = append(grants[userID], Grant{
|
||||||
|
TopicPattern: fromSQLWildcard(topic),
|
||||||
|
Permission: NewPermission(read, write),
|
||||||
|
Provisioned: provisioned,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return grants, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grants returns all user-specific access control entries
|
||||||
|
func (s *commonStore) Grants(username string) ([]Grant, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserAccess, username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
grants := make([]Grant, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var topic string
|
||||||
|
var read, write, provisioned bool
|
||||||
|
if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
grants = append(grants, Grant{
|
||||||
|
TopicPattern: fromSQLWildcard(topic),
|
||||||
|
Permission: NewPermission(read, write),
|
||||||
|
Provisioned: provisioned,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return grants, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowAccess adds or updates an entry in the access control list
|
||||||
|
func (s *commonStore) AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
|
||||||
|
if !AllowedUsername(username) && username != Everyone {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
} else if !AllowedTopicPattern(topicPattern) {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAccess removes an access control list entry
|
||||||
|
func (s *commonStore) ResetAccess(username, topicPattern string) error {
|
||||||
|
if !AllowedUsername(username) && username != Everyone && username != "" {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
} else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
if username == "" && topicPattern == "" {
|
||||||
|
_, err := s.db.Exec(s.queries.deleteAllAccess)
|
||||||
|
return err
|
||||||
|
} else if topicPattern == "" {
|
||||||
|
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAllProvisionedAccess removes all provisioned access control entries
|
||||||
|
func (s *commonStore) ResetAllProvisionedAccess() error {
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteUserAccessProvisioned); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reservations returns all user-owned topics, and the associated everyone-access
|
||||||
|
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
reservations := make([]Reservation, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var topic string
|
||||||
|
var ownerRead, ownerWrite bool
|
||||||
|
var everyoneRead, everyoneWrite sql.NullBool
|
||||||
|
if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
reservations = append(reservations, Reservation{
|
||||||
|
Topic: unescapeUnderscore(topic),
|
||||||
|
Owner: NewPermission(ownerRead, ownerWrite),
|
||||||
|
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return reservations, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasReservation returns true if the given topic access is owned by the user
|
||||||
|
func (s *commonStore) HasReservation(username, topic string) (bool, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserHasReservation, username, escapeUnderscore(topic))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return false, errNoRows
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := rows.Scan(&count); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReservationsCount returns the number of reservations owned by this user
|
||||||
|
func (s *commonStore) ReservationsCount(username string) (int64, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserReservationsCount, username)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
|
var count int64
|
||||||
|
if err := rows.Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReservationOwner returns user ID of the user that owns this topic, or an empty string if it's not owned by anyone
|
||||||
|
func (s *commonStore) ReservationOwner(topic string) (string, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserReservationsOwner, escapeUnderscore(topic))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var ownerUserID string
|
||||||
|
if err := rows.Scan(&ownerUserID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return ownerUserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OtherAccessCount returns the number of access entries for the given topic that are not owned by the user
|
||||||
|
func (s *commonStore) OtherAccessCount(username, topic string) (int, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectOtherAccessCount, escapeUnderscore(topic), escapeUnderscore(topic), username)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTier creates a new tier in the database
|
||||||
|
func (s *commonStore) AddTier(tier *Tier) error {
|
||||||
|
if tier.ID == "" {
|
||||||
|
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||||
|
}
|
||||||
|
if _, err := s.db.Exec(s.queries.insertTier, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTier updates a tier's properties in the database
|
||||||
|
func (s *commonStore) UpdateTier(tier *Tier) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateTier, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveTier deletes the tier with the given code
|
||||||
|
func (s *commonStore) RemoveTier(code string) error {
|
||||||
|
if !AllowedTier(code) {
|
||||||
|
return ErrInvalidArgument
|
||||||
|
}
|
||||||
|
if _, err := s.db.Exec(s.queries.deleteTier, code); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tiers returns a list of all Tier structs
|
||||||
|
func (s *commonStore) Tiers() ([]*Tier, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTiers)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
tiers := make([]*Tier, 0)
|
||||||
|
for {
|
||||||
|
tier, err := s.readTier(rows)
|
||||||
|
if errors.Is(err, ErrTierNotFound) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tiers = append(tiers, tier)
|
||||||
|
}
|
||||||
|
return tiers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
|
||||||
|
func (s *commonStore) Tier(code string) (*Tier, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTierByCode, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return s.readTier(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
|
||||||
|
func (s *commonStore) TierByStripePrice(priceID string) (*Tier, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectTierByPriceID, priceID, priceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return s.readTier(rows)
|
||||||
|
}
|
||||||
|
func (s *commonStore) readTier(rows *sql.Rows) (*Tier, error) {
|
||||||
|
var id, code, name string
|
||||||
|
var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
|
||||||
|
var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, ErrTierNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Tier{
|
||||||
|
ID: id,
|
||||||
|
Code: code,
|
||||||
|
Name: name,
|
||||||
|
MessageLimit: messagesLimit.Int64,
|
||||||
|
MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
|
||||||
|
EmailLimit: emailsLimit.Int64,
|
||||||
|
CallLimit: callsLimit.Int64,
|
||||||
|
ReservationLimit: reservationsLimit.Int64,
|
||||||
|
AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
|
||||||
|
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
|
||||||
|
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
|
||||||
|
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
|
||||||
|
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
|
||||||
|
StripeYearlyPriceID: stripeYearlyPriceID.String,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PhoneNumbers returns all phone numbers for the user with the given user ID
|
||||||
|
func (s *commonStore) PhoneNumbers(userID string) ([]string, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectPhoneNumbers, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
phoneNumbers := make([]string, 0)
|
||||||
|
for {
|
||||||
|
phoneNumber, err := s.readPhoneNumber(rows)
|
||||||
|
if errors.Is(err, ErrPhoneNumberNotFound) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
phoneNumbers = append(phoneNumbers, phoneNumber)
|
||||||
|
}
|
||||||
|
return phoneNumbers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPhoneNumber adds a phone number to the user with the given user ID
|
||||||
|
func (s *commonStore) AddPhoneNumber(userID, phoneNumber string) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.insertPhoneNumber, userID, phoneNumber); err != nil {
|
||||||
|
if isUniqueConstraintError(err) {
|
||||||
|
return ErrPhoneNumberExists
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePhoneNumber deletes a phone number from the user with the given user ID
|
||||||
|
func (s *commonStore) RemovePhoneNumber(userID, phoneNumber string) error {
|
||||||
|
_, err := s.db.Exec(s.queries.deletePhoneNumber, userID, phoneNumber)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
func (s *commonStore) readPhoneNumber(rows *sql.Rows) (string, error) {
|
||||||
|
var phoneNumber string
|
||||||
|
if !rows.Next() {
|
||||||
|
return "", ErrPhoneNumberNotFound
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&phoneNumber); err != nil {
|
||||||
|
return "", err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return phoneNumber, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeBilling updates a user's billing fields
|
||||||
|
func (s *commonStore) ChangeBilling(username string, billing *Billing) error {
|
||||||
|
if _, err := s.db.Exec(s.queries.updateBilling, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserIDByUsername returns the user ID for the given username
|
||||||
|
func (s *commonStore) UserIDByUsername(username string) (string, error) {
|
||||||
|
rows, err := s.db.Query(s.queries.selectUserIDFromUsername, username)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return "", ErrUserNotFound
|
||||||
|
}
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return userID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying database
|
||||||
|
func (s *commonStore) Close() error {
|
||||||
|
return s.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUniqueConstraintError checks if the error is a unique constraint violation for both SQLite and PostgreSQL
|
||||||
|
func isUniqueConstraintError(err error) bool {
|
||||||
|
errStr := err.Error()
|
||||||
|
return strings.Contains(errStr, "UNIQUE constraint failed") || strings.Contains(errStr, "23505")
|
||||||
|
}
|
||||||
292
user/store_postgres.go
Normal file
292
user/store_postgres.go
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL queries
|
||||||
|
const (
|
||||||
|
// User queries
|
||||||
|
postgresSelectUserByID = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.id = $1
|
||||||
|
`
|
||||||
|
postgresSelectUserByName = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE user_name = $1
|
||||||
|
`
|
||||||
|
postgresSelectUserByToken = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
JOIN user_token tk on u.id = tk.user_id
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE tk.token = $1 AND (tk.expires = 0 OR tk.expires >= $2)
|
||||||
|
`
|
||||||
|
postgresSelectUserByStripeID = `
|
||||||
|
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM "user" u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.stripe_customer_id = $1
|
||||||
|
`
|
||||||
|
postgresSelectUsernames = `
|
||||||
|
SELECT user_name
|
||||||
|
FROM "user"
|
||||||
|
ORDER BY
|
||||||
|
CASE role
|
||||||
|
WHEN 'admin' THEN 1
|
||||||
|
WHEN 'anonymous' THEN 3
|
||||||
|
ELSE 2
|
||||||
|
END, user_name
|
||||||
|
`
|
||||||
|
postgresSelectUserCount = `SELECT COUNT(*) FROM "user"`
|
||||||
|
postgresSelectUserIDFromUsername = `SELECT id FROM "user" WHERE user_name = $1`
|
||||||
|
postgresInsertUser = `INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created) VALUES ($1, $2, $3, $4, $5, $6, $7)`
|
||||||
|
postgresUpdateUserPass = `UPDATE "user" SET pass = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserRole = `UPDATE "user" SET role = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserProvisioned = `UPDATE "user" SET provisioned = $1 WHERE user_name = $2`
|
||||||
|
postgresUpdateUserPrefs = `UPDATE "user" SET prefs = $1 WHERE id = $2`
|
||||||
|
postgresUpdateUserStats = `UPDATE "user" SET stats_messages = $1, stats_emails = $2, stats_calls = $3 WHERE id = $4`
|
||||||
|
postgresUpdateUserStatsResetAll = `UPDATE "user" SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||||
|
postgresUpdateUserTier = `UPDATE "user" SET tier_id = (SELECT id FROM tier WHERE code = $1) WHERE user_name = $2`
|
||||||
|
postgresUpdateUserDeleted = `UPDATE "user" SET deleted = $1 WHERE id = $2`
|
||||||
|
postgresDeleteUser = `DELETE FROM "user" WHERE user_name = $1`
|
||||||
|
postgresDeleteUserTier = `UPDATE "user" SET tier_id = null WHERE user_name = $1`
|
||||||
|
postgresDeleteUsersMarked = `DELETE FROM "user" WHERE deleted < $1`
|
||||||
|
|
||||||
|
// Access queries
|
||||||
|
postgresSelectTopicPerms = `
|
||||||
|
SELECT read, write
|
||||||
|
FROM user_access a
|
||||||
|
JOIN "user" u ON u.id = a.user_id
|
||||||
|
WHERE (u.user_name = $1 OR u.user_name = $2) AND $3 LIKE a.topic ESCAPE '\'
|
||||||
|
ORDER BY u.user_name DESC, LENGTH(a.topic) DESC, CASE WHEN a.write THEN 1 ELSE 0 END DESC
|
||||||
|
`
|
||||||
|
postgresSelectUserAllAccess = `
|
||||||
|
SELECT user_id, topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||||
|
`
|
||||||
|
postgresSelectUserAccess = `
|
||||||
|
SELECT topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
ORDER BY LENGTH(topic) DESC, CASE WHEN write THEN 1 ELSE 0 END DESC, CASE WHEN read THEN 1 ELSE 0 END DESC, topic
|
||||||
|
`
|
||||||
|
postgresSelectUserReservations = `
|
||||||
|
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||||
|
FROM user_access a_user
|
||||||
|
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
WHERE a_user.user_id = a_user.owner_user_id
|
||||||
|
AND a_user.owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||||
|
ORDER BY a_user.topic
|
||||||
|
`
|
||||||
|
postgresSelectUserReservationsCount = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
`
|
||||||
|
postgresSelectUserReservationsOwner = `
|
||||||
|
SELECT owner_user_id
|
||||||
|
FROM user_access
|
||||||
|
WHERE topic = $1
|
||||||
|
AND user_id = owner_user_id
|
||||||
|
`
|
||||||
|
postgresSelectUserHasReservation = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
AND topic = $2
|
||||||
|
`
|
||||||
|
postgresSelectOtherAccessCount = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE (topic = $1 OR $2 LIKE topic ESCAPE '\')
|
||||||
|
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM "user" WHERE user_name = $3))
|
||||||
|
`
|
||||||
|
postgresUpsertUserAccess = `
|
||||||
|
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||||
|
VALUES (
|
||||||
|
(SELECT id FROM "user" WHERE user_name = $1),
|
||||||
|
$2,
|
||||||
|
$3,
|
||||||
|
$4,
|
||||||
|
CASE WHEN $5 = '' THEN NULL ELSE (SELECT id FROM "user" WHERE user_name = $6) END,
|
||||||
|
$7
|
||||||
|
)
|
||||||
|
ON CONFLICT (user_id, topic)
|
||||||
|
DO UPDATE SET read=EXCLUDED.read, write=EXCLUDED.write, owner_user_id=EXCLUDED.owner_user_id, provisioned=EXCLUDED.provisioned
|
||||||
|
`
|
||||||
|
postgresDeleteUserAccess = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM "user" WHERE user_name = $1)
|
||||||
|
OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2)
|
||||||
|
`
|
||||||
|
postgresDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = true`
|
||||||
|
postgresDeleteTopicAccess = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE (user_id = (SELECT id FROM "user" WHERE user_name = $1) OR owner_user_id = (SELECT id FROM "user" WHERE user_name = $2))
|
||||||
|
AND topic = $3
|
||||||
|
`
|
||||||
|
postgresDeleteAllAccess = `DELETE FROM user_access`
|
||||||
|
|
||||||
|
// Token queries
|
||||||
|
postgresSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1 AND token = $2`
|
||||||
|
postgresSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = $1`
|
||||||
|
postgresSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = $1`
|
||||||
|
postgresSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = true`
|
||||||
|
postgresUpsertToken = `
|
||||||
|
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
ON CONFLICT (user_id, token)
|
||||||
|
DO UPDATE SET label = EXCLUDED.label, expires = EXCLUDED.expires, provisioned = EXCLUDED.provisioned
|
||||||
|
`
|
||||||
|
postgresUpdateTokenLabel = `UPDATE user_token SET label = $1 WHERE user_id = $2 AND token = $3`
|
||||||
|
postgresUpdateTokenExpiry = `UPDATE user_token SET expires = $1 WHERE user_id = $2 AND token = $3`
|
||||||
|
postgresUpdateTokenLastAccess = `UPDATE user_token SET last_access = $1, last_origin = $2 WHERE token = $3`
|
||||||
|
postgresDeleteToken = `DELETE FROM user_token WHERE user_id = $1 AND token = $2`
|
||||||
|
postgresDeleteProvisionedToken = `DELETE FROM user_token WHERE token = $1`
|
||||||
|
postgresDeleteAllToken = `DELETE FROM user_token WHERE user_id = $1`
|
||||||
|
postgresDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < $1`
|
||||||
|
postgresDeleteExcessTokens = `
|
||||||
|
DELETE FROM user_token
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND (user_id, token) NOT IN (
|
||||||
|
SELECT user_id, token
|
||||||
|
FROM user_token
|
||||||
|
WHERE user_id = $2
|
||||||
|
ORDER BY expires DESC
|
||||||
|
LIMIT $3
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
// Tier queries
|
||||||
|
postgresInsertTier = `
|
||||||
|
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||||
|
`
|
||||||
|
postgresUpdateTier = `
|
||||||
|
UPDATE tier
|
||||||
|
SET name = $1, messages_limit = $2, messages_expiry_duration = $3, emails_limit = $4, calls_limit = $5, reservations_limit = $6, attachment_file_size_limit = $7, attachment_total_size_limit = $8, attachment_expiry_duration = $9, attachment_bandwidth_limit = $10, stripe_monthly_price_id = $11, stripe_yearly_price_id = $12
|
||||||
|
WHERE code = $13
|
||||||
|
`
|
||||||
|
postgresSelectTiers = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
`
|
||||||
|
postgresSelectTierByCode = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE code = $1
|
||||||
|
`
|
||||||
|
postgresSelectTierByPriceID = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE (stripe_monthly_price_id = $1 OR stripe_yearly_price_id = $2)
|
||||||
|
`
|
||||||
|
postgresDeleteTier = `DELETE FROM tier WHERE code = $1`
|
||||||
|
|
||||||
|
// Phone queries
|
||||||
|
postgresSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = $1`
|
||||||
|
postgresInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES ($1, $2)`
|
||||||
|
postgresDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = $1 AND phone_number = $2`
|
||||||
|
|
||||||
|
// Billing queries
|
||||||
|
postgresUpdateBilling = `
|
||||||
|
UPDATE "user"
|
||||||
|
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||||
|
WHERE user_name = $7
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewPostgresStore creates a new PostgreSQL-backed user store
|
||||||
|
func NewPostgresStore(dsn string) (Store, error) {
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := setupPostgres(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &commonStore{
|
||||||
|
db: db,
|
||||||
|
queries: storeQueries{
|
||||||
|
// User queries
|
||||||
|
selectUserByID: postgresSelectUserByID,
|
||||||
|
selectUserByName: postgresSelectUserByName,
|
||||||
|
selectUserByToken: postgresSelectUserByToken,
|
||||||
|
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||||
|
selectUsernames: postgresSelectUsernames,
|
||||||
|
selectUserCount: postgresSelectUserCount,
|
||||||
|
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||||
|
insertUser: postgresInsertUser,
|
||||||
|
updateUserPass: postgresUpdateUserPass,
|
||||||
|
updateUserRole: postgresUpdateUserRole,
|
||||||
|
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||||
|
updateUserPrefs: postgresUpdateUserPrefs,
|
||||||
|
updateUserStats: postgresUpdateUserStats,
|
||||||
|
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||||
|
updateUserTier: postgresUpdateUserTier,
|
||||||
|
updateUserDeleted: postgresUpdateUserDeleted,
|
||||||
|
deleteUser: postgresDeleteUser,
|
||||||
|
deleteUserTier: postgresDeleteUserTier,
|
||||||
|
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||||
|
|
||||||
|
// Access queries
|
||||||
|
selectTopicPerms: postgresSelectTopicPerms,
|
||||||
|
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||||
|
selectUserAccess: postgresSelectUserAccess,
|
||||||
|
selectUserReservations: postgresSelectUserReservations,
|
||||||
|
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||||
|
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||||
|
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||||
|
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||||
|
upsertUserAccess: postgresUpsertUserAccess,
|
||||||
|
deleteUserAccess: postgresDeleteUserAccess,
|
||||||
|
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||||
|
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||||
|
deleteAllAccess: postgresDeleteAllAccess,
|
||||||
|
|
||||||
|
// Token queries
|
||||||
|
selectToken: postgresSelectToken,
|
||||||
|
selectTokens: postgresSelectTokens,
|
||||||
|
selectTokenCount: postgresSelectTokenCount,
|
||||||
|
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||||
|
upsertToken: postgresUpsertToken,
|
||||||
|
updateTokenLabel: postgresUpdateTokenLabel,
|
||||||
|
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||||
|
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||||
|
deleteToken: postgresDeleteToken,
|
||||||
|
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||||
|
deleteAllToken: postgresDeleteAllToken,
|
||||||
|
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||||
|
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||||
|
|
||||||
|
// Tier queries
|
||||||
|
insertTier: postgresInsertTier,
|
||||||
|
selectTiers: postgresSelectTiers,
|
||||||
|
selectTierByCode: postgresSelectTierByCode,
|
||||||
|
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||||
|
updateTier: postgresUpdateTier,
|
||||||
|
deleteTier: postgresDeleteTier,
|
||||||
|
|
||||||
|
// Phone queries
|
||||||
|
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||||
|
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||||
|
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||||
|
|
||||||
|
// Billing queries
|
||||||
|
updateBilling: postgresUpdateBilling,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
113
user/store_postgres_schema.go
Normal file
113
user/store_postgres_schema.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initial PostgreSQL schema
|
||||||
|
const (
|
||||||
|
postgresCreateTablesQueries = `
|
||||||
|
CREATE TABLE IF NOT EXISTS tier (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
messages_limit BIGINT NOT NULL,
|
||||||
|
messages_expiry_duration BIGINT NOT NULL,
|
||||||
|
emails_limit BIGINT NOT NULL,
|
||||||
|
calls_limit BIGINT NOT NULL,
|
||||||
|
reservations_limit BIGINT NOT NULL,
|
||||||
|
attachment_file_size_limit BIGINT NOT NULL,
|
||||||
|
attachment_total_size_limit BIGINT NOT NULL,
|
||||||
|
attachment_expiry_duration BIGINT NOT NULL,
|
||||||
|
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||||
|
stripe_monthly_price_id TEXT,
|
||||||
|
stripe_yearly_price_id TEXT,
|
||||||
|
UNIQUE(code),
|
||||||
|
UNIQUE(stripe_monthly_price_id),
|
||||||
|
UNIQUE(stripe_yearly_price_id)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "user" (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
tier_id TEXT REFERENCES tier(id),
|
||||||
|
user_name TEXT NOT NULL UNIQUE,
|
||||||
|
pass TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||||
|
prefs JSONB NOT NULL DEFAULT '{}',
|
||||||
|
sync_topic TEXT NOT NULL,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||||
|
stripe_customer_id TEXT UNIQUE,
|
||||||
|
stripe_subscription_id TEXT UNIQUE,
|
||||||
|
stripe_subscription_status TEXT,
|
||||||
|
stripe_subscription_interval TEXT,
|
||||||
|
stripe_subscription_paid_until BIGINT,
|
||||||
|
stripe_subscription_cancel_at BIGINT,
|
||||||
|
created BIGINT NOT NULL,
|
||||||
|
deleted BIGINT
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_access (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
read BOOLEAN NOT NULL,
|
||||||
|
write BOOLEAN NOT NULL,
|
||||||
|
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, topic)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_token (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
label TEXT NOT NULL,
|
||||||
|
last_access BIGINT NOT NULL,
|
||||||
|
last_origin TEXT NOT NULL,
|
||||||
|
expires BIGINT NOT NULL,
|
||||||
|
provisioned BOOLEAN NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, token)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_phone (
|
||||||
|
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||||
|
phone_number TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, phone_number)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
store TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||||
|
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||||
|
ON CONFLICT (id) DO NOTHING;
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema table management queries for Postgres
|
||||||
|
const (
|
||||||
|
postgresCurrentSchemaVersion = 6
|
||||||
|
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||||
|
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupPostgres(db *sql.DB) error {
|
||||||
|
var schemaVersion int
|
||||||
|
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||||
|
if err != nil {
|
||||||
|
return setupNewPostgres(db)
|
||||||
|
}
|
||||||
|
if schemaVersion > postgresCurrentSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||||
|
}
|
||||||
|
// Note: PostgreSQL migrations will be added when needed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupNewPostgres(db *sql.DB) error {
|
||||||
|
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
208
user/store_postgres_test.go
Normal file
208
user/store_postgres_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package user_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/user"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestPostgresStore(t *testing.T) user.Store {
|
||||||
|
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||||
|
}
|
||||||
|
// Create a unique schema for this test
|
||||||
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
|
setupDB, err := sql.Open("pgx", dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
// Open store with search_path set to the new schema
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("search_path", schema)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
store, err := user.NewPostgresStore(u.String())
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
store.Close()
|
||||||
|
cleanDB, err := sql.Open("pgx", dsn)
|
||||||
|
if err == nil {
|
||||||
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAddUser(t *testing.T) {
|
||||||
|
testStoreAddUser(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAddUserAlreadyExists(t *testing.T) {
|
||||||
|
testStoreAddUserAlreadyExists(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreRemoveUser(t *testing.T) {
|
||||||
|
testStoreRemoveUser(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUserByID(t *testing.T) {
|
||||||
|
testStoreUserByID(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUserByToken(t *testing.T) {
|
||||||
|
testStoreUserByToken(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUserByStripeCustomer(t *testing.T) {
|
||||||
|
testStoreUserByStripeCustomer(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUsers(t *testing.T) {
|
||||||
|
testStoreUsers(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUsersCount(t *testing.T) {
|
||||||
|
testStoreUsersCount(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreChangePassword(t *testing.T) {
|
||||||
|
testStoreChangePassword(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreChangeRole(t *testing.T) {
|
||||||
|
testStoreChangeRole(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokens(t *testing.T) {
|
||||||
|
testStoreTokens(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokenChangeLabel(t *testing.T) {
|
||||||
|
testStoreTokenChangeLabel(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokenRemove(t *testing.T) {
|
||||||
|
testStoreTokenRemove(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokenRemoveExpired(t *testing.T) {
|
||||||
|
testStoreTokenRemoveExpired(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokenRemoveExcess(t *testing.T) {
|
||||||
|
testStoreTokenRemoveExcess(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTokenUpdateLastAccess(t *testing.T) {
|
||||||
|
testStoreTokenUpdateLastAccess(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAllowAccess(t *testing.T) {
|
||||||
|
testStoreAllowAccess(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAllowAccessReadOnly(t *testing.T) {
|
||||||
|
testStoreAllowAccessReadOnly(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreResetAccess(t *testing.T) {
|
||||||
|
testStoreResetAccess(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreResetAccessAll(t *testing.T) {
|
||||||
|
testStoreResetAccessAll(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAuthorizeTopicAccess(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccess(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccessNotFound(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccessDenyAll(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreReservations(t *testing.T) {
|
||||||
|
testStoreReservations(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreReservationsCount(t *testing.T) {
|
||||||
|
testStoreReservationsCount(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreHasReservation(t *testing.T) {
|
||||||
|
testStoreHasReservation(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreReservationOwner(t *testing.T) {
|
||||||
|
testStoreReservationOwner(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTiers(t *testing.T) {
|
||||||
|
testStoreTiers(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTierUpdate(t *testing.T) {
|
||||||
|
testStoreTierUpdate(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTierRemove(t *testing.T) {
|
||||||
|
testStoreTierRemove(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreTierByStripePrice(t *testing.T) {
|
||||||
|
testStoreTierByStripePrice(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreChangeTier(t *testing.T) {
|
||||||
|
testStoreChangeTier(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStorePhoneNumbers(t *testing.T) {
|
||||||
|
testStorePhoneNumbers(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreChangeSettings(t *testing.T) {
|
||||||
|
testStoreChangeSettings(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreChangeBilling(t *testing.T) {
|
||||||
|
testStoreChangeBilling(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreUpdateStats(t *testing.T) {
|
||||||
|
testStoreUpdateStats(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreResetStats(t *testing.T) {
|
||||||
|
testStoreResetStats(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreMarkUserRemoved(t *testing.T) {
|
||||||
|
testStoreMarkUserRemoved(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreRemoveDeletedUsers(t *testing.T) {
|
||||||
|
testStoreRemoveDeletedUsers(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreAllGrants(t *testing.T) {
|
||||||
|
testStoreAllGrants(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStoreOtherAccessCount(t *testing.T) {
|
||||||
|
testStoreOtherAccessCount(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
273
user/store_sqlite.go
Normal file
273
user/store_sqlite.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package user
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// User queries
|
||||||
|
sqliteSelectUserByID = `
|
||||||
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM user u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.id = ?
|
||||||
|
`
|
||||||
|
sqliteSelectUserByName = `
|
||||||
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM user u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE user = ?
|
||||||
|
`
|
||||||
|
sqliteSelectUserByToken = `
|
||||||
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM user u
|
||||||
|
JOIN user_token tk on u.id = tk.user_id
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||||
|
`
|
||||||
|
sqliteSelectUserByStripeID = `
|
||||||
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||||
|
FROM user u
|
||||||
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
|
WHERE u.stripe_customer_id = ?
|
||||||
|
`
|
||||||
|
sqliteSelectUsernames = `
|
||||||
|
SELECT user
|
||||||
|
FROM user
|
||||||
|
ORDER BY
|
||||||
|
CASE role
|
||||||
|
WHEN 'admin' THEN 1
|
||||||
|
WHEN 'anonymous' THEN 3
|
||||||
|
ELSE 2
|
||||||
|
END, user
|
||||||
|
`
|
||||||
|
sqliteSelectUserCount = `SELECT COUNT(*) FROM user`
|
||||||
|
sqliteSelectUserIDFromUsername = `SELECT id FROM user WHERE user = ?`
|
||||||
|
sqliteInsertUser = `INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
||||||
|
sqliteUpdateUserPass = `UPDATE user SET pass = ? WHERE user = ?`
|
||||||
|
sqliteUpdateUserRole = `UPDATE user SET role = ? WHERE user = ?`
|
||||||
|
sqliteUpdateUserProvisioned = `UPDATE user SET provisioned = ? WHERE user = ?`
|
||||||
|
sqliteUpdateUserPrefs = `UPDATE user SET prefs = ? WHERE id = ?`
|
||||||
|
sqliteUpdateUserStats = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
||||||
|
sqliteUpdateUserStatsResetAll = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
||||||
|
sqliteUpdateUserTier = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
|
||||||
|
sqliteUpdateUserDeleted = `UPDATE user SET deleted = ? WHERE id = ?`
|
||||||
|
sqliteDeleteUser = `DELETE FROM user WHERE user = ?`
|
||||||
|
sqliteDeleteUserTier = `UPDATE user SET tier_id = null WHERE user = ?`
|
||||||
|
sqliteDeleteUsersMarked = `DELETE FROM user WHERE deleted < ?`
|
||||||
|
|
||||||
|
// Access queries
|
||||||
|
sqliteSelectTopicPerms = `
|
||||||
|
SELECT read, write
|
||||||
|
FROM user_access a
|
||||||
|
JOIN user u ON u.id = a.user_id
|
||||||
|
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
|
||||||
|
ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
|
||||||
|
`
|
||||||
|
sqliteSelectUserAllAccess = `
|
||||||
|
SELECT user_id, topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||||
|
`
|
||||||
|
sqliteSelectUserAccess = `
|
||||||
|
SELECT topic, read, write, provisioned
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
|
||||||
|
`
|
||||||
|
sqliteSelectUserReservations = `
|
||||||
|
SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
|
||||||
|
FROM user_access a_user
|
||||||
|
LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
WHERE a_user.user_id = a_user.owner_user_id
|
||||||
|
AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
ORDER BY a_user.topic
|
||||||
|
`
|
||||||
|
sqliteSelectUserReservationsCount = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
`
|
||||||
|
sqliteSelectUserReservationsOwner = `
|
||||||
|
SELECT owner_user_id
|
||||||
|
FROM user_access
|
||||||
|
WHERE topic = ?
|
||||||
|
AND user_id = owner_user_id
|
||||||
|
`
|
||||||
|
sqliteSelectUserHasReservation = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
AND topic = ?
|
||||||
|
`
|
||||||
|
sqliteSelectOtherAccessCount = `
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM user_access
|
||||||
|
WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
|
||||||
|
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
|
||||||
|
`
|
||||||
|
sqliteUpsertUserAccess = `
|
||||||
|
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
||||||
|
VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
|
||||||
|
ON CONFLICT (user_id, topic)
|
||||||
|
DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
|
||||||
|
`
|
||||||
|
sqliteDeleteUserAccess = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
OR owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
`
|
||||||
|
sqliteDeleteUserAccessProvisioned = `DELETE FROM user_access WHERE provisioned = 1`
|
||||||
|
sqliteDeleteTopicAccess = `
|
||||||
|
DELETE FROM user_access
|
||||||
|
WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
|
||||||
|
AND topic = ?
|
||||||
|
`
|
||||||
|
sqliteDeleteAllAccess = `DELETE FROM user_access`
|
||||||
|
|
||||||
|
// Token queries
|
||||||
|
sqliteSelectToken = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
|
sqliteSelectTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
||||||
|
sqliteSelectTokenCount = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||||
|
sqliteSelectAllProvisionedTokens = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
|
||||||
|
sqliteUpsertToken = `
|
||||||
|
INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT (user_id, token)
|
||||||
|
DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
||||||
|
`
|
||||||
|
sqliteUpdateTokenLabel = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||||
|
sqliteUpdateTokenExpiry = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||||
|
sqliteUpdateTokenLastAccess = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||||
|
sqliteDeleteToken = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
|
sqliteDeleteProvisionedToken = `DELETE FROM user_token WHERE token = ?`
|
||||||
|
sqliteDeleteAllToken = `DELETE FROM user_token WHERE user_id = ?`
|
||||||
|
sqliteDeleteExpiredTokens = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||||
|
sqliteDeleteExcessTokens = `
|
||||||
|
DELETE FROM user_token
|
||||||
|
WHERE user_id = ?
|
||||||
|
AND (user_id, token) NOT IN (
|
||||||
|
SELECT user_id, token
|
||||||
|
FROM user_token
|
||||||
|
WHERE user_id = ?
|
||||||
|
ORDER BY expires DESC
|
||||||
|
LIMIT ?
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
// Tier queries
|
||||||
|
sqliteInsertTier = `
|
||||||
|
INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
sqliteUpdateTier = `
|
||||||
|
UPDATE tier
|
||||||
|
SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
|
||||||
|
WHERE code = ?
|
||||||
|
`
|
||||||
|
sqliteSelectTiers = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
`
|
||||||
|
sqliteSelectTierByCode = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE code = ?
|
||||||
|
`
|
||||||
|
sqliteSelectTierByPriceID = `
|
||||||
|
SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
|
||||||
|
FROM tier
|
||||||
|
WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
|
||||||
|
`
|
||||||
|
sqliteDeleteTier = `DELETE FROM tier WHERE code = ?`
|
||||||
|
|
||||||
|
// Phone queries
|
||||||
|
sqliteSelectPhoneNumbers = `SELECT phone_number FROM user_phone WHERE user_id = ?`
|
||||||
|
sqliteInsertPhoneNumber = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
|
||||||
|
sqliteDeletePhoneNumber = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
|
||||||
|
|
||||||
|
// Billing queries
|
||||||
|
sqliteUpdateBilling = `
|
||||||
|
UPDATE user
|
||||||
|
SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
|
||||||
|
WHERE user = ?
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSQLiteStore creates a new SQLite-backed user store
|
||||||
|
func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := setupSQLite(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &commonStore{
|
||||||
|
db: db,
|
||||||
|
queries: storeQueries{
|
||||||
|
selectUserByID: sqliteSelectUserByID,
|
||||||
|
selectUserByName: sqliteSelectUserByName,
|
||||||
|
selectUserByToken: sqliteSelectUserByToken,
|
||||||
|
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||||
|
selectUsernames: sqliteSelectUsernames,
|
||||||
|
selectUserCount: sqliteSelectUserCount,
|
||||||
|
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||||
|
insertUser: sqliteInsertUser,
|
||||||
|
updateUserPass: sqliteUpdateUserPass,
|
||||||
|
updateUserRole: sqliteUpdateUserRole,
|
||||||
|
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||||
|
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||||
|
updateUserStats: sqliteUpdateUserStats,
|
||||||
|
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||||
|
updateUserTier: sqliteUpdateUserTier,
|
||||||
|
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||||
|
deleteUser: sqliteDeleteUser,
|
||||||
|
deleteUserTier: sqliteDeleteUserTier,
|
||||||
|
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||||
|
selectTopicPerms: sqliteSelectTopicPerms,
|
||||||
|
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||||
|
selectUserAccess: sqliteSelectUserAccess,
|
||||||
|
selectUserReservations: sqliteSelectUserReservations,
|
||||||
|
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||||
|
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||||
|
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||||
|
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||||
|
upsertUserAccess: sqliteUpsertUserAccess,
|
||||||
|
deleteUserAccess: sqliteDeleteUserAccess,
|
||||||
|
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||||
|
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||||
|
deleteAllAccess: sqliteDeleteAllAccess,
|
||||||
|
selectToken: sqliteSelectToken,
|
||||||
|
selectTokens: sqliteSelectTokens,
|
||||||
|
selectTokenCount: sqliteSelectTokenCount,
|
||||||
|
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||||
|
upsertToken: sqliteUpsertToken,
|
||||||
|
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||||
|
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||||
|
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||||
|
deleteToken: sqliteDeleteToken,
|
||||||
|
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||||
|
deleteAllToken: sqliteDeleteAllToken,
|
||||||
|
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||||
|
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||||
|
insertTier: sqliteInsertTier,
|
||||||
|
selectTiers: sqliteSelectTiers,
|
||||||
|
selectTierByCode: sqliteSelectTierByCode,
|
||||||
|
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||||
|
updateTier: sqliteUpdateTier,
|
||||||
|
deleteTier: sqliteDeleteTier,
|
||||||
|
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||||
|
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||||
|
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||||
|
updateBilling: sqliteUpdateBilling,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -2,19 +2,116 @@ package user
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Schema management queries
|
// Initial SQLite schema
|
||||||
const (
|
const (
|
||||||
currentSchemaVersion = 6
|
sqliteCreateTablesQueries = `
|
||||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
BEGIN;
|
||||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
CREATE TABLE IF NOT EXISTS tier (
|
||||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
id TEXT PRIMARY KEY,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
messages_limit INT NOT NULL,
|
||||||
|
messages_expiry_duration INT NOT NULL,
|
||||||
|
emails_limit INT NOT NULL,
|
||||||
|
calls_limit INT NOT NULL,
|
||||||
|
reservations_limit INT NOT NULL,
|
||||||
|
attachment_file_size_limit INT NOT NULL,
|
||||||
|
attachment_total_size_limit INT NOT NULL,
|
||||||
|
attachment_expiry_duration INT NOT NULL,
|
||||||
|
attachment_bandwidth_limit INT NOT NULL,
|
||||||
|
stripe_monthly_price_id TEXT,
|
||||||
|
stripe_yearly_price_id TEXT
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||||
|
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||||
|
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||||
|
CREATE TABLE IF NOT EXISTS user (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
tier_id TEXT,
|
||||||
|
user TEXT NOT NULL,
|
||||||
|
pass TEXT NOT NULL,
|
||||||
|
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||||
|
prefs JSON NOT NULL DEFAULT '{}',
|
||||||
|
sync_topic TEXT NOT NULL,
|
||||||
|
provisioned INT NOT NULL,
|
||||||
|
stats_messages INT NOT NULL DEFAULT (0),
|
||||||
|
stats_emails INT NOT NULL DEFAULT (0),
|
||||||
|
stats_calls INT NOT NULL DEFAULT (0),
|
||||||
|
stripe_customer_id TEXT,
|
||||||
|
stripe_subscription_id TEXT,
|
||||||
|
stripe_subscription_status TEXT,
|
||||||
|
stripe_subscription_interval TEXT,
|
||||||
|
stripe_subscription_paid_until INT,
|
||||||
|
stripe_subscription_cancel_at INT,
|
||||||
|
created INT NOT NULL,
|
||||||
|
deleted INT,
|
||||||
|
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||||
|
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||||
|
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_access (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
read INT NOT NULL,
|
||||||
|
write INT NOT NULL,
|
||||||
|
owner_user_id INT,
|
||||||
|
provisioned INT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, topic),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_token (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
token TEXT NOT NULL,
|
||||||
|
label TEXT NOT NULL,
|
||||||
|
last_access INT NOT NULL,
|
||||||
|
last_origin TEXT NOT NULL,
|
||||||
|
expires INT NOT NULL,
|
||||||
|
provisioned INT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, token),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||||
|
CREATE TABLE IF NOT EXISTS user_phone (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
phone_number TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (user_id, phone_number),
|
||||||
|
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||||
|
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||||
|
ON CONFLICT (id) DO NOTHING;
|
||||||
|
COMMIT;
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema version table management for SQLite
|
||||||
|
const (
|
||||||
|
sqliteCurrentSchemaVersion = 6
|
||||||
|
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
|
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||||
|
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema migrations for SQLite
|
||||||
|
const (
|
||||||
// 1 -> 2 (complex migration!)
|
// 1 -> 2 (complex migration!)
|
||||||
migrate1To2CreateTablesQueries = `
|
sqliteMigrate1To2CreateTablesQueries = `
|
||||||
ALTER TABLE user RENAME TO user_old;
|
ALTER TABLE user RENAME TO user_old;
|
||||||
CREATE TABLE IF NOT EXISTS tier (
|
CREATE TABLE IF NOT EXISTS tier (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
@@ -82,12 +179,12 @@ const (
|
|||||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||||
ON CONFLICT (id) DO NOTHING;
|
ON CONFLICT (id) DO NOTHING;
|
||||||
`
|
`
|
||||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||||
migrate1To2InsertUserNoTx = `
|
sqliteMigrate1To2InsertUserNoTx = `
|
||||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||||
`
|
`
|
||||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||||
INSERT INTO user_access (user_id, topic, read, write)
|
INSERT INTO user_access (user_id, topic, read, write)
|
||||||
SELECT u.id, a.topic, a.read, a.write
|
SELECT u.id, a.topic, a.read, a.write
|
||||||
FROM user u
|
FROM user u
|
||||||
@@ -98,7 +195,7 @@ const (
|
|||||||
`
|
`
|
||||||
|
|
||||||
// 2 -> 3
|
// 2 -> 3
|
||||||
migrate2To3UpdateQueries = `
|
sqliteMigrate2To3UpdateQueries = `
|
||||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||||
@@ -108,7 +205,7 @@ const (
|
|||||||
`
|
`
|
||||||
|
|
||||||
// 3 -> 4
|
// 3 -> 4
|
||||||
migrate3To4UpdateQueries = `
|
sqliteMigrate3To4UpdateQueries = `
|
||||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||||
CREATE TABLE IF NOT EXISTS user_phone (
|
CREATE TABLE IF NOT EXISTS user_phone (
|
||||||
@@ -120,12 +217,12 @@ const (
|
|||||||
`
|
`
|
||||||
|
|
||||||
// 4 -> 5
|
// 4 -> 5
|
||||||
migrate4To5UpdateQueries = `
|
sqliteMigrate4To5UpdateQueries = `
|
||||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||||
`
|
`
|
||||||
|
|
||||||
// 5 -> 6
|
// 5 -> 6
|
||||||
migrate5To6UpdateQueries = `
|
sqliteMigrate5To6UpdateQueries = `
|
||||||
PRAGMA foreign_keys=off;
|
PRAGMA foreign_keys=off;
|
||||||
|
|
||||||
-- Alter user table: Add provisioned column
|
-- Alter user table: Add provisioned column
|
||||||
@@ -220,16 +317,60 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
migrations = map[int]func(db *sql.DB) error{
|
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||||
1: migrateFrom1,
|
1: sqliteMigrateFrom1,
|
||||||
2: migrateFrom2,
|
2: sqliteMigrateFrom2,
|
||||||
3: migrateFrom3,
|
3: sqliteMigrateFrom3,
|
||||||
4: migrateFrom4,
|
4: sqliteMigrateFrom4,
|
||||||
5: migrateFrom5,
|
5: sqliteMigrateFrom5,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func migrateFrom1(db *sql.DB) error {
|
func setupSQLite(db *sql.DB) error {
|
||||||
|
var schemaVersion int
|
||||||
|
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
|
if err != nil {
|
||||||
|
return setupNewSQLite(db)
|
||||||
|
}
|
||||||
|
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||||
|
return nil
|
||||||
|
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||||
|
}
|
||||||
|
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||||
|
fn, ok := sqliteMigrations[i]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||||
|
} else if err := fn(db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupNewSQLite(db *sql.DB) error {
|
||||||
|
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||||
|
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if startupQueries != "" {
|
||||||
|
if _, err := db.Exec(startupQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom1(db *sql.DB) error {
|
||||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -237,11 +378,11 @@ func migrateFrom1(db *sql.DB) error {
|
|||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
// Rename user -> user_old, and create new tables
|
// Rename user -> user_old, and create new tables
|
||||||
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Insert users from user_old into new user table, with ID and sync_topic
|
// Insert users from user_old into new user table, with ID and sync_topic
|
||||||
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
|
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -260,15 +401,15 @@ func migrateFrom1(db *sql.DB) error {
|
|||||||
for _, username := range usernames {
|
for _, username := range usernames {
|
||||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||||
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||||
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
@@ -277,65 +418,65 @@ func migrateFrom1(db *sql.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateFrom2(db *sql.DB) error {
|
func sqliteMigrateFrom2(db *sql.DB) error {
|
||||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateFrom3(db *sql.DB) error {
|
func sqliteMigrateFrom3(db *sql.DB) error {
|
||||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateFrom4(db *sql.DB) error {
|
func sqliteMigrateFrom4(db *sql.DB) error {
|
||||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func migrateFrom5(db *sql.DB) error {
|
func sqliteMigrateFrom5(db *sql.DB) error {
|
||||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
|
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
180
user/store_sqlite_test.go
Normal file
180
user/store_sqlite_test.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
package user_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestSQLiteStore(t *testing.T) user.Store {
|
||||||
|
store, err := user.NewSQLiteStore(filepath.Join(t.TempDir(), "user.db"), "")
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { store.Close() })
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAddUser(t *testing.T) {
|
||||||
|
testStoreAddUser(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAddUserAlreadyExists(t *testing.T) {
|
||||||
|
testStoreAddUserAlreadyExists(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreRemoveUser(t *testing.T) {
|
||||||
|
testStoreRemoveUser(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUserByID(t *testing.T) {
|
||||||
|
testStoreUserByID(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUserByToken(t *testing.T) {
|
||||||
|
testStoreUserByToken(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUserByStripeCustomer(t *testing.T) {
|
||||||
|
testStoreUserByStripeCustomer(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUsers(t *testing.T) {
|
||||||
|
testStoreUsers(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUsersCount(t *testing.T) {
|
||||||
|
testStoreUsersCount(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreChangePassword(t *testing.T) {
|
||||||
|
testStoreChangePassword(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreChangeRole(t *testing.T) {
|
||||||
|
testStoreChangeRole(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokens(t *testing.T) {
|
||||||
|
testStoreTokens(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokenChangeLabel(t *testing.T) {
|
||||||
|
testStoreTokenChangeLabel(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokenRemove(t *testing.T) {
|
||||||
|
testStoreTokenRemove(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokenRemoveExpired(t *testing.T) {
|
||||||
|
testStoreTokenRemoveExpired(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokenRemoveExcess(t *testing.T) {
|
||||||
|
testStoreTokenRemoveExcess(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTokenUpdateLastAccess(t *testing.T) {
|
||||||
|
testStoreTokenUpdateLastAccess(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAllowAccess(t *testing.T) {
|
||||||
|
testStoreAllowAccess(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAllowAccessReadOnly(t *testing.T) {
|
||||||
|
testStoreAllowAccessReadOnly(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreResetAccess(t *testing.T) {
|
||||||
|
testStoreResetAccess(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreResetAccessAll(t *testing.T) {
|
||||||
|
testStoreResetAccessAll(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAuthorizeTopicAccess(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccess(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAuthorizeTopicAccessNotFound(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccessNotFound(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAuthorizeTopicAccessDenyAll(t *testing.T) {
|
||||||
|
testStoreAuthorizeTopicAccessDenyAll(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreReservations(t *testing.T) {
|
||||||
|
testStoreReservations(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreReservationsCount(t *testing.T) {
|
||||||
|
testStoreReservationsCount(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreHasReservation(t *testing.T) {
|
||||||
|
testStoreHasReservation(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreReservationOwner(t *testing.T) {
|
||||||
|
testStoreReservationOwner(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTiers(t *testing.T) {
|
||||||
|
testStoreTiers(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTierUpdate(t *testing.T) {
|
||||||
|
testStoreTierUpdate(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTierRemove(t *testing.T) {
|
||||||
|
testStoreTierRemove(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreTierByStripePrice(t *testing.T) {
|
||||||
|
testStoreTierByStripePrice(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreChangeTier(t *testing.T) {
|
||||||
|
testStoreChangeTier(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStorePhoneNumbers(t *testing.T) {
|
||||||
|
testStorePhoneNumbers(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreChangeSettings(t *testing.T) {
|
||||||
|
testStoreChangeSettings(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreChangeBilling(t *testing.T) {
|
||||||
|
testStoreChangeBilling(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreUpdateStats(t *testing.T) {
|
||||||
|
testStoreUpdateStats(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreResetStats(t *testing.T) {
|
||||||
|
testStoreResetStats(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreMarkUserRemoved(t *testing.T) {
|
||||||
|
testStoreMarkUserRemoved(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreRemoveDeletedUsers(t *testing.T) {
|
||||||
|
testStoreRemoveDeletedUsers(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreAllGrants(t *testing.T) {
|
||||||
|
testStoreAllGrants(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLiteStoreOtherAccessCount(t *testing.T) {
|
||||||
|
testStoreOtherAccessCount(t, newTestSQLiteStore(t))
|
||||||
|
}
|
||||||
619
user/store_test.go
Normal file
619
user/store_test.go
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
package user_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testStoreAddUser(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "phil", u.Name)
|
||||||
|
require.Equal(t, user.RoleUser, u.Role)
|
||||||
|
require.False(t, u.Provisioned)
|
||||||
|
require.NotEmpty(t, u.ID)
|
||||||
|
require.NotEmpty(t, u.SyncTopic)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAddUserAlreadyExists(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Equal(t, user.ErrUserExists, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreRemoveUser(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "phil", u.Name)
|
||||||
|
|
||||||
|
require.Nil(t, store.RemoveUser("phil"))
|
||||||
|
_, err = store.User("phil")
|
||||||
|
require.Equal(t, user.ErrUserNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUserByID(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleAdmin, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
u2, err := store.UserByID(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, u.Name, u2.Name)
|
||||||
|
require.Equal(t, u.ID, u2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUserByToken(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
tk, err := store.CreateToken(u.ID, "tk_test123", "test token", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(24*time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tk_test123", tk.Value)
|
||||||
|
|
||||||
|
u2, err := store.UserByToken(tk.Value)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "phil", u2.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUserByStripeCustomer(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.ChangeBilling("phil", &user.Billing{
|
||||||
|
StripeCustomerID: "cus_test123",
|
||||||
|
StripeSubscriptionID: "sub_test123",
|
||||||
|
}))
|
||||||
|
|
||||||
|
u, err := store.UserByStripeCustomer("cus_test123")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "phil", u.Name)
|
||||||
|
require.Equal(t, "cus_test123", u.Billing.StripeCustomerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUsers(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AddUser("ben", "benhash", user.RoleAdmin, false))
|
||||||
|
|
||||||
|
users, err := store.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, len(users) >= 3) // phil, ben, and the everyone user
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUsersCount(t *testing.T, store user.Store) {
|
||||||
|
count, err := store.UsersCount()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, count >= 1) // At least the everyone user
|
||||||
|
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
count2, err := store.UsersCount()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, count+1, count2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreChangePassword(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "philhash", u.Hash)
|
||||||
|
|
||||||
|
require.Nil(t, store.ChangePassword("phil", "newhash"))
|
||||||
|
u, err = store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "newhash", u.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreChangeRole(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, user.RoleUser, u.Role)
|
||||||
|
|
||||||
|
require.Nil(t, store.ChangeRole("phil", user.RoleAdmin))
|
||||||
|
u, err = store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, user.RoleAdmin, u.Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokens(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expires := now.Add(24 * time.Hour)
|
||||||
|
origin := netip.MustParseAddr("9.9.9.9")
|
||||||
|
|
||||||
|
tk, err := store.CreateToken(u.ID, "tk_abc", "my token", now, origin, expires, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tk_abc", tk.Value)
|
||||||
|
require.Equal(t, "my token", tk.Label)
|
||||||
|
|
||||||
|
// Get single token
|
||||||
|
tk2, err := store.Token(u.ID, "tk_abc")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tk_abc", tk2.Value)
|
||||||
|
require.Equal(t, "my token", tk2.Label)
|
||||||
|
|
||||||
|
// Get all tokens
|
||||||
|
tokens, err := store.Tokens(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, tokens, 1)
|
||||||
|
require.Equal(t, "tk_abc", tokens[0].Value)
|
||||||
|
|
||||||
|
// Token count
|
||||||
|
count, err := store.TokenCount(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokenChangeLabel(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = store.CreateToken(u.ID, "tk_abc", "old label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.ChangeTokenLabel(u.ID, "tk_abc", "new label"))
|
||||||
|
tk, err := store.Token(u.ID, "tk_abc")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "new label", tk.Label)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokenRemove(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.RemoveToken(u.ID, "tk_abc"))
|
||||||
|
_, err = store.Token(u.ID, "tk_abc")
|
||||||
|
require.Equal(t, user.ErrTokenNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokenRemoveExpired(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create expired token and active token
|
||||||
|
_, err = store.CreateToken(u.ID, "tk_expired", "expired", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(-time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = store.CreateToken(u.ID, "tk_active", "active", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.RemoveExpiredTokens())
|
||||||
|
|
||||||
|
// Expired token should be gone
|
||||||
|
_, err = store.Token(u.ID, "tk_expired")
|
||||||
|
require.Equal(t, user.ErrTokenNotFound, err)
|
||||||
|
|
||||||
|
// Active token should still exist
|
||||||
|
tk, err := store.Token(u.ID, "tk_active")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tk_active", tk.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokenRemoveExcess(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create 3 tokens with increasing expiry
|
||||||
|
for i, name := range []string{"tk_a", "tk_b", "tk_c"} {
|
||||||
|
_, err = store.CreateToken(u.ID, name, name, time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Duration(i+1)*time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := store.TokenCount(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, count)
|
||||||
|
|
||||||
|
// Remove excess, keep only 2 (the ones with latest expiry: tk_b, tk_c)
|
||||||
|
require.Nil(t, store.RemoveExcessTokens(u.ID, 2))
|
||||||
|
|
||||||
|
count, err = store.TokenCount(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, count)
|
||||||
|
|
||||||
|
// tk_a should be removed (earliest expiry)
|
||||||
|
_, err = store.Token(u.ID, "tk_a")
|
||||||
|
require.Equal(t, user.ErrTokenNotFound, err)
|
||||||
|
|
||||||
|
// tk_b and tk_c should remain
|
||||||
|
_, err = store.Token(u.ID, "tk_b")
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = store.Token(u.ID, "tk_c")
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTokenUpdateLastAccess(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = store.CreateToken(u.ID, "tk_abc", "label", time.Now(), netip.MustParseAddr("1.2.3.4"), time.Now().Add(time.Hour), false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
newTime := time.Now().Add(5 * time.Minute)
|
||||||
|
newOrigin := netip.MustParseAddr("5.5.5.5")
|
||||||
|
require.Nil(t, store.UpdateTokenLastAccess("tk_abc", newTime, newOrigin))
|
||||||
|
|
||||||
|
tk, err := store.Token(u.ID, "tk_abc")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, newTime.Unix(), tk.LastAccess.Unix())
|
||||||
|
require.Equal(t, newOrigin, tk.LastOrigin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAllowAccess(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||||
|
grants, err := store.Grants("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, grants, 1)
|
||||||
|
require.Equal(t, "mytopic", grants[0].TopicPattern)
|
||||||
|
require.True(t, grants[0].Permission.IsReadWrite())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAllowAccessReadOnly(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "announcements", true, false, "", false))
|
||||||
|
grants, err := store.Grants("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, grants, 1)
|
||||||
|
require.True(t, grants[0].Permission.IsRead())
|
||||||
|
require.False(t, grants[0].Permission.IsWrite())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreResetAccess(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||||
|
|
||||||
|
grants, err := store.Grants("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, grants, 2)
|
||||||
|
|
||||||
|
require.Nil(t, store.ResetAccess("phil", "topic1"))
|
||||||
|
grants, err = store.Grants("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, grants, 1)
|
||||||
|
require.Equal(t, "topic2", grants[0].TopicPattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreResetAccessAll(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic2", true, false, "", false))
|
||||||
|
|
||||||
|
require.Nil(t, store.ResetAccess("phil", ""))
|
||||||
|
grants, err := store.Grants("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, grants, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAuthorizeTopicAccess(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "", false))
|
||||||
|
|
||||||
|
read, write, found, err := store.AuthorizeTopicAccess("phil", "mytopic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, found)
|
||||||
|
require.True(t, read)
|
||||||
|
require.True(t, write)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAuthorizeTopicAccessNotFound(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
|
||||||
|
_, _, found, err := store.AuthorizeTopicAccess("phil", "other")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAuthorizeTopicAccessDenyAll(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "secret", false, false, "", false))
|
||||||
|
|
||||||
|
read, write, found, err := store.AuthorizeTopicAccess("phil", "secret")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, found)
|
||||||
|
require.False(t, read)
|
||||||
|
require.False(t, write)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreReservations(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||||
|
require.Nil(t, store.AllowAccess(user.Everyone, "mytopic", true, false, "phil", false))
|
||||||
|
|
||||||
|
reservations, err := store.Reservations("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, reservations, 1)
|
||||||
|
require.Equal(t, "mytopic", reservations[0].Topic)
|
||||||
|
require.True(t, reservations[0].Owner.IsReadWrite())
|
||||||
|
require.True(t, reservations[0].Everyone.IsRead())
|
||||||
|
require.False(t, reservations[0].Everyone.IsWrite())
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreReservationsCount(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "phil", false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic2", true, true, "phil", false))
|
||||||
|
|
||||||
|
count, err := store.ReservationsCount("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(2), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreHasReservation(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||||
|
|
||||||
|
has, err := store.HasReservation("phil", "mytopic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, has)
|
||||||
|
|
||||||
|
has, err = store.HasReservation("phil", "other")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.False(t, has)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreReservationOwner(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "mytopic", true, true, "phil", false))
|
||||||
|
|
||||||
|
owner, err := store.ReservationOwner("mytopic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotEmpty(t, owner) // Returns the user ID
|
||||||
|
|
||||||
|
owner, err = store.ReservationOwner("unowned")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, owner)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTiers(t *testing.T, store user.Store) {
|
||||||
|
tier := &user.Tier{
|
||||||
|
ID: "ti_test",
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
MessageLimit: 5000,
|
||||||
|
MessageExpiryDuration: 24 * time.Hour,
|
||||||
|
EmailLimit: 100,
|
||||||
|
CallLimit: 10,
|
||||||
|
ReservationLimit: 20,
|
||||||
|
AttachmentFileSizeLimit: 10 * 1024 * 1024,
|
||||||
|
AttachmentTotalSizeLimit: 100 * 1024 * 1024,
|
||||||
|
AttachmentExpiryDuration: 48 * time.Hour,
|
||||||
|
AttachmentBandwidthLimit: 500 * 1024 * 1024,
|
||||||
|
}
|
||||||
|
require.Nil(t, store.AddTier(tier))
|
||||||
|
|
||||||
|
// Get by code
|
||||||
|
t2, err := store.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ti_test", t2.ID)
|
||||||
|
require.Equal(t, "pro", t2.Code)
|
||||||
|
require.Equal(t, "Pro", t2.Name)
|
||||||
|
require.Equal(t, int64(5000), t2.MessageLimit)
|
||||||
|
require.Equal(t, int64(100), t2.EmailLimit)
|
||||||
|
require.Equal(t, int64(10), t2.CallLimit)
|
||||||
|
require.Equal(t, int64(20), t2.ReservationLimit)
|
||||||
|
|
||||||
|
// List all tiers
|
||||||
|
tiers, err := store.Tiers()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, tiers, 1)
|
||||||
|
require.Equal(t, "pro", tiers[0].Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTierUpdate(t *testing.T, store user.Store) {
|
||||||
|
tier := &user.Tier{
|
||||||
|
ID: "ti_test",
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
}
|
||||||
|
require.Nil(t, store.AddTier(tier))
|
||||||
|
|
||||||
|
tier.Name = "Professional"
|
||||||
|
tier.MessageLimit = 9999
|
||||||
|
require.Nil(t, store.UpdateTier(tier))
|
||||||
|
|
||||||
|
t2, err := store.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "Professional", t2.Name)
|
||||||
|
require.Equal(t, int64(9999), t2.MessageLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTierRemove(t *testing.T, store user.Store) {
|
||||||
|
tier := &user.Tier{
|
||||||
|
ID: "ti_test",
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
}
|
||||||
|
require.Nil(t, store.AddTier(tier))
|
||||||
|
|
||||||
|
t2, err := store.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "pro", t2.Code)
|
||||||
|
|
||||||
|
require.Nil(t, store.RemoveTier("pro"))
|
||||||
|
_, err = store.Tier("pro")
|
||||||
|
require.Equal(t, user.ErrTierNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreTierByStripePrice(t *testing.T, store user.Store) {
|
||||||
|
tier := &user.Tier{
|
||||||
|
ID: "ti_test",
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
StripeMonthlyPriceID: "price_monthly",
|
||||||
|
StripeYearlyPriceID: "price_yearly",
|
||||||
|
}
|
||||||
|
require.Nil(t, store.AddTier(tier))
|
||||||
|
|
||||||
|
t2, err := store.TierByStripePrice("price_monthly")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "pro", t2.Code)
|
||||||
|
|
||||||
|
t3, err := store.TierByStripePrice("price_yearly")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "pro", t3.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreChangeTier(t *testing.T, store user.Store) {
|
||||||
|
tier := &user.Tier{
|
||||||
|
ID: "ti_test",
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
}
|
||||||
|
require.Nil(t, store.AddTier(tier))
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotNil(t, u.Tier)
|
||||||
|
require.Equal(t, "pro", u.Tier.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStorePhoneNumbers(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.AddPhoneNumber(u.ID, "+1234567890"))
|
||||||
|
require.Nil(t, store.AddPhoneNumber(u.ID, "+0987654321"))
|
||||||
|
|
||||||
|
numbers, err := store.PhoneNumbers(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, numbers, 2)
|
||||||
|
|
||||||
|
require.Nil(t, store.RemovePhoneNumber(u.ID, "+1234567890"))
|
||||||
|
numbers, err = store.PhoneNumbers(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, numbers, 1)
|
||||||
|
require.Equal(t, "+0987654321", numbers[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreChangeSettings(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
lang := "de"
|
||||||
|
prefs := &user.Prefs{Language: &lang}
|
||||||
|
require.Nil(t, store.ChangeSettings(u.ID, prefs))
|
||||||
|
|
||||||
|
u2, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.NotNil(t, u2.Prefs)
|
||||||
|
require.Equal(t, "de", *u2.Prefs.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreChangeBilling(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
|
||||||
|
billing := &user.Billing{
|
||||||
|
StripeCustomerID: "cus_123",
|
||||||
|
StripeSubscriptionID: "sub_456",
|
||||||
|
}
|
||||||
|
require.Nil(t, store.ChangeBilling("phil", billing))
|
||||||
|
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "cus_123", u.Billing.StripeCustomerID)
|
||||||
|
require.Equal(t, "sub_456", u.Billing.StripeSubscriptionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreUpdateStats(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
stats := &user.Stats{Messages: 42, Emails: 3, Calls: 1}
|
||||||
|
require.Nil(t, store.UpdateStats(u.ID, stats))
|
||||||
|
|
||||||
|
u2, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(42), u2.Stats.Messages)
|
||||||
|
require.Equal(t, int64(3), u2.Stats.Emails)
|
||||||
|
require.Equal(t, int64(1), u2.Stats.Calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreResetStats(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.UpdateStats(u.ID, &user.Stats{Messages: 42, Emails: 3, Calls: 1}))
|
||||||
|
require.Nil(t, store.ResetStats())
|
||||||
|
|
||||||
|
u2, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), u2.Stats.Messages)
|
||||||
|
require.Equal(t, int64(0), u2.Stats.Emails)
|
||||||
|
require.Equal(t, int64(0), u2.Stats.Calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreMarkUserRemoved(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||||
|
|
||||||
|
u2, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, u2.Deleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreRemoveDeletedUsers(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
u, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.MarkUserRemoved(u.ID))
|
||||||
|
|
||||||
|
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
|
||||||
|
// Immediately after marking, the user should still exist.
|
||||||
|
require.Nil(t, store.RemoveDeletedUsers())
|
||||||
|
u2, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, u2.Deleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreAllGrants(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||||
|
phil, err := store.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
ben, err := store.User("ben")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Nil(t, store.AllowAccess("phil", "topic1", true, true, "", false))
|
||||||
|
require.Nil(t, store.AllowAccess("ben", "topic2", true, false, "", false))
|
||||||
|
|
||||||
|
grants, err := store.AllGrants()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Contains(t, grants, phil.ID)
|
||||||
|
require.Contains(t, grants, ben.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStoreOtherAccessCount(t *testing.T, store user.Store) {
|
||||||
|
require.Nil(t, store.AddUser("phil", "philhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AddUser("ben", "benhash", user.RoleUser, false))
|
||||||
|
require.Nil(t, store.AllowAccess("ben", "mytopic", true, true, "ben", false))
|
||||||
|
|
||||||
|
count, err := store.OtherAccessCount("phil", "mytopic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
}
|
||||||
@@ -242,6 +242,20 @@ const (
|
|||||||
everyoneID = "u_everyone"
|
everyoneID = "u_everyone"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config holds the configuration for the user Manager
|
||||||
|
type Config struct {
|
||||||
|
Filename string // Database filename, e.g. "/var/lib/ntfy/user.db" (SQLite)
|
||||||
|
DatabaseURL string // Database connection string (PostgreSQL)
|
||||||
|
StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers (SQLite only)
|
||||||
|
DefaultAccess Permission // Default permission if no ACL matches
|
||||||
|
ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
||||||
|
Users []*User // Predefined users to create on startup
|
||||||
|
Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
||||||
|
Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
||||||
|
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
||||||
|
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
||||||
|
}
|
||||||
|
|
||||||
// Error constants used by the package
|
// Error constants used by the package
|
||||||
var (
|
var (
|
||||||
ErrUnauthenticated = errors.New("unauthenticated")
|
ErrUnauthenticated = errors.New("unauthenticated")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package webpush
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
)
|
)
|
||||||
@@ -25,8 +26,8 @@ const (
|
|||||||
PRIMARY KEY (subscription_id, topic)
|
PRIMARY KEY (subscription_id, topic)
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
CREATE INDEX IF NOT EXISTS idx_webpush_topic ON webpush_subscription_topic (topic);
|
||||||
CREATE TABLE IF NOT EXISTS webpush_schema_version (
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
id INT PRIMARY KEY,
|
store TEXT PRIMARY KEY,
|
||||||
version INT NOT NULL
|
version INT NOT NULL
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
@@ -65,8 +66,8 @@ const (
|
|||||||
// PostgreSQL schema management queries
|
// PostgreSQL schema management queries
|
||||||
const (
|
const (
|
||||||
pgCurrentSchemaVersion = 1
|
pgCurrentSchemaVersion = 1
|
||||||
pgInsertSchemaVersion = `INSERT INTO webpush_schema_version VALUES (1, $1)`
|
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('webpush', $1)`
|
||||||
pgSelectSchemaVersionQuery = `SELECT version FROM webpush_schema_version WHERE id = 1`
|
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'webpush'`
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
// NewPostgresStore creates a new PostgreSQL-backed web push store.
|
||||||
@@ -102,12 +103,15 @@ func NewPostgresStore(dsn string) (Store, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupPostgresDB(db *sql.DB) error {
|
func setupPostgresDB(db *sql.DB) error {
|
||||||
// If 'webpush_schema_version' table does not exist, this must be a new database
|
var schemaVersion int
|
||||||
rows, err := db.Query(pgSelectSchemaVersionQuery)
|
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewPostgresDB(db)
|
return setupNewPostgresDB(db)
|
||||||
}
|
}
|
||||||
return rows.Close()
|
if schemaVersion > pgCurrentSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewPostgresDB(db *sql.DB) error {
|
func setupNewPostgresDB(db *sql.DB) error {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package webpush
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
)
|
)
|
||||||
@@ -82,10 +83,10 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := setupSQLiteWebPushDB(db); err != nil {
|
if err := setupSQLite(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil {
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &commonStore{
|
return &commonStore{
|
||||||
@@ -108,16 +109,19 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupSQLiteWebPushDB(db *sql.DB) error {
|
func setupSQLite(db *sql.DB) error {
|
||||||
// If 'schemaVersion' table does not exist, this must be a new database
|
var schemaVersion int
|
||||||
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
|
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return setupNewSQLiteWebPushDB(db)
|
return setupNewSQLite(db)
|
||||||
}
|
}
|
||||||
return rows.Close()
|
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
func setupNewSQLite(db *sql.DB) error {
|
||||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -127,7 +131,7 @@ func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||||
if _, err := db.Exec(startupQueries); err != nil {
|
if _, err := db.Exec(startupQueries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user