Re-add tx integrity

This commit is contained in:
binwiederhier
2026-02-26 16:54:56 -05:00
parent a7d5a9c5d8
commit c66fa92341
4 changed files with 79 additions and 44 deletions

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@ build/
server/docs/ server/docs/
server/site/ server/site/
tools/fbsend/fbsend tools/fbsend/fbsend
tools/pgimport/pgimport
playground/ playground/
secrets/ secrets/
*.iml *.iml

View File

@@ -355,7 +355,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) { if !AllowedUsername(user.Name) {
return ErrInvalidArgument return ErrInvalidArgument
} }
return a.store.MarkUserRemoved(user.ID) return a.store.MarkUserRemoved(user.ID, user.Name)
} }
// Users returns a list of users. It always also returns the Everyone user ("*"). // Users returns a list of users. It always also returns the Everyone user ("*").
@@ -552,13 +552,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument return ErrInvalidArgument
} }
if err := a.store.AllowAccess(username, topic, true, true, username, false); err != nil { return a.store.AddReservation(username, topic, everyone)
return err
}
if err := a.store.AllowAccess(Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return nil
} }
// RemoveReservations deletes the access control entries associated with the given username/topic, as // RemoveReservations deletes the access control entries associated with the given username/topic, as
@@ -572,15 +566,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument return ErrInvalidArgument
} }
} }
for _, topic := range topics { return a.store.RemoveReservations(username, topics...)
if err := a.store.ResetAccess(username, topic); err != nil {
return err
}
if err := a.store.ResetAccess(Everyone, topic); err != nil {
return err
}
}
return nil
} }
// DefaultAccess returns the default read/write access if no access control entry matches // DefaultAccess returns the default read/write access if no access control entry matches

View File

