Make user tests work for postgres and sqlite

This commit is contained in:
binwiederhier
2026-02-17 20:14:45 -05:00
parent 1abc1005d0
commit e3a402ed95
2 changed files with 1156 additions and 1045 deletions

View File

@@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"net/netip" "net/netip"
"net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
@@ -16,8 +18,53 @@ import (
const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
// newStoreFunc creates a Store. Calling it multiple times within the same test
// returns a new Store object pointing at the same underlying data (same SQLite
// file / same PostgreSQL schema), enabling close-and-reopen tests.
type newStoreFunc func() Store
func forEachBackend(t *testing.T, f func(t *testing.T, newStore newStoreFunc)) {
t.Run("sqlite", func(t *testing.T) {
dir := t.TempDir()
f(t, func() Store {
store, err := NewSQLiteStore(filepath.Join(dir, "user.db"), "")
require.Nil(t, err)
return store
})
})
t.Run("postgres", func(t *testing.T) {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set")
}
schema := fmt.Sprintf("test_%s", util.RandomString(10))
setupDB, err := sql.Open("pgx", dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("search_path", schema)
u.RawQuery = q.Encode()
schemaDSN := u.String()
t.Cleanup(func() {
cleanDB, _ := sql.Open("pgx", dsn)
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
})
f(t, func() Store {
store, err := NewPostgresStore(schemaDSN)
require.Nil(t, err)
return store
})
})
}
func TestManager_FullScenario_Default_DenyAll(t *testing.T) { func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AddUser("john", "john", RoleUser, false)) require.Nil(t, a.AddUser("john", "john", RoleUser, false))
@@ -36,7 +83,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
phil, err := a.Authenticate("phil", "phil") phil, err := a.Authenticate("phil", "phil")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "phil", phil.Name) require.Equal(t, "phil", phil.Name)
require.True(t, strings.HasPrefix(phil.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(phil.Hash, "$2a$04$"))
require.Equal(t, RoleAdmin, phil.Role) require.Equal(t, RoleAdmin, phil.Role)
philGrants, err := a.Grants("phil") philGrants, err := a.Grants("phil")
@@ -46,7 +93,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
ben, err := a.Authenticate("ben", "ben") ben, err := a.Authenticate("ben", "ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "ben", ben.Name) require.Equal(t, "ben", ben.Name)
require.True(t, strings.HasPrefix(ben.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(ben.Hash, "$2a$04$"))
require.Equal(t, RoleUser, ben.Role) require.Equal(t, RoleUser, ben.Role)
benGrants, err := a.Grants("ben") benGrants, err := a.Grants("ben")
@@ -61,7 +108,7 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
john, err := a.Authenticate("john", "john") john, err := a.Authenticate("john", "john")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "john", john.Name) require.Equal(t, "john", john.Name)
require.True(t, strings.HasPrefix(john.Hash, "$2a$10$")) require.True(t, strings.HasPrefix(john.Hash, "$2a$04$"))
require.Equal(t, RoleUser, john.Role) require.Equal(t, RoleUser, john.Role)
johnGrants, err := a.Grants("john") johnGrants, err := a.Grants("john")
@@ -127,13 +174,14 @@ func TestManager_FullScenario_Default_DenyAll(t *testing.T) {
require.Nil(t, a.Authorize(nil, "everyonewrite", PermissionWrite)) require.Nil(t, a.Authorize(nil, "everyonewrite", PermissionWrite))
require.Nil(t, a.Authorize(nil, "up1234", PermissionWrite)) // Wildcard permission require.Nil(t, a.Authorize(nil, "up1234", PermissionWrite)) // Wildcard permission
require.Nil(t, a.Authorize(nil, "up5678", PermissionWrite)) require.Nil(t, a.Authorize(nil, "up5678", PermissionWrite))
})
} }
func TestManager_Access_Order_LengthWriteRead(t *testing.T) { func TestManager_Access_Order_LengthWriteRead(t *testing.T) {
// This test validates issue #914 / #917, i.e. that write permissions are prioritized over read permissions, // This test validates issue #914 / #917, i.e. that write permissions are prioritized over read permissions,
// and longer ACL rules are prioritized as well. // and longer ACL rules are prioritized as well.
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AllowAccess("ben", "test*", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "test*", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "*", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "*", PermissionRead))
@@ -143,12 +191,15 @@ func TestManager_Access_Order_LengthWriteRead(t *testing.T) {
require.Nil(t, a.Authorize(ben, "any-topic-can-be-read", PermissionRead)) require.Nil(t, a.Authorize(ben, "any-topic-can-be-read", PermissionRead))
require.Nil(t, a.Authorize(ben, "this-too", PermissionRead)) require.Nil(t, a.Authorize(ben, "this-too", PermissionRead))
require.Nil(t, a.Authorize(ben, "test123", PermissionWrite)) require.Nil(t, a.Authorize(ben, "test123", PermissionWrite))
})
} }
func TestManager_AddUser_Invalid(t *testing.T) { func TestManager_AddUser_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, false)) require.Equal(t, ErrInvalidArgument, a.AddUser(" invalid ", "pass", RoleAdmin, false))
require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", false)) require.Equal(t, ErrInvalidArgument, a.AddUser("validuser", "pass", "invalid-role", false))
})
} }
func TestManager_AddUser_Timing(t *testing.T) { func TestManager_AddUser_Timing(t *testing.T) {
@@ -159,7 +210,8 @@ func TestManager_AddUser_Timing(t *testing.T) {
} }
func TestManager_AddUser_And_Query(t *testing.T) { func TestManager_AddUser_And_Query(t *testing.T) {
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
require.Nil(t, a.ChangeBilling("user", &Billing{ require.Nil(t, a.ChangeBilling("user", &Billing{
StripeCustomerID: "acct_123", StripeCustomerID: "acct_123",
@@ -181,10 +233,12 @@ func TestManager_AddUser_And_Query(t *testing.T) {
u3, err := a.UserByStripeCustomer("acct_123") u3, err := a.UserByStripeCustomer("acct_123")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, u.ID, u3.ID) require.Equal(t, u.ID, u3.ID)
})
} }
func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
// Create user, add reservations and token // Create user, add reservations and token
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
@@ -225,16 +279,20 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.True(t, u.Deleted) require.True(t, u.Deleted)
_, err = testDB(a).Exec("UPDATE user SET deleted = ? WHERE id = ?", time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID) // Backdate the deleted timestamp so RemoveDeletedUsers will prune the user
q := a.store.(*commonStore).queries.updateUserDeleted
_, err = testDB(a).Exec(q, time.Now().Add(-1*(userHardDeleteAfterDuration+time.Hour)).Unix(), u.ID)
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.RemoveDeletedUsers()) require.Nil(t, a.RemoveDeletedUsers())
_, err = a.User("user") _, err = a.User("user")
require.Equal(t, ErrUserNotFound, err) require.Equal(t, ErrUserNotFound, err)
})
} }
func TestManager_CreateToken_Only_Lower(t *testing.T) { func TestManager_CreateToken_Only_Lower(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
// Create user, add reservations and token // Create user, add reservations and token
require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false)) require.Nil(t, a.AddUser("user", "pass", RoleAdmin, false))
@@ -244,10 +302,12 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) {
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false) token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, token.Value, strings.ToLower(token.Value)) require.Equal(t, token.Value, strings.ToLower(token.Value))
})
} }
func TestManager_UserManagement(t *testing.T) { func TestManager_UserManagement(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
@@ -335,10 +395,12 @@ func TestManager_UserManagement(t *testing.T) {
require.Equal(t, 2, len(users)) require.Equal(t, 2, len(users))
require.Equal(t, "phil", users[0].Name) require.Equal(t, "phil", users[0].Name)
require.Equal(t, "*", users[1].Name) require.Equal(t, "*", users[1].Name)
})
} }
func TestManager_ChangePassword(t *testing.T) { func TestManager_ChangePassword(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false)) require.Nil(t, a.AddUser("phil", "phil", RoleAdmin, false))
require.Nil(t, a.AddUser("jane", "$2a$10$OyqU72muEy7VMd1SAU2Iru5IbeSMgrtCGHu/fWLmxL1MwlijQXWbG", RoleUser, true)) require.Nil(t, a.AddUser("jane", "$2a$10$OyqU72muEy7VMd1SAU2Iru5IbeSMgrtCGHu/fWLmxL1MwlijQXWbG", RoleUser, true))
@@ -359,10 +421,12 @@ func TestManager_ChangePassword(t *testing.T) {
require.Equal(t, ErrUnauthenticated, err) require.Equal(t, ErrUnauthenticated, err)
_, err = a.Authenticate("jane", "newpass") _, err = a.Authenticate("jane", "newpass")
require.Nil(t, err) require.Nil(t, err)
})
} }
func TestManager_ChangeRole(t *testing.T) { func TestManager_ChangeRole(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite)) require.Nil(t, a.AllowAccess("ben", "mytopic", PermissionReadWrite))
require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "readme", PermissionRead))
@@ -384,10 +448,12 @@ func TestManager_ChangeRole(t *testing.T) {
benGrants, err = a.Grants("ben") benGrants, err = a.Grants("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 0, len(benGrants)) require.Equal(t, 0, len(benGrants))
})
} }
func TestManager_Reservations(t *testing.T) { func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
@@ -454,10 +520,12 @@ func TestManager_Reservations(t *testing.T) {
count, err = a.ReservationsCount("ben") count, err = a.ReservationsCount("ben")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), count) require.Equal(t, int64(0), count)
})
} }
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) { func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
Code: "pro", Code: "pro",
Name: "ntfy Pro", Name: "ntfy Pro",
@@ -513,10 +581,12 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
everyoneGrants, err = a.Grants(Everyone) everyoneGrants, err = a.Grants(Everyone)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 0, len(everyoneGrants)) require.Equal(t, 0, len(everyoneGrants))
})
} }
func TestManager_Token_Valid(t *testing.T) { func TestManager_Token_Valid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
u, err := a.User("ben") u, err := a.User("ben")
@@ -557,10 +627,12 @@ func TestManager_Token_Valid(t *testing.T) {
tokens, err = a.Tokens(u.ID) tokens, err = a.Tokens(u.ID)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 0, len(tokens)) require.Equal(t, 0, len(tokens))
})
} }
func TestManager_Token_Invalid(t *testing.T) { func TestManager_Token_Invalid(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length u, err := a.AuthenticateToken(strings.Repeat("x", 32)) // 32 == token length
@@ -570,16 +642,20 @@ func TestManager_Token_Invalid(t *testing.T) {
u, err = a.AuthenticateToken("not long enough anyway") u, err = a.AuthenticateToken("not long enough anyway")
require.Nil(t, u) require.Nil(t, u)
require.Equal(t, ErrUnauthenticated, err) require.Equal(t, ErrUnauthenticated, err)
})
} }
func TestManager_Token_NotFound(t *testing.T) { func TestManager_Token_NotFound(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
_, err := a.Token("u_bla", "notfound") _, err := a.Token("u_bla", "notfound")
require.Equal(t, ErrTokenNotFound, err) require.Equal(t, ErrTokenNotFound, err)
})
} }
func TestManager_Token_Expire(t *testing.T) { func TestManager_Token_Expire(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
u, err := a.User("ben") u, err := a.User("ben")
@@ -604,30 +680,33 @@ func TestManager_Token_Expire(t *testing.T) {
_, err = a.AuthenticateToken(token2.Value) _, err = a.AuthenticateToken(token2.Value)
require.Nil(t, err) require.Nil(t, err)
// Modify token expiration in database // Expire token1 via the API
_, err = testDB(a).Exec("UPDATE user_token SET expires = 1 WHERE token = ?", token1.Value) _, err = a.ChangeToken(u.ID, token1.Value, nil, util.Time(time.Unix(1, 0)))
require.Nil(t, err) require.Nil(t, err)
// Now token1 shouldn't work anymore // Now token1 shouldn't work anymore
_, err = a.AuthenticateToken(token1.Value) _, err = a.AuthenticateToken(token1.Value)
require.Equal(t, ErrUnauthenticated, err) require.Equal(t, ErrUnauthenticated, err)
result, err := testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value) // But the token row should still exist
tokens, err := a.Tokens(u.ID)
require.Nil(t, err) require.Nil(t, err)
require.True(t, result.Next()) // Still a matching row require.Equal(t, token1.Value, tokens[0].Value)
require.Nil(t, result.Close()) require.Equal(t, 2, len(tokens))
// Expire tokens and check database rows // Expire tokens and check that token1 is gone
require.Nil(t, a.RemoveExpiredTokens()) require.Nil(t, a.RemoveExpiredTokens())
result, err = testDB(a).Query("SELECT * from user_token WHERE token = ?", token1.Value) tokens, err = a.Tokens(u.ID)
require.Nil(t, err) require.Nil(t, err)
require.False(t, result.Next()) // No matching row! require.Equal(t, 1, len(tokens))
require.Nil(t, result.Close()) require.Equal(t, token2.Value, tokens[0].Value)
})
} }
func TestManager_Token_Extend(t *testing.T) { func TestManager_Token_Extend(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
// Try to extend token for user without token // Try to extend token for user without token
@@ -651,12 +730,13 @@ func TestManager_Token_Extend(t *testing.T) {
require.Equal(t, "changed label", extendedToken.Label) require.Equal(t, "changed label", extendedToken.Label)
require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix()) require.True(t, token.Expires.Unix() < extendedToken.Expires.Unix())
require.True(t, time.Now().Add(99*time.Hour).Unix() < extendedToken.Expires.Unix()) require.True(t, time.Now().Add(99*time.Hour).Unix() < extendedToken.Expires.Unix())
})
} }
func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
// Tests that tokens are automatically deleted when the maximum number of tokens is reached // Tests that tokens are automatically deleted when the maximum number of tokens is reached
forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
@@ -688,7 +768,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
benTokens = append(benTokens, token.Value) benTokens = append(benTokens, token.Value)
// Manually modify expiry date to avoid sorting issues (this is a hack) // Manually modify expiry date to avoid sorting issues (this is a hack)
_, err = testDB(a).Exec(`UPDATE user_token SET expires=? WHERE token=?`, baseTime.Add(time.Duration(i)*time.Minute).Unix(), token.Value) _, err = a.ChangeToken(ben.ID, token.Value, nil, util.Time(baseTime.Add(time.Duration(i)*time.Minute)))
require.Nil(t, err) require.Nil(t, err)
} }
@@ -715,29 +795,24 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
require.Equal(t, philTokens[i], userWithToken.Token) require.Equal(t, philTokens[i], userWithToken.Token)
} }
var benCount int benTokensList, err := a.Tokens(ben.ID)
rows, err := testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, ben.ID)
require.Nil(t, err) require.Nil(t, err)
require.True(t, rows.Next()) require.Equal(t, 60, len(benTokensList))
require.Nil(t, rows.Scan(&benCount))
require.Equal(t, 60, benCount)
var philCount int philTokensList, err := a.Tokens(phil.ID)
rows, err = testDB(a).Query(`SELECT COUNT(*) FROM user_token WHERE user_id=?`, phil.ID)
require.Nil(t, err) require.Nil(t, err)
require.True(t, rows.Next()) require.Equal(t, 2, len(philTokensList))
require.Nil(t, rows.Scan(&philCount)) })
require.Equal(t, 2, philCount)
} }
func TestManager_EnqueueStats_ResetStats(t *testing.T) { func TestManager_EnqueueStats_ResetStats(t *testing.T) {
filename := filepath.Join(t.TempDir(), "db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
conf := &Config{ conf := &Config{
DefaultAccess: PermissionReadWrite, DefaultAccess: PermissionReadWrite,
BcryptCost: bcrypt.MinCost, BcryptCost: bcrypt.MinCost,
QueueWriterInterval: 1500 * time.Millisecond, QueueWriterInterval: 1500 * time.Millisecond,
} }
a := newTestManagerFromStoreConfig(t, filename, conf) a := newTestManagerFromStoreConfig(t, newStore, conf)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
// Baseline: No messages or emails // Baseline: No messages or emails
@@ -775,16 +850,17 @@ func TestManager_EnqueueStats_ResetStats(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), u.Stats.Messages) require.Equal(t, int64(0), u.Stats.Messages)
require.Equal(t, int64(0), u.Stats.Emails) require.Equal(t, int64(0), u.Stats.Emails)
})
} }
func TestManager_EnqueueTokenUpdate(t *testing.T) { func TestManager_EnqueueTokenUpdate(t *testing.T) {
filename := filepath.Join(t.TempDir(), "db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
conf := &Config{ conf := &Config{
DefaultAccess: PermissionReadWrite, DefaultAccess: PermissionReadWrite,
BcryptCost: bcrypt.MinCost, BcryptCost: bcrypt.MinCost,
QueueWriterInterval: 500 * time.Millisecond, QueueWriterInterval: 500 * time.Millisecond,
} }
a := newTestManagerFromStoreConfig(t, filename, conf) a := newTestManagerFromStoreConfig(t, newStore, conf)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
// Create user and token // Create user and token
@@ -813,16 +889,17 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix()) require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix())
require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin) require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin)
})
} }
func TestManager_ChangeSettings(t *testing.T) { func TestManager_ChangeSettings(t *testing.T) {
filename := filepath.Join(t.TempDir(), "db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
conf := &Config{ conf := &Config{
DefaultAccess: PermissionReadWrite, DefaultAccess: PermissionReadWrite,
BcryptCost: bcrypt.MinCost, BcryptCost: bcrypt.MinCost,
QueueWriterInterval: 1500 * time.Millisecond, QueueWriterInterval: 1500 * time.Millisecond,
} }
a := newTestManagerFromStoreConfig(t, filename, conf) a := newTestManagerFromStoreConfig(t, newStore, conf)
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
// No settings // No settings
@@ -859,10 +936,12 @@ func TestManager_ChangeSettings(t *testing.T) {
require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL) require.Equal(t, "https://ntfy.sh", u.Prefs.Subscriptions[0].BaseURL)
require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic) require.Equal(t, "mytopic", u.Prefs.Subscriptions[0].Topic)
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName) require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
})
} }
func TestManager_Tier_Create_Update_List_Delete(t *testing.T) { func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
// Create tier and user // Create tier and user
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
@@ -977,10 +1056,12 @@ func TestManager_Tier_Create_Update_List_Delete(t *testing.T) {
require.Equal(t, 1, len(tiers)) require.Equal(t, 1, len(tiers))
require.Equal(t, "pro", tiers[0].Code) require.Equal(t, "pro", tiers[0].Code)
require.Equal(t, "pro", tiers[0].Code) require.Equal(t, "pro", tiers[0].Code)
})
} }
func TestAccount_Tier_Create_With_ID(t *testing.T) { func TestAccount_Tier_Create_With_ID(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
ID: "ti_123", ID: "ti_123",
@@ -990,10 +1071,12 @@ func TestAccount_Tier_Create_With_ID(t *testing.T) {
ti, err := a.Tier("pro") ti, err := a.Tier("pro")
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "ti_123", ti.ID) require.Equal(t, "ti_123", ti.ID)
})
} }
func TestManager_Tier_Change_And_Reset(t *testing.T) { func TestManager_Tier_Change_And_Reset(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
// Create tier and user // Create tier and user
require.Nil(t, a.AddTier(&Tier{ require.Nil(t, a.AddTier(&Tier{
@@ -1027,10 +1110,12 @@ func TestManager_Tier_Change_And_Reset(t *testing.T) {
// Resetting after removing all reservations // Resetting after removing all reservations
require.Nil(t, a.RemoveReservations("phil", "topic1", "topic2", "topic3")) require.Nil(t, a.RemoveReservations("phil", "topic1", "topic2", "topic3"))
require.Nil(t, a.ResetTier("phil")) require.Nil(t, a.ResetTier("phil"))
})
} }
func TestUser_PhoneNumberAddListRemove(t *testing.T) { func TestUser_PhoneNumberAddListRemove(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
phil, err := a.User("phil") phil, err := a.User("phil")
@@ -1052,10 +1137,12 @@ func TestUser_PhoneNumberAddListRemove(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.False(t, rows.Next()) require.False(t, rows.Next())
require.Nil(t, rows.Close()) require.Nil(t, rows.Close())
})
} }
func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) { func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleUser, false)) require.Nil(t, a.AddUser("phil", "phil", RoleUser, false))
require.Nil(t, a.AddUser("ben", "ben", RoleUser, false)) require.Nil(t, a.AddUser("ben", "ben", RoleUser, false))
@@ -1065,11 +1152,12 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, a.AddPhoneNumber(phil.ID, "+1234567890")) require.Nil(t, a.AddPhoneNumber(phil.ID, "+1234567890"))
require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890")) require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
})
} }
func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) { func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead)) require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead))
require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead)) require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead))
require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead)) require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead))
@@ -1078,20 +1166,22 @@ func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead)) require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead)) require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead)) require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead))
})
} }
func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) { func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval) a := newTestManager(t, newStore, PermissionDenyAll)
require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite)) require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite))
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead)) require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite)) require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead)) require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite)) require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite))
})
} }
func TestManager_WithProvisionedUsers(t *testing.T) { func TestManager_WithProvisionedUsers(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
conf := &Config{ conf := &Config{
DefaultAccess: PermissionReadWrite, DefaultAccess: PermissionReadWrite,
ProvisionEnabled: true, ProvisionEnabled: true,
@@ -1111,7 +1201,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
}, },
}, },
} }
a := newTestManagerFromStoreConfig(t, f, conf) a := newTestManagerFromStoreConfig(t, newStore, conf)
// Manually add user // Manually add user
require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false)) require.Nil(t, a.AddUser("philmanual", "manual", RoleUser, false))
@@ -1167,7 +1257,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
{Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"}, {Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"},
}, },
} }
a = newTestManagerFromStoreConfig(t, f, conf) a = newTestManagerFromStoreConfig(t, newStore, conf)
// Check that the provisioned users are there // Check that the provisioned users are there
users, err = a.Users() users, err = a.Users()
@@ -1206,7 +1296,7 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
conf.Users = []*User{} conf.Users = []*User{}
conf.Access = map[string][]*Grant{} conf.Access = map[string][]*Grant{}
conf.Tokens = map[string][]*Token{} conf.Tokens = map[string][]*Token{}
a = newTestManagerFromStoreConfig(t, f, conf) a = newTestManagerFromStoreConfig(t, newStore, conf)
// Check that the provisioned users are all gone // Check that the provisioned users are all gone
users, err = a.Users() users, err = a.Users()
@@ -1225,16 +1315,25 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 0, len(tokens)) require.Equal(t, 0, len(tokens))
var count int // Verify no provisioned data remains
testDB(a).QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count) for _, u := range users {
require.Equal(t, 0, count) require.False(t, u.Provisioned)
testDB(a).QueryRow("SELECT COUNT(*) FROM user_access WHERE provisioned = 1").Scan(&count) userGrants, err := a.Grants(u.Name)
require.Equal(t, 0, count) require.Nil(t, err)
testDB(a).QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count) for _, g := range userGrants {
require.False(t, g.Provisioned)
}
userTokens, err := a.Tokens(u.ID)
require.Nil(t, err)
for _, tk := range userTokens {
require.False(t, tk.Provisioned)
}
}
})
} }
func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) { func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db") forEachBackend(t, func(t *testing.T, newStore newStoreFunc) {
conf := &Config{ conf := &Config{
DefaultAccess: PermissionReadWrite, DefaultAccess: PermissionReadWrite,
ProvisionEnabled: true, ProvisionEnabled: true,
@@ -1245,7 +1344,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
}, },
}, },
} }
a := newTestManagerFromStoreConfig(t, f, conf) a := newTestManagerFromStoreConfig(t, newStore, conf)
// Manually add user // Manually add user
require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false)) require.Nil(t, a.AddUser("philuser", "manual", RoleUser, false))
@@ -1286,7 +1385,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
{TopicPattern: "stats", Permission: PermissionReadWrite}, {TopicPattern: "stats", Permission: PermissionReadWrite},
}, },
} }
a = newTestManagerFromStoreConfig(t, f, conf) a = newTestManagerFromStoreConfig(t, newStore, conf)
// Check that the user was "upgraded" to a provisioned user // Check that the user was "upgraded" to a provisioned user
users, err = a.Users() users, err = a.Users()
@@ -1310,6 +1409,7 @@ func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {
grants, err = a.Grants(Everyone) grants, err = a.Grants(Everyone)
require.Nil(t, err) require.Nil(t, err)
require.Empty(t, grants) require.Empty(t, grants)
})
} }
func TestToFromSQLWildcard(t *testing.T) { func TestToFromSQLWildcard(t *testing.T) {
@@ -1568,8 +1668,16 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
require.Nil(t, rows.Close()) require.Nil(t, rows.Close())
} }
func newTestManager(t *testing.T, defaultAccess Permission) *Manager { func newTestManager(t *testing.T, newStore newStoreFunc, defaultAccess Permission) *Manager {
return newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", defaultAccess, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval) store := newStore()
a, err := NewManager(store, &Config{
DefaultAccess: defaultAccess,
BcryptCost: bcrypt.MinCost,
QueueWriterInterval: DefaultUserStatsQueueWriterInterval,
})
require.Nil(t, err)
t.Cleanup(func() { a.Close() })
return a
} }
func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager { func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) *Manager {
@@ -1585,14 +1693,14 @@ func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defau
return a return a
} }
func newTestManagerFromStoreConfig(t *testing.T, newStore newStoreFunc, conf *Config) *Manager {
store := newStore()
a, err := NewManager(store, conf)
require.Nil(t, err)
t.Cleanup(func() { a.Close() })
return a
}
func testDB(a *Manager) *sql.DB { func testDB(a *Manager) *sql.DB {
return a.store.(*commonStore).db return a.store.(*commonStore).db
} }
func newTestManagerFromStoreConfig(t *testing.T, filename string, conf *Config) *Manager {
store, err := NewSQLiteStore(filename, "")
require.Nil(t, err)
a, err := NewManager(store, conf)
require.Nil(t, err)
return a
}

