Manual refinements

This commit is contained in:
binwiederhier
2026-02-16 19:06:45 -05:00
parent a8dcecdb6d
commit bdd20197b3
3 changed files with 28 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
package webpush
import (
"database/sql"
"errors"
"net/netip"
"time"
@@ -9,9 +10,9 @@ import (
)
const (
SubscriptionIDPrefix = "wps_"
SubscriptionIDLength = 10
SubscriptionEndpointLimitPerSubscriberIP = 10
subscriptionIDPrefix = "wps_"
subscriptionIDLength = 10
subscriptionEndpointLimitPerSubscriberIP = 10
)
var (
@@ -50,3 +51,21 @@ func (w *Subscription) Context() log.Context {
"web_push_subscription_endpoint": w.Endpoint,
}
}
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
subscriptions := make([]*Subscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &Subscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}

View File

@@ -2,6 +2,7 @@ package webpush
import (
"database/sql"
"errors"
"net/netip"
"time"
@@ -133,11 +134,11 @@ func (c *PostgresStore) UpsertSubscription(endpoint string, auth, p256dh, userID
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
if err == sql.ErrNoRows {
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP {
if errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength)
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}

View File

@@ -166,10 +166,10 @@ func (c *SQLiteStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
return err
}
} else {
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength)
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
}
if err := rows.Close(); err != nil {
return err
@@ -262,21 +262,3 @@ func (c *SQLiteStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64)
func (c *SQLiteStore) Close() error {
return c.db.Close()
}
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
subscriptions := make([]*Subscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &Subscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}