Re-org
This commit is contained in:
@@ -3,14 +3,15 @@ package user
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const minBcryptTimingMillis = int64(40) // Ideally should be >100ms, but this should also run on a Raspberry Pi without massive resources
|
||||
@@ -1563,7 +1564,7 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
||||
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
||||
require.Equal(t, sqliteCurrentSchemaVersion, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
|
||||
@@ -2,87 +2,12 @@ package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL schema and queries
|
||||
// PostgreSQL queries
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
|
||||
// User queries
|
||||
postgresSelectUserByID = `
|
||||
SELECT u.id, u.user_name, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, u.deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
@@ -279,15 +204,6 @@ const (
|
||||
SET stripe_customer_id = $1, stripe_subscription_id = $2, stripe_subscription_status = $3, stripe_subscription_interval = $4, stripe_subscription_paid_until = $5, stripe_subscription_cancel_at = $6
|
||||
WHERE user_name = $7
|
||||
`
|
||||
|
||||
// Schema version queries
|
||||
postgresSelectSchemaVersionExists = `SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = current_schema()
|
||||
AND table_name = 'schema_version'
|
||||
)`
|
||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed user store
|
||||
@@ -299,126 +215,78 @@ func NewPostgresStore(dsn string) (Store, error) {
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgresDB(db); err != nil {
|
||||
if err := setupPostgres(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: postgresQueries(),
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
// User queries
|
||||
selectUserByID: postgresSelectUserByID,
|
||||
selectUserByName: postgresSelectUserByName,
|
||||
selectUserByToken: postgresSelectUserByToken,
|
||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||
selectUsernames: postgresSelectUsernames,
|
||||
selectUserCount: postgresSelectUserCount,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||
insertUser: postgresInsertUser,
|
||||
updateUserPass: postgresUpdateUserPass,
|
||||
updateUserRole: postgresUpdateUserRole,
|
||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||
updateUserPrefs: postgresUpdateUserPrefs,
|
||||
updateUserStats: postgresUpdateUserStats,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||
updateUserTier: postgresUpdateUserTier,
|
||||
updateUserDeleted: postgresUpdateUserDeleted,
|
||||
deleteUser: postgresDeleteUser,
|
||||
deleteUserTier: postgresDeleteUserTier,
|
||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms: postgresSelectTopicPerms,
|
||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||
selectUserAccess: postgresSelectUserAccess,
|
||||
selectUserReservations: postgresSelectUserReservations,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||
upsertUserAccess: postgresUpsertUserAccess,
|
||||
deleteUserAccess: postgresDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||
deleteAllAccess: postgresDeleteAllAccess,
|
||||
|
||||
// Token queries
|
||||
selectToken: postgresSelectToken,
|
||||
selectTokens: postgresSelectTokens,
|
||||
selectTokenCount: postgresSelectTokenCount,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||
upsertToken: postgresUpsertToken,
|
||||
updateTokenLabel: postgresUpdateTokenLabel,
|
||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||
deleteToken: postgresDeleteToken,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||
deleteAllToken: postgresDeleteAllToken,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||
|
||||
// Tier queries
|
||||
insertTier: postgresInsertTier,
|
||||
selectTiers: postgresSelectTiers,
|
||||
selectTierByCode: postgresSelectTierByCode,
|
||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||
updateTier: postgresUpdateTier,
|
||||
deleteTier: postgresDeleteTier,
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||
|
||||
// Billing queries
|
||||
updateBilling: postgresUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func postgresQueries() storeQueries {
|
||||
return storeQueries{
|
||||
// User queries
|
||||
selectUserByID: postgresSelectUserByID,
|
||||
selectUserByName: postgresSelectUserByName,
|
||||
selectUserByToken: postgresSelectUserByToken,
|
||||
selectUserByStripeID: postgresSelectUserByStripeID,
|
||||
selectUsernames: postgresSelectUsernames,
|
||||
selectUserCount: postgresSelectUserCount,
|
||||
selectUserIDFromUsername: postgresSelectUserIDFromUsername,
|
||||
insertUser: postgresInsertUser,
|
||||
updateUserPass: postgresUpdateUserPass,
|
||||
updateUserRole: postgresUpdateUserRole,
|
||||
updateUserProvisioned: postgresUpdateUserProvisioned,
|
||||
updateUserPrefs: postgresUpdateUserPrefs,
|
||||
updateUserStats: postgresUpdateUserStats,
|
||||
updateUserStatsResetAll: postgresUpdateUserStatsResetAll,
|
||||
updateUserTier: postgresUpdateUserTier,
|
||||
updateUserDeleted: postgresUpdateUserDeleted,
|
||||
deleteUser: postgresDeleteUser,
|
||||
deleteUserTier: postgresDeleteUserTier,
|
||||
deleteUsersMarked: postgresDeleteUsersMarked,
|
||||
|
||||
// Access queries
|
||||
selectTopicPerms: postgresSelectTopicPerms,
|
||||
selectUserAllAccess: postgresSelectUserAllAccess,
|
||||
selectUserAccess: postgresSelectUserAccess,
|
||||
selectUserReservations: postgresSelectUserReservations,
|
||||
selectUserReservationsCount: postgresSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: postgresSelectUserReservationsOwner,
|
||||
selectUserHasReservation: postgresSelectUserHasReservation,
|
||||
selectOtherAccessCount: postgresSelectOtherAccessCount,
|
||||
upsertUserAccess: postgresUpsertUserAccess,
|
||||
deleteUserAccess: postgresDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: postgresDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: postgresDeleteTopicAccess,
|
||||
deleteAllAccess: postgresDeleteAllAccess,
|
||||
|
||||
// Token queries
|
||||
selectToken: postgresSelectToken,
|
||||
selectTokens: postgresSelectTokens,
|
||||
selectTokenCount: postgresSelectTokenCount,
|
||||
selectAllProvisionedTokens: postgresSelectAllProvisionedTokens,
|
||||
upsertToken: postgresUpsertToken,
|
||||
updateTokenLabel: postgresUpdateTokenLabel,
|
||||
updateTokenExpiry: postgresUpdateTokenExpiry,
|
||||
updateTokenLastAccess: postgresUpdateTokenLastAccess,
|
||||
deleteToken: postgresDeleteToken,
|
||||
deleteProvisionedToken: postgresDeleteProvisionedToken,
|
||||
deleteAllToken: postgresDeleteAllToken,
|
||||
deleteExpiredTokens: postgresDeleteExpiredTokens,
|
||||
deleteExcessTokens: postgresDeleteExcessTokens,
|
||||
|
||||
// Tier queries
|
||||
insertTier: postgresInsertTier,
|
||||
selectTiers: postgresSelectTiers,
|
||||
selectTierByCode: postgresSelectTierByCode,
|
||||
selectTierByPriceID: postgresSelectTierByPriceID,
|
||||
updateTier: postgresUpdateTier,
|
||||
deleteTier: postgresDeleteTier,
|
||||
|
||||
// Phone queries
|
||||
selectPhoneNumbers: postgresSelectPhoneNumbers,
|
||||
insertPhoneNumber: postgresInsertPhoneNumber,
|
||||
deletePhoneNumber: postgresDeletePhoneNumber,
|
||||
|
||||
// Billing queries
|
||||
updateBilling: postgresUpdateBilling,
|
||||
}
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
// Check if schema version table exists
|
||||
var exists bool
|
||||
if err := db.QueryRow(postgresSelectSchemaVersionExists).Scan(&exists); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// New database, create all tables
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Table exists, check if user store has a row
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||
if err == sql.ErrNoRows {
|
||||
// schema_version table exists (e.g. created by webpush) but no user row yet
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("cannot determine schema version: %v", err)
|
||||
}
|
||||
|
||||
if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
// For now, we only support new installations at the latest schema version
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
113
user/store_postgres_schema.go
Normal file
113
user/store_postgres_schema.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
postgresCreateTablesQueries = `
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit BIGINT NOT NULL,
|
||||
messages_expiry_duration BIGINT NOT NULL,
|
||||
emails_limit BIGINT NOT NULL,
|
||||
calls_limit BIGINT NOT NULL,
|
||||
reservations_limit BIGINT NOT NULL,
|
||||
attachment_file_size_limit BIGINT NOT NULL,
|
||||
attachment_total_size_limit BIGINT NOT NULL,
|
||||
attachment_expiry_duration BIGINT NOT NULL,
|
||||
attachment_bandwidth_limit BIGINT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT,
|
||||
UNIQUE(code),
|
||||
UNIQUE(stripe_monthly_price_id),
|
||||
UNIQUE(stripe_yearly_price_id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT REFERENCES tier(id),
|
||||
user_name TEXT NOT NULL UNIQUE,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK (role IN ('anonymous', 'admin', 'user')),
|
||||
prefs JSONB NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
stats_messages BIGINT NOT NULL DEFAULT 0,
|
||||
stats_emails BIGINT NOT NULL DEFAULT 0,
|
||||
stats_calls BIGINT NOT NULL DEFAULT 0,
|
||||
stripe_customer_id TEXT UNIQUE,
|
||||
stripe_subscription_id TEXT UNIQUE,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until BIGINT,
|
||||
stripe_subscription_cancel_at BIGINT,
|
||||
created BIGINT NOT NULL,
|
||||
deleted BIGINT
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
topic TEXT NOT NULL,
|
||||
read BOOLEAN NOT NULL,
|
||||
write BOOLEAN NOT NULL,
|
||||
owner_user_id TEXT REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, topic)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
label TEXT NOT NULL,
|
||||
last_access BIGINT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
provisioned BOOLEAN NOT NULL,
|
||||
PRIMARY KEY (user_id, token)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO "user" (id, user_name, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, EXTRACT(EPOCH FROM NOW())::BIGINT)
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema table management queries for Postgres
|
||||
const (
|
||||
postgresCurrentSchemaVersion = 6
|
||||
postgresSelectSchemaVersion = `SELECT version FROM schema_version WHERE store = 'user'`
|
||||
postgresInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('user', $1)`
|
||||
)
|
||||
|
||||
func setupPostgres(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(postgresSelectSchemaVersion).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgres(db)
|
||||
}
|
||||
if schemaVersion > postgresCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||
}
|
||||
// Note: PostgreSQL migrations will be added when needed
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgres(db *sql.DB) error {
|
||||
if _, err := db.Exec(postgresCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(postgresInsertSchemaVersion, postgresCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,102 +2,11 @@ package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
// SQLite schema and queries
|
||||
const (
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
sqliteBuiltinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
// User queries
|
||||
sqliteSelectUserByID = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
|
||||
@@ -295,126 +204,70 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLiteDB(db); err != nil {
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
db: db,
|
||||
queries: sqliteQueries(),
|
||||
db: db,
|
||||
queries: storeQueries{
|
||||
selectUserByID: sqliteSelectUserByID,
|
||||
selectUserByName: sqliteSelectUserByName,
|
||||
selectUserByToken: sqliteSelectUserByToken,
|
||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||
selectUsernames: sqliteSelectUsernames,
|
||||
selectUserCount: sqliteSelectUserCount,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||
insertUser: sqliteInsertUser,
|
||||
updateUserPass: sqliteUpdateUserPass,
|
||||
updateUserRole: sqliteUpdateUserRole,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||
updateUserStats: sqliteUpdateUserStats,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||
updateUserTier: sqliteUpdateUserTier,
|
||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||
deleteUser: sqliteDeleteUser,
|
||||
deleteUserTier: sqliteDeleteUserTier,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||
selectTopicPerms: sqliteSelectTopicPerms,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||
selectUserAccess: sqliteSelectUserAccess,
|
||||
selectUserReservations: sqliteSelectUserReservations,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||
upsertUserAccess: sqliteUpsertUserAccess,
|
||||
deleteUserAccess: sqliteDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||
deleteAllAccess: sqliteDeleteAllAccess,
|
||||
selectToken: sqliteSelectToken,
|
||||
selectTokens: sqliteSelectTokens,
|
||||
selectTokenCount: sqliteSelectTokenCount,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||
upsertToken: sqliteUpsertToken,
|
||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||
deleteToken: sqliteDeleteToken,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||
deleteAllToken: sqliteDeleteAllToken,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||
insertTier: sqliteInsertTier,
|
||||
selectTiers: sqliteSelectTiers,
|
||||
selectTierByCode: sqliteSelectTierByCode,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||
updateTier: sqliteUpdateTier,
|
||||
deleteTier: sqliteDeleteTier,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||
updateBilling: sqliteUpdateBilling,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func sqliteQueries() storeQueries {
|
||||
return storeQueries{
|
||||
selectUserByID: sqliteSelectUserByID,
|
||||
selectUserByName: sqliteSelectUserByName,
|
||||
selectUserByToken: sqliteSelectUserByToken,
|
||||
selectUserByStripeID: sqliteSelectUserByStripeID,
|
||||
selectUsernames: sqliteSelectUsernames,
|
||||
selectUserCount: sqliteSelectUserCount,
|
||||
selectUserIDFromUsername: sqliteSelectUserIDFromUsername,
|
||||
insertUser: sqliteInsertUser,
|
||||
updateUserPass: sqliteUpdateUserPass,
|
||||
updateUserRole: sqliteUpdateUserRole,
|
||||
updateUserProvisioned: sqliteUpdateUserProvisioned,
|
||||
updateUserPrefs: sqliteUpdateUserPrefs,
|
||||
updateUserStats: sqliteUpdateUserStats,
|
||||
updateUserStatsResetAll: sqliteUpdateUserStatsResetAll,
|
||||
updateUserTier: sqliteUpdateUserTier,
|
||||
updateUserDeleted: sqliteUpdateUserDeleted,
|
||||
deleteUser: sqliteDeleteUser,
|
||||
deleteUserTier: sqliteDeleteUserTier,
|
||||
deleteUsersMarked: sqliteDeleteUsersMarked,
|
||||
selectTopicPerms: sqliteSelectTopicPerms,
|
||||
selectUserAllAccess: sqliteSelectUserAllAccess,
|
||||
selectUserAccess: sqliteSelectUserAccess,
|
||||
selectUserReservations: sqliteSelectUserReservations,
|
||||
selectUserReservationsCount: sqliteSelectUserReservationsCount,
|
||||
selectUserReservationsOwner: sqliteSelectUserReservationsOwner,
|
||||
selectUserHasReservation: sqliteSelectUserHasReservation,
|
||||
selectOtherAccessCount: sqliteSelectOtherAccessCount,
|
||||
upsertUserAccess: sqliteUpsertUserAccess,
|
||||
deleteUserAccess: sqliteDeleteUserAccess,
|
||||
deleteUserAccessProvisioned: sqliteDeleteUserAccessProvisioned,
|
||||
deleteTopicAccess: sqliteDeleteTopicAccess,
|
||||
deleteAllAccess: sqliteDeleteAllAccess,
|
||||
selectToken: sqliteSelectToken,
|
||||
selectTokens: sqliteSelectTokens,
|
||||
selectTokenCount: sqliteSelectTokenCount,
|
||||
selectAllProvisionedTokens: sqliteSelectAllProvisionedTokens,
|
||||
upsertToken: sqliteUpsertToken,
|
||||
updateTokenLabel: sqliteUpdateTokenLabel,
|
||||
updateTokenExpiry: sqliteUpdateTokenExpiry,
|
||||
updateTokenLastAccess: sqliteUpdateTokenLastAccess,
|
||||
deleteToken: sqliteDeleteToken,
|
||||
deleteProvisionedToken: sqliteDeleteProvisionedToken,
|
||||
deleteAllToken: sqliteDeleteAllToken,
|
||||
deleteExpiredTokens: sqliteDeleteExpiredTokens,
|
||||
deleteExcessTokens: sqliteDeleteExcessTokens,
|
||||
insertTier: sqliteInsertTier,
|
||||
selectTiers: sqliteSelectTiers,
|
||||
selectTierByCode: sqliteSelectTierByCode,
|
||||
selectTierByPriceID: sqliteSelectTierByPriceID,
|
||||
updateTier: sqliteUpdateTier,
|
||||
deleteTier: sqliteDeleteTier,
|
||||
selectPhoneNumbers: sqliteSelectPhoneNumbers,
|
||||
insertPhoneNumber: sqliteInsertPhoneNumber,
|
||||
deletePhoneNumber: sqliteDeletePhoneNumber,
|
||||
updateBilling: sqliteUpdateBilling,
|
||||
}
|
||||
}
|
||||
|
||||
func setupSQLiteDB(db *sql.DB) error {
|
||||
rowsSV, err := db.Query(selectSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewSQLiteDB(db)
|
||||
}
|
||||
defer rowsSV.Close()
|
||||
schemaVersion := 0
|
||||
if !rowsSV.Next() {
|
||||
return fmt.Errorf("cannot determine schema version: database file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
if schemaVersion == currentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > currentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < currentSchemaVersion; i++ {
|
||||
fn, ok := migrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLiteDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,20 +2,116 @@ package user
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// SQLite migrations
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
currentSchemaVersion = 6
|
||||
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
sqliteCreateTablesQueries = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
messages_limit INT NOT NULL,
|
||||
messages_expiry_duration INT NOT NULL,
|
||||
emails_limit INT NOT NULL,
|
||||
calls_limit INT NOT NULL,
|
||||
reservations_limit INT NOT NULL,
|
||||
attachment_file_size_limit INT NOT NULL,
|
||||
attachment_total_size_limit INT NOT NULL,
|
||||
attachment_expiry_duration INT NOT NULL,
|
||||
attachment_bandwidth_limit INT NOT NULL,
|
||||
stripe_monthly_price_id TEXT,
|
||||
stripe_yearly_price_id TEXT
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
|
||||
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
tier_id TEXT,
|
||||
user TEXT NOT NULL,
|
||||
pass TEXT NOT NULL,
|
||||
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
|
||||
prefs JSON NOT NULL DEFAULT '{}',
|
||||
sync_topic TEXT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
stats_messages INT NOT NULL DEFAULT (0),
|
||||
stats_emails INT NOT NULL DEFAULT (0),
|
||||
stats_calls INT NOT NULL DEFAULT (0),
|
||||
stripe_customer_id TEXT,
|
||||
stripe_subscription_id TEXT,
|
||||
stripe_subscription_status TEXT,
|
||||
stripe_subscription_interval TEXT,
|
||||
stripe_subscription_paid_until INT,
|
||||
stripe_subscription_cancel_at INT,
|
||||
created INT NOT NULL,
|
||||
deleted INT,
|
||||
FOREIGN KEY (tier_id) REFERENCES tier (id)
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user ON user (user);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
||||
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
||||
CREATE TABLE IF NOT EXISTS user_access (
|
||||
user_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
read INT NOT NULL,
|
||||
write INT NOT NULL,
|
||||
owner_user_id INT,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, topic),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS user_token (
|
||||
user_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
label TEXT NOT NULL,
|
||||
last_access INT NOT NULL,
|
||||
last_origin TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
provisioned INT NOT NULL,
|
||||
PRIMARY KEY (user_id, token),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
user_id TEXT NOT NULL,
|
||||
phone_number TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, phone_number),
|
||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
|
||||
VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
const (
|
||||
sqliteBuiltinStartupQueries = `PRAGMA foreign_keys = ON;`
|
||||
)
|
||||
|
||||
// Schema version table management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 6
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 1 -> 2 (complex migration!)
|
||||
migrate1To2CreateTablesQueries = `
|
||||
sqliteMigrate1To2CreateTablesQueries = `
|
||||
ALTER TABLE user RENAME TO user_old;
|
||||
CREATE TABLE IF NOT EXISTS tier (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -83,12 +179,12 @@ const (
|
||||
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
|
||||
ON CONFLICT (id) DO NOTHING;
|
||||
`
|
||||
migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
migrate1To2InsertUserNoTx = `
|
||||
sqliteMigrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
|
||||
sqliteMigrate1To2InsertUserNoTx = `
|
||||
INSERT INTO user (id, user, pass, role, sync_topic, created)
|
||||
SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
|
||||
`
|
||||
migrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
sqliteMigrate1To2InsertFromOldTablesAndDropNoTx = `
|
||||
INSERT INTO user_access (user_id, topic, read, write)
|
||||
SELECT u.id, a.topic, a.read, a.write
|
||||
FROM user u
|
||||
@@ -99,7 +195,7 @@ const (
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
migrate2To3UpdateQueries = `
|
||||
sqliteMigrate2To3UpdateQueries = `
|
||||
ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
|
||||
ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
|
||||
ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
|
||||
@@ -109,7 +205,7 @@ const (
|
||||
`
|
||||
|
||||
// 3 -> 4
|
||||
migrate3To4UpdateQueries = `
|
||||
sqliteMigrate3To4UpdateQueries = `
|
||||
ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
|
||||
ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
|
||||
CREATE TABLE IF NOT EXISTS user_phone (
|
||||
@@ -121,12 +217,12 @@ const (
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
migrate4To5UpdateQueries = `
|
||||
sqliteMigrate4To5UpdateQueries = `
|
||||
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
migrate5To6UpdateQueries = `
|
||||
sqliteMigrate5To6UpdateQueries = `
|
||||
PRAGMA foreign_keys=off;
|
||||
|
||||
-- Alter user table: Add provisioned column
|
||||
@@ -221,16 +317,60 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB) error{
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
sqliteMigrations = map[int]func(db *sql.DB) error{
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
}
|
||||
)
|
||||
|
||||
func migrateFrom1(db *sql.DB) error {
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(sqliteBuiltinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
@@ -238,11 +378,11 @@ func migrateFrom1(db *sql.DB) error {
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Rename user -> user_old, and create new tables
|
||||
if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2CreateTablesQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert users from user_old into new user table, with ID and sync_topic
|
||||
rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
|
||||
rows, err := tx.Query(sqliteMigrate1To2SelectAllOldUsernamesNoTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -261,15 +401,15 @@ func migrateFrom1(db *sql.DB) error {
|
||||
for _, username := range usernames {
|
||||
userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
|
||||
syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
|
||||
if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Migrate old "access" table to "user_access" and drop "access" and "user_old"
|
||||
if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
@@ -278,65 +418,65 @@ func migrateFrom1(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateFrom2(db *sql.DB) error {
|
||||
func sqliteMigrateFrom2(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate2To3UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom3(db *sql.DB) error {
|
||||
func sqliteMigrateFrom3(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate3To4UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom4(db *sql.DB) error {
|
||||
func sqliteMigrateFrom4(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate4To5UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom5(db *sql.DB) error {
|
||||
func sqliteMigrateFrom5(db *sql.DB) error {
|
||||
log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
|
||||
if _, err := tx.Exec(sqliteMigrate5To6UpdateQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
@@ -2,6 +2,7 @@ package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
@@ -102,14 +103,13 @@ func NewPostgresStore(dsn string) (Store, error) {
|
||||
}
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
// If 'schema_version' table does not exist or no webpush row, this must be a new database
|
||||
rows, err := db.Query(pgSelectSchemaVersionQuery)
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return setupNewPostgresDB(db)
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
@@ -82,10 +83,10 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLiteWebPushDB(db); err != nil {
|
||||
if err := setupSQLite(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runSQLiteWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &commonStore{
|
||||
@@ -108,16 +109,19 @@ func NewSQLiteStore(filename, startupQueries string) (Store, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupSQLiteWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(sqliteSelectWebPushSchemaVersionQuery)
|
||||
func setupSQLite(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(sqliteSelectWebPushSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewSQLiteWebPushDB(db)
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
return rows.Close()
|
||||
if schemaVersion > sqliteCurrentWebPushSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentWebPushSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,7 +131,7 @@ func setupNewSQLiteWebPushDB(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user