Re-add tx integrity

This commit is contained in:
binwiederhier
2026-02-26 16:54:56 -05:00
parent a7d5a9c5d8
commit c66fa92341
4 changed files with 79 additions and 44 deletions

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@ build/
server/docs/
server/site/
tools/fbsend/fbsend
tools/pgimport/pgimport
playground/
secrets/
*.iml

View File

@@ -355,7 +355,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) {
return ErrInvalidArgument
}
return a.store.MarkUserRemoved(user.ID)
return a.store.MarkUserRemoved(user.ID, user.Name)
}
// Users returns a list of users. It always also returns the Everyone user ("*").
@@ -552,13 +552,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument
}
if err := a.store.AllowAccess(username, topic, true, true, username, false); err != nil {
return err
}
if err := a.store.AllowAccess(Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return nil
return a.store.AddReservation(username, topic, everyone)
}
// RemoveReservations deletes the access control entries associated with the given username/topic, as
@@ -572,15 +566,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument
}
}
for _, topic := range topics {
if err := a.store.ResetAccess(username, topic); err != nil {
return err
}
if err := a.store.ResetAccess(Everyone, topic); err != nil {
return err
}
}
return nil
return a.store.RemoveReservations(username, topics...)
}
// DefaultAccess returns the default read/write access if no access control entry matches

View File

@@ -24,7 +24,7 @@ type Store interface {
UsersCount() (int64, error)
AddUser(username, hash string, role Role, provisioned bool) error
RemoveUser(username string) error
MarkUserRemoved(userID string) error
MarkUserRemoved(userID, username string) error
RemoveDeletedUsers() error
ChangePassword(username, hash string) error
ChangeRole(username string, role Role) error
@@ -53,6 +53,8 @@ type Store interface {
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
ResetAccess(username, topicPattern string) error
ResetAllProvisionedAccess() error
AddReservation(username, topic string, everyone Permission) error
RemoveReservations(username string, topics ...string) error
Reservations(username string) ([]Reservation, error)
HasReservation(username, topic string) (bool, error)
ReservationsCount(username string) (int64, error)
@@ -141,6 +143,12 @@ type storeQueries struct {
updateBilling string
}
// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods
// to be used both standalone and within a transaction.
type execer interface {
Exec(query string, args ...any) (sql.Result, error)
}
// commonStore implements store operations that work across database backends
type commonStore struct {
db *sql.DB
@@ -259,22 +267,13 @@ func (s *commonStore) RemoveUser(username string) error {
}
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
func (s *commonStore) MarkUserRemoved(userID string) error {
func (s *commonStore) MarkUserRemoved(userID, username string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Get username for deleteUserAccess query (must be inside tx for consistency)
rows, err := tx.Query(s.queries.selectUserByID, userID)
if err != nil {
return err
}
user, err := s.readUser(rows)
if err != nil {
return err
}
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
return err
}
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
@@ -413,12 +412,12 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
Calls: calls,
},
Billing: &Billing{
StripeCustomerID: stripeCustomerID.String,
StripeSubscriptionID: stripeSubscriptionID.String,
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String),
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String),
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0),
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0),
StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
},
Deleted: deleted.Valid,
}
@@ -426,6 +425,7 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
return nil, err
}
if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{
ID: tierID.String,
Code: tierCode.String,
@@ -439,8 +439,8 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String,
StripeYearlyPriceID: stripeYearlyPriceID.String,
StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
}
}
return user, nil
@@ -612,6 +612,10 @@ func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
// The found return value indicates whether an ACL entry was found at all.
//
// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil {
@@ -689,10 +693,12 @@ func (s *commonStore) AllowAccess(username, topicPattern string, read, write boo
} else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument
}
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil {
return err
}
return nil
return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned)
}
func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
_, err := tx.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned)
return err
}
// ResetAccess removes an access control list entry
@@ -709,7 +715,11 @@ func (s *commonStore) ResetAccess(username, topicPattern string) error {
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
return err
}
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
return s.resetTopicAccessTx(s.db, username, topicPattern)
}
func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error {
_, err := tx.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
return err
}
@@ -721,6 +731,44 @@ func (s *commonStore) ResetAllProvisionedAccess() error {
return nil
}
// AddReservation creates two access control entries for the given topic: one with full read/write
// access for the given user, and one for Everyone with the given permission. Both entries are
// created atomically in a single transaction.
func (s *commonStore) AddReservation(username, topic string, everyone Permission) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := s.allowAccessTx(tx, username, topic, true, true, username, false); err != nil {
return err
}
if err := s.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return tx.Commit()
}
// RemoveReservations deletes the access control entries associated with the given username/topic,
// as well as all entries with Everyone/topic. All deletions are performed atomically in a single
// transaction.
func (s *commonStore) RemoveReservations(username string, topics ...string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, topic := range topics {
if err := s.resetTopicAccessTx(tx, username, topic); err != nil {
return err
}
if err := s.resetTopicAccessTx(tx, Everyone, topic); err != nil {
return err
}
}
return tx.Commit()
}
// Reservations returns all user-owned topics, and the associated everyone-access
func (s *commonStore) Reservations(username string) ([]Reservation, error) {
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)

View File

@@ -656,7 +656,7 @@ func TestStoreMarkUserRemoved(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID))
require.Nil(t, store.MarkUserRemoved(u.ID, u.Name))
u2, err := store.User("phil")
require.Nil(t, err)
@@ -670,7 +670,7 @@ func TestStoreRemoveDeletedUsers(t *testing.T) {
u, err := store.User("phil")
require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID))
require.Nil(t, store.MarkUserRemoved(u.ID, u.Name))
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
// Immediately after marking, the user should still exist.