diff --git a/.gitignore b/.gitignore index cf10bc33..3a362286 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ build/ server/docs/ server/site/ tools/fbsend/fbsend +tools/pgimport/pgimport playground/ secrets/ *.iml diff --git a/user/manager.go b/user/manager.go index 2a962be2..34a6d33b 100644 --- a/user/manager.go +++ b/user/manager.go @@ -355,7 +355,7 @@ func (a *Manager) MarkUserRemoved(user *User) error { if !AllowedUsername(user.Name) { return ErrInvalidArgument } - return a.store.MarkUserRemoved(user.ID) + return a.store.MarkUserRemoved(user.ID, user.Name) } // Users returns a list of users. It always also returns the Everyone user ("*"). @@ -552,13 +552,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { return ErrInvalidArgument } - if err := a.store.AllowAccess(username, topic, true, true, username, false); err != nil { - return err - } - if err := a.store.AllowAccess(Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil { - return err - } - return nil + return a.store.AddReservation(username, topic, everyone) } // RemoveReservations deletes the access control entries associated with the given username/topic, as @@ -572,15 +566,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error { return ErrInvalidArgument } } - for _, topic := range topics { - if err := a.store.ResetAccess(username, topic); err != nil { - return err - } - if err := a.store.ResetAccess(Everyone, topic); err != nil { - return err - } - } - return nil + return a.store.RemoveReservations(username, topics...) } // DefaultAccess returns the default read/write access if no access control entry matches diff --git a/user/store.go b/user/store.go index fd9a3fbb..7ebb5b58 100644 --- a/user/store.go +++ b/user/store.go @@ -24,7 +24,7 @@ type Store interface { UsersCount() (int64, error) AddUser(username, hash string, role Role, provisioned bool) error RemoveUser(username string) error - MarkUserRemoved(userID string) error + MarkUserRemoved(userID, username string) error RemoveDeletedUsers() error ChangePassword(username, hash string) error ChangeRole(username string, role Role) error @@ -53,6 +53,8 @@ type Store interface { AllowAccess(username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error ResetAccess(username, topicPattern string) error ResetAllProvisionedAccess() error + AddReservation(username, topic string, everyone Permission) error + RemoveReservations(username string, topics ...string) error Reservations(username string) ([]Reservation, error) HasReservation(username, topic string) (bool, error) ReservationsCount(username string) (int64, error) @@ -141,6 +143,12 @@ type storeQueries struct { updateBilling string } +// execer is satisfied by both *sql.DB and *sql.Tx, allowing helper methods +// to be used both standalone and within a transaction. +type execer interface { + Exec(query string, args ...any) (sql.Result, error) +} + // commonStore implements store operations that work across database backends type commonStore struct { db *sql.DB @@ -259,22 +267,13 @@ func (s *commonStore) RemoveUser(username string) error { } // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens -func (s *commonStore) MarkUserRemoved(userID string) error { +func (s *commonStore) MarkUserRemoved(userID, username string) error { tx, err := s.db.Begin() if err != nil { return err } defer tx.Rollback() - // Get username for deleteUserAccess query (must be inside tx for consistency) - rows, err := tx.Query(s.queries.selectUserByID, userID) - if err != nil { - return err - } - user, err := s.readUser(rows) - if err != nil { - return err - } - if _, err := tx.Exec(s.queries.deleteUserAccess, user.Name, user.Name); err != nil { + if _, err := tx.Exec(s.queries.deleteUserAccess, username, username); err != nil { return err } if _, err := tx.Exec(s.queries.deleteAllToken, userID); err != nil { @@ -413,12 +412,12 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { Calls: calls, }, Billing: &Billing{ - StripeCustomerID: stripeCustomerID.String, - StripeSubscriptionID: stripeSubscriptionID.String, - StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), - StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), - StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), - StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), + StripeCustomerID: stripeCustomerID.String, // May be empty + StripeSubscriptionID: stripeSubscriptionID.String, // May be empty + StripeSubscriptionStatus: payments.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty + StripeSubscriptionInterval: payments.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty + StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero + StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero }, Deleted: deleted.Valid, } @@ -426,6 +425,7 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { return nil, err } if tierCode.Valid { + // See readTier() when this is changed! user.Tier = &Tier{ ID: tierID.String, Code: tierCode.String, @@ -439,8 +439,8 @@ func (s *commonStore) readUser(rows *sql.Rows) (*User, error) { AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64, AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second, AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64, - StripeMonthlyPriceID: stripeMonthlyPriceID.String, - StripeYearlyPriceID: stripeYearlyPriceID.String, + StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty + StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty } } return user, nil @@ -612,6 +612,10 @@ func (s *commonStore) readToken(rows *sql.Rows) (*Token, error) { // AuthorizeTopicAccess returns the read/write permissions for the given username and topic. // The found return value indicates whether an ACL entry was found at all. +// +// - The query may return two rows (one for everyone, and one for the user), but prioritizes the user. +// - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*" +// - It also prioritizes write permissions over read permissions func (s *commonStore) AuthorizeTopicAccess(usernameOrEveryone, topic string) (read, write, found bool, err error) { rows, err := s.db.Query(s.queries.selectTopicPerms, Everyone, usernameOrEveryone, topic) if err != nil { @@ -689,10 +693,12 @@ func (s *commonStore) AllowAccess(username, topicPattern string, read, write boo } else if !AllowedTopicPattern(topicPattern) { return ErrInvalidArgument } - if _, err := s.db.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned); err != nil { - return err - } - return nil + return s.allowAccessTx(s.db, username, topicPattern, read, write, ownerUsername, provisioned) +} + +func (s *commonStore) allowAccessTx(tx execer, username, topicPattern string, read, write bool, ownerUsername string, provisioned bool) error { + _, err := tx.Exec(s.queries.upsertUserAccess, username, toSQLWildcard(topicPattern), read, write, ownerUsername, ownerUsername, provisioned) + return err } // ResetAccess removes an access control list entry @@ -709,7 +715,11 @@ func (s *commonStore) ResetAccess(username, topicPattern string) error { _, err := s.db.Exec(s.queries.deleteUserAccess, username, username) return err } - _, err := s.db.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) + return s.resetTopicAccessTx(s.db, username, topicPattern) +} + +func (s *commonStore) resetTopicAccessTx(tx execer, username, topicPattern string) error { + _, err := tx.Exec(s.queries.deleteTopicAccess, username, username, toSQLWildcard(topicPattern)) return err } @@ -721,6 +731,44 @@ func (s *commonStore) ResetAllProvisionedAccess() error { return nil } +// AddReservation creates two access control entries for the given topic: one with full read/write +// access for the given user, and one for Everyone with the given permission. Both entries are +// created atomically in a single transaction. +func (s *commonStore) AddReservation(username, topic string, everyone Permission) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if err := s.allowAccessTx(tx, username, topic, true, true, username, false); err != nil { + return err + } + if err := s.allowAccessTx(tx, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, false); err != nil { + return err + } + return tx.Commit() +} + +// RemoveReservations deletes the access control entries associated with the given username/topic, +// as well as all entries with Everyone/topic. All deletions are performed atomically in a single +// transaction. +func (s *commonStore) RemoveReservations(username string, topics ...string) error { + tx, err := s.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, topic := range topics { + if err := s.resetTopicAccessTx(tx, username, topic); err != nil { + return err + } + if err := s.resetTopicAccessTx(tx, Everyone, topic); err != nil { + return err + } + } + return tx.Commit() +} + // Reservations returns all user-owned topics, and the associated everyone-access func (s *commonStore) Reservations(username string) ([]Reservation, error) { rows, err := s.db.Query(s.queries.selectUserReservations, Everyone, username) diff --git a/user/store_test.go b/user/store_test.go index c5c40a45..6ed88722 100644 --- a/user/store_test.go +++ b/user/store_test.go @@ -656,7 +656,7 @@ func TestStoreMarkUserRemoved(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - require.Nil(t, store.MarkUserRemoved(u.ID)) + require.Nil(t, store.MarkUserRemoved(u.ID, u.Name)) u2, err := store.User("phil") require.Nil(t, err) @@ -670,7 +670,7 @@ func TestStoreRemoveDeletedUsers(t *testing.T) { u, err := store.User("phil") require.Nil(t, err) - require.Nil(t, store.MarkUserRemoved(u.ID)) + require.Nil(t, store.MarkUserRemoved(u.ID, u.Name)) // RemoveDeletedUsers only removes users past the hard-delete duration (7 days). // Immediately after marking, the user should still exist.