@@ -24,7 +24,7 @@ type Store interface {
UsersCount() (int64, error) UsersCount() (int64, error)
AddUser(username, hash string, role Role, provisioned bool) error AddUser(username, hash string, role Role, provisioned bool) error
RemoveUser(username string) error RemoveUser(username string) error
MarkUserRemoved(userID string) error MarkUserRemoved(userID, username string) error
RemoveDeletedUsers() error RemoveDeletedUsers() error
ChangePassword(username, hash string) error ChangePassword(username, hash string) error
ChangeRole(username string, role Role) error ChangeRole(username string, role Role) error
@@ -53,6 +53,8 @@ type Store interface {
AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error
ResetAccess(username, topicPattern string) error ResetAccess(username, topicPattern string) error
ResetAllProvisionedAccess() error ResetAllProvisionedAccess() error
AddReservation(username, topic string, everyone Permission) error
RemoveReservations(username string, topics ...string) error
Reservations(username string) ([]Reservation, error) Reservations(username string) ([]Reservation, error)
HasReservation(username, topic string) (bool, error) HasReservation(username, topic string) (bool, error)
ReservationsCount(username string) (int64, error) ReservationsCount(username string) (int64, error)
@@ -141,6 +143,12 @@ type storeQueries struct {
updateBilling string updateBilling string
} }
// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods
// to be used both standalone and within a transaction.
type execer interface {
Exec(query string, args ...any) (sql.Result, error)
}
// commonStore implements store operations that work across database backends // commonStore implements store operations that work across database backends
type commonStore struct { type commonStore struct {
db *sql.DB db *sql.DB
@@ -259,22 +267,13 @@ func (s *commonStore) RemoveUser(username string) error {
} }
// MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens
func (s *commonStore) MarkUserRemoved(userID string) error { func (s *commonStore) MarkUserRemoved(userID, username string) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
// Get username for deleteUserAccess query (must be inside tx for consistency) if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil {
rows, err := tx.Query(s.queries.selectUserByID, userID)
if err != nil {
return err
}
user, err := s.readUser(rows)
if err != nil {
return err
}
if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil {
return err return err
} }
if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil { if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil {
@@ -413,12 +412,12 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
Calls: calls, Calls: calls,
}, },
Billing: &Billing{ Billing: &Billing{
StripeCustomerID: stripeCustomerID.String, StripeCustomerID: stripeCustomerID.String, // May be empty
StripeSubscriptionID: stripeSubscriptionID.String, StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
}, },
Deleted: deleted.Valid, Deleted: deleted.Valid,
} }
@@ -426,6 +425,7 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
return nil, err return nil, err
} }
if tierCode.Valid { if tierCode.Valid {
// See readTier() when this is changed!
user.Tier = &Tier{ user.Tier = &Tier{
ID: tierID.String, ID: tierID.String,
Code: tierCode.String, Code: tierCode.String,
@@ -439,8 +439,8 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) {
AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
StripeMonthlyPriceID: stripeMonthlyPriceID.String, StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
StripeYearlyPriceID: stripeYearlyPriceID.String, StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
} }
} }
return user, nil return user, nil
@@ -612,6 +612,10 @@ func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) {
// AuthorizeTopicAccess returns the read/write permissions for the given username and topic. // AuthorizeTopicAccess returns the read/write permissions for the given username and topic.
// The found return value indicates whether an ACL entry was found at all. // The found return value indicates whether an ACL entry was found at all.
//
// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
// - It also prioritizes write permissions over read permissions
func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) {
rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic)
if err != nil { if err != nil {
@@ -689,10 +693,12 @@ func (s *commonStore) AllowAccess(username, topicPattern string, read, write boo
} else if !AllowedTopicPattern(topicPattern) { } else if !AllowedTopicPattern(topicPattern) {
return ErrInvalidArgument return ErrInvalidArgument
} }
if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil { return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned)
return err }
}
return nil func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error {
_, err := tx.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned)
return err
} }
// ResetAccess removes an access control list entry // ResetAccess removes an access control list entry
@@ -709,7 +715,11 @@ func (s *commonStore) ResetAccess(username, topicPattern string) error {
_, err := s.db.Exec(s.queries.deleteUserAccess, username, username) _, err := s.db.Exec(s.queries.deleteUserAccess, username, username)
return err return err
} }
_, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) return s.resetTopicAccessTx(s.db, username, topicPattern)
}
func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error {
_, err := tx.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern))
return err return err
} }
@@ -721,6 +731,44 @@ func (s *commonStore) ResetAllProvisionedAccess() error {
return nil return nil
} }
// AddReservation creates two access control entries for the given topic: one with full read/write
// access for the given user, and one for Everyone with the given permission. Both entries are
// created atomically in a single transaction.
func (s *commonStore) AddReservation(username, topic string, everyone Permission) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := s.allowAccessTx(tx, username, topic, true, true, username, false); err != nil {
return err
}
if err := s.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil {
return err
}
return tx.Commit()
}
// RemoveReservations deletes the access control entries associated with the given username/topic,
// as well as all entries with Everyone/topic. All deletions are performed atomically in a single
// transaction.
func (s *commonStore) RemoveReservations(username string, topics ...string) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, topic := range topics {
if err := s.resetTopicAccessTx(tx, username, topic); err != nil {
return err
}
if err := s.resetTopicAccessTx(tx, Everyone, topic); err != nil {
return err
}
}
return tx.Commit()
}
// Reservations returns all user-owned topics, and the associated everyone-access // Reservations returns all user-owned topics, and the associated everyone-access
func (s *commonStore) Reservations(username string) ([]Reservation, error) { func (s *commonStore) Reservations(username string) ([]Reservation, error) {
rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username) rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username)

View File

@@ -656,7 +656,7 @@ func TestStoreMarkUserRemoved(t *testing.T) {
u, err := store.User("phil") u, err := store.User("phil")
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID)) require.Nil(t, store.MarkUserRemoved(u.ID, u.Name))
u2, err := store.User("phil") u2, err := store.User("phil")
require.Nil(t, err) require.Nil(t, err)
@@ -670,7 +670,7 @@ func TestStoreRemoveDeletedUsers(t *testing.T) {
u, err := store.User("phil") u, err := store.User("phil")
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, store.MarkUserRemoved(u.ID)) require.Nil(t, store.MarkUserRemoved(u.ID, u.Name))
// RemoveDeletedUsers only removes users past the hard-delete duration (7 days). // RemoveDeletedUsers only removes users past the hard-delete duration (7 days).
// Immediately after marking, the user should still exist. // Immediately after marking, the user should still exist.