Manual refinements

This commit is contained in:
binwiederhier
2026-02-16 19:06:45 -05:00
parent a8dcecdb6d
commit bdd20197b3
3 changed files with 28 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
package webpush package webpush
import ( import (
"database/sql"
"errors" "errors"
"net/netip" "net/netip"
"time" "time"
@@ -9,9 +10,9 @@ import (
) )
const ( const (
SubscriptionIDPrefix = "wps_" subscriptionIDPrefix = "wps_"
SubscriptionIDLength = 10 subscriptionIDLength = 10
SubscriptionEndpointLimitPerSubscriberIP = 10 subscriptionEndpointLimitPerSubscriberIP = 10
) )
var ( var (
@@ -50,3 +51,21 @@ func (w *Subscription) Context() log.Context {
"web_push_subscription_endpoint": w.Endpoint, "web_push_subscription_endpoint": w.Endpoint,
} }
} }
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
subscriptions := make([]*Subscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &Subscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}

View File

@@ -2,6 +2,7 @@ package webpush
import ( import (
"database/sql" "database/sql"
"errors"
"net/netip" "net/netip"
"time" "time"
@@ -133,11 +134,11 @@ func (c *PostgresStore) UpsertSubscription(endpoint string, auth, p256dh, userID
// Read existing subscription ID for endpoint (or create new ID) // Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string var subscriptionID string
err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID) err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP { if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions return ErrWebPushTooManySubscriptions
} }
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength) subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil { } else if err != nil {
return err return err
} }

View File

@@ -166,10 +166,10 @@ func (c *SQLiteStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
return err return err
} }
} else { } else {
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP { if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions return ErrWebPushTooManySubscriptions
} }
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength) subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return err return err
@@ -262,21 +262,3 @@ func (c *SQLiteStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64)
func (c *SQLiteStore) Close() error { func (c *SQLiteStore) Close() error {
return c.db.Close() return c.db.Close()
} }
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
subscriptions := make([]*Subscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &Subscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}