View File

@@ -19,6 +19,7 @@ type Store interface {
User(username string) (*User, error) User(username string) (*User, error)
UserByToken(token string) (*User, error) UserByToken(token string) (*User, error)
UserByStripeCustomer(customerID string) (*User, error) UserByStripeCustomer(customerID string) (*User, error)
UserIDByUsername(username string) (string, error)
Users() ([]*User, error) Users() ([]*User, error)
UsersCount() (int64, error) UsersCount() (int64, error)
AddUser(username, hash string, role Role, provisioned bool) error AddUser(username, hash string, role Role, provisioned bool) error
@@ -33,6 +34,7 @@ type Store interface {
ResetTier(username string) error ResetTier(username string) error
UpdateStats(userID string, stats *Stats) error UpdateStats(userID string, stats *Stats) error
ResetStats() error ResetStats() error
// Token operations // Token operations
CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error) CreateToken(userID, token, label string, lastAccess time.Time, lastOrigin netip.Addr, expires time.Time, provisioned bool) (*Token, error)
Token(userID, token string) (*Token, error) Token(userID, token string) (*Token, error)
@@ -45,6 +47,7 @@ type Store interface {
RemoveExpiredTokens() error RemoveExpiredTokens() error
TokenCount(userID string) (int, error) TokenCount(userID string) (int, error)
RemoveExcessTokens(userID string, maxCount int) error RemoveExcessTokens(userID string, maxCount int) error
// Access operations // Access operations
AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error)
AllGrants() (map[string][]Grant, error) AllGrants() (map[string][]Grant, error)
@@ -57,6 +60,7 @@ type Store interface {
ReservationsCount(username string) (int64, error) ReservationsCount(username string) (int64, error)
ReservationOwner(topic string) (string, error) ReservationOwner(topic string) (string, error)
OtherAccessCount(username, topic string) (int, error) OtherAccessCount(username, topic string) (int, error)
// Tier operations // Tier operations
AddTier(tier *Tier) error AddTier(tier *Tier) error
UpdateTier(tier *Tier) error UpdateTier(tier *Tier) error
@@ -64,15 +68,14 @@ type Store interface {
Tiers() ([]*Tier, error) Tiers() ([]*Tier, error)
Tier(code string) (*Tier, error) Tier(code string) (*Tier, error)
TierByStripePrice(priceID string) (*Tier, error) TierByStripePrice(priceID string) (*Tier, error)
// Phone operations // Phone operations
PhoneNumbers(userID string) ([]string, error) PhoneNumbers(userID string) ([]string, error)
AddPhoneNumber(userID, phoneNumber string) error AddPhoneNumber(userID, phoneNumber string) error
RemovePhoneNumber(userID, phoneNumber string) error RemovePhoneNumber(userID, phoneNumber string) error
// Billing
// Other stuff
ChangeBilling(username string, billing *Billing) error ChangeBilling(username string, billing *Billing) error
// Internal helpers
UserIDByUsername(username string) (string, error)
// System
Close() error Close() error
} }