Manual refinements
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
@@ -9,9 +10,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionIDPrefix = "wps_"
|
||||
SubscriptionIDLength = 10
|
||||
SubscriptionEndpointLimitPerSubscriberIP = 10
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -50,3 +51,21 @@ func (w *Subscription) Context() log.Context {
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
|
||||
subscriptions := make([]*Subscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &Subscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package webpush
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
@@ -133,11 +134,11 @@ func (c *PostgresStore) UpsertSubscription(endpoint string, auth, p256dh, userID
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
var subscriptionID string
|
||||
err = tx.QueryRow(pgSelectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID)
|
||||
if err == sql.ErrNoRows {
|
||||
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return ErrWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength)
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -166,10 +166,10 @@ func (c *SQLiteStore) UpsertSubscription(endpoint string, auth, p256dh, userID s
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if subscriptionCount >= SubscriptionEndpointLimitPerSubscriberIP {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return ErrWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(SubscriptionIDPrefix, SubscriptionIDLength)
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
@@ -262,21 +262,3 @@ func (c *SQLiteStore) SetSubscriptionUpdatedAt(endpoint string, updatedAt int64)
|
||||
func (c *SQLiteStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func subscriptionsFromRows(rows *sql.Rows) ([]*Subscription, error) {
|
||||
subscriptions := make([]*Subscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &Subscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user