Merge branch 'main' of github.com:binwiederhier/ntfy into html-emails
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/user"
|
||||
"io/fs"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/user"
|
||||
)
|
||||
|
||||
// Defines default config settings (excluding limits, see below)
|
||||
@@ -22,6 +23,12 @@ const (
|
||||
DefaultStripePriceCacheDuration = 3 * time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
|
||||
)
|
||||
|
||||
// Defines default Web Push settings
|
||||
const (
|
||||
DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour
|
||||
DefaultWebPushExpiryDuration = 9 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// Defines all global and per-visitor limits
|
||||
// - message size limit: the max number of bytes for a message
|
||||
// - total topic limit: max number of topics overall
|
||||
@@ -92,12 +99,13 @@ type Config struct {
|
||||
KeepaliveInterval time.Duration
|
||||
ManagerInterval time.Duration
|
||||
DisallowedTopics []string
|
||||
WebRootIsApp bool
|
||||
WebRoot string // empty to disable
|
||||
DelayedSenderInterval time.Duration
|
||||
FirebaseKeepaliveInterval time.Duration
|
||||
FirebasePollInterval time.Duration
|
||||
FirebaseQuotaExceededPenaltyDuration time.Duration
|
||||
UpstreamBaseURL string
|
||||
UpstreamAccessToken string
|
||||
SMTPSenderAddr string
|
||||
SMTPSenderUser string
|
||||
SMTPSenderPass string
|
||||
@@ -105,6 +113,12 @@ type Config struct {
|
||||
SMTPServerListen string
|
||||
SMTPServerDomain string
|
||||
SMTPServerAddrPrefix string
|
||||
TwilioAccount string
|
||||
TwilioAuthToken string
|
||||
TwilioPhoneNumber string
|
||||
TwilioCallsBaseURL string
|
||||
TwilioVerifyBaseURL string
|
||||
TwilioVerifyService string
|
||||
MetricsEnable bool
|
||||
MetricsListenHTTP string
|
||||
ProfileListenHTTP string
|
||||
@@ -133,13 +147,19 @@ type Config struct {
|
||||
StripeWebhookKey string
|
||||
StripePriceCacheDuration time.Duration
|
||||
BillingContact string
|
||||
EnableWeb bool
|
||||
EnableSignup bool // Enable creation of accounts via API and UI
|
||||
EnableLogin bool
|
||||
EnableReservations bool // Allow users with role "user" to own/reserve topics
|
||||
EnableMetrics bool
|
||||
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
|
||||
Version string // injected by App
|
||||
WebPushPrivateKey string
|
||||
WebPushPublicKey string
|
||||
WebPushFile string
|
||||
WebPushEmailAddress string
|
||||
WebPushStartupQueries string
|
||||
WebPushExpiryDuration time.Duration
|
||||
WebPushExpiryWarningDuration time.Duration
|
||||
}
|
||||
|
||||
// NewConfig instantiates a default new server config
|
||||
@@ -171,12 +191,13 @@ func NewConfig() *Config {
|
||||
KeepaliveInterval: DefaultKeepaliveInterval,
|
||||
ManagerInterval: DefaultManagerInterval,
|
||||
DisallowedTopics: DefaultDisallowedTopics,
|
||||
WebRootIsApp: false,
|
||||
WebRoot: "/",
|
||||
DelayedSenderInterval: DefaultDelayedSenderInterval,
|
||||
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
|
||||
FirebasePollInterval: DefaultFirebasePollInterval,
|
||||
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
|
||||
UpstreamBaseURL: "",
|
||||
UpstreamAccessToken: "",
|
||||
SMTPSenderAddr: "",
|
||||
SMTPSenderUser: "",
|
||||
SMTPSenderPass: "",
|
||||
@@ -184,6 +205,12 @@ func NewConfig() *Config {
|
||||
SMTPServerListen: "",
|
||||
SMTPServerDomain: "",
|
||||
SMTPServerAddrPrefix: "",
|
||||
TwilioCallsBaseURL: "https://api.twilio.com", // Override for tests
|
||||
TwilioAccount: "",
|
||||
TwilioAuthToken: "",
|
||||
TwilioPhoneNumber: "",
|
||||
TwilioVerifyBaseURL: "https://verify.twilio.com", // Override for tests
|
||||
TwilioVerifyService: "",
|
||||
MessageLimit: DefaultMessageLengthLimit,
|
||||
MinDelay: DefaultMinDelay,
|
||||
MaxDelay: DefaultMaxDelay,
|
||||
@@ -209,11 +236,16 @@ func NewConfig() *Config {
|
||||
StripeWebhookKey: "",
|
||||
StripePriceCacheDuration: DefaultStripePriceCacheDuration,
|
||||
BillingContact: "",
|
||||
EnableWeb: true,
|
||||
EnableSignup: false,
|
||||
EnableLogin: false,
|
||||
EnableReservations: false,
|
||||
AccessControlAllowOrigin: "*",
|
||||
Version: "",
|
||||
WebPushPrivateKey: "",
|
||||
WebPushPublicKey: "",
|
||||
WebPushFile: "",
|
||||
WebPushEmailAddress: "",
|
||||
WebPushExpiryDuration: DefaultWebPushExpiryDuration,
|
||||
WebPushExpiryWarningDuration: DefaultWebPushExpiryWarningDuration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,12 +106,25 @@ var (
|
||||
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", "", nil}
|
||||
errHTTPBadRequestBillingRequestInvalid = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid billing request", "", nil}
|
||||
errHTTPBadRequestBillingSubscriptionExists = &errHTTP{40029, http.StatusBadRequest, "invalid request: billing subscription already exists", "", nil}
|
||||
errHTTPBadRequestTierInvalid = &errHTTP{40030, http.StatusBadRequest, "invalid request: tier does not exist", "", nil}
|
||||
errHTTPBadRequestUserNotFound = &errHTTP{40031, http.StatusBadRequest, "invalid request: user does not exist", "", nil}
|
||||
errHTTPBadRequestPhoneCallsDisabled = &errHTTP{40032, http.StatusBadRequest, "invalid request: calling is disabled", "https://ntfy.sh/docs/config/#phone-calls", nil}
|
||||
errHTTPBadRequestPhoneNumberInvalid = &errHTTP{40033, http.StatusBadRequest, "invalid request: phone number invalid", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestPhoneNumberNotVerified = &errHTTP{40034, http.StatusBadRequest, "invalid request: phone number not verified, or no matching verified numbers found", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestAnonymousCallsNotAllowed = &errHTTP{40035, http.StatusBadRequest, "invalid request: anonymous phone calls are not allowed", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestPhoneNumberVerifyChannelInvalid = &errHTTP{40036, http.StatusBadRequest, "invalid request: verification channel must be 'sms' or 'call'", "https://ntfy.sh/docs/publish/#phone-calls", nil}
|
||||
errHTTPBadRequestDelayNoCall = &errHTTP{40037, http.StatusBadRequest, "delayed call notifications are not supported", "", nil}
|
||||
errHTTPBadRequestWebPushSubscriptionInvalid = &errHTTP{40038, http.StatusBadRequest, "invalid request: web push payload malformed", "", nil}
|
||||
errHTTPBadRequestWebPushEndpointUnknown = &errHTTP{40039, http.StatusBadRequest, "invalid request: web push endpoint unknown", "", nil}
|
||||
errHTTPBadRequestWebPushTopicCountTooHigh = &errHTTP{40040, http.StatusBadRequest, "invalid request: too many web push topic subscriptions", "", nil}
|
||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
|
||||
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
|
||||
errHTTPConflictUserExists = &errHTTP{40901, http.StatusConflict, "conflict: user already exists", "", nil}
|
||||
errHTTPConflictTopicReserved = &errHTTP{40902, http.StatusConflict, "conflict: access control entry for topic or topic pattern already exists", "", nil}
|
||||
errHTTPConflictSubscriptionExists = &errHTTP{40903, http.StatusConflict, "conflict: topic subscription already exists", "", nil}
|
||||
errHTTPConflictPhoneNumberExists = &errHTTP{40904, http.StatusConflict, "conflict: phone number already exists", "", nil}
|
||||
errHTTPGonePhoneVerificationExpired = &errHTTP{41001, http.StatusGone, "phone number verification expired or does not exist", "", nil}
|
||||
errHTTPEntityTooLargeAttachment = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations", nil}
|
||||
errHTTPEntityTooLargeMatrixRequest = &errHTTP{41302, http.StatusRequestEntityTooLarge, "Matrix request is larger than the max allowed length", "", nil}
|
||||
errHTTPEntityTooLargeJSONBody = &errHTTP{41303, http.StatusRequestEntityTooLarge, "JSON body too large", "", nil}
|
||||
@@ -124,8 +137,10 @@ var (
|
||||
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil}
|
||||
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
|
||||
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit
|
||||
errHTTPTooManyRequestsLimitCalls = &errHTTP{42910, http.StatusTooManyRequests, "limit reached: daily phone call quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
|
||||
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
|
||||
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
|
||||
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}
|
||||
errHTTPInternalErrorWebPushUnableToPublish = &errHTTP{50004, http.StatusInternalServerError, "internal server error: unable to publish web push message", "", nil}
|
||||
errHTTPInsufficientStorageUnifiedPush = &errHTTP{50701, http.StatusInsufficientStorage, "cannot publish to UnifiedPush topic without previously active subscriber", "", nil}
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ const (
|
||||
tagFirebase = "firebase"
|
||||
tagSMTP = "smtp" // Receive email
|
||||
tagEmail = "email" // Send email
|
||||
tagTwilio = "twilio"
|
||||
tagFileCache = "file_cache"
|
||||
tagMessageCache = "message_cache"
|
||||
tagStripe = "stripe"
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
tagResetter = "resetter"
|
||||
tagWebsocket = "websocket"
|
||||
tagMatrix = "matrix"
|
||||
tagWebPush = "webpush"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
File diff suppressed because one or more lines are too long
1857
server/mailer_emoji_map.json
Normal file
1857
server/mailer_emoji_map.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,7 @@ import (
|
||||
var (
|
||||
errUnexpectedMessageType = errors.New("unexpected message type")
|
||||
errMessageNotFound = errors.New("message not found")
|
||||
errNoRows = errors.New("no rows found")
|
||||
)
|
||||
|
||||
// Messages cache
|
||||
@@ -44,6 +45,7 @@ const (
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
@@ -54,46 +56,51 @@ const (
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
COMMIT;
|
||||
`
|
||||
insertMessageQuery = `
|
||||
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
|
||||
selectMessagesByIDQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`
|
||||
selectMessagesSinceTimeQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceIDQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > ? OR published = 0)
|
||||
ORDER BY time, id
|
||||
`
|
||||
selectMessagesDueQuery = `
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
|
||||
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
@@ -108,11 +115,14 @@ const (
|
||||
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
|
||||
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentSchemaVersion = 10
|
||||
currentSchemaVersion = 12
|
||||
createSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
@@ -222,20 +232,36 @@ const (
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
migrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
migrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: migrateFrom0,
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
6: migrateFrom6,
|
||||
7: migrateFrom7,
|
||||
8: migrateFrom8,
|
||||
9: migrateFrom9,
|
||||
0: migrateFrom0,
|
||||
1: migrateFrom1,
|
||||
2: migrateFrom2,
|
||||
3: migrateFrom3,
|
||||
4: migrateFrom4,
|
||||
5: migrateFrom5,
|
||||
6: migrateFrom6,
|
||||
7: migrateFrom7,
|
||||
8: migrateFrom8,
|
||||
9: migrateFrom9,
|
||||
10: migrateFrom10,
|
||||
11: migrateFrom11,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -251,7 +277,7 @@ func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupDB(db, startupQueries, cacheDuration); err != nil {
|
||||
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var queue *util.BatchingQueue[*message]
|
||||
@@ -365,6 +391,7 @@ func (c *messageCache) addMessages(ms []*message) error {
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
m.ContentType,
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
@@ -637,7 +664,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
|
||||
func readMessage(rows *sql.Rows) (*message, error) {
|
||||
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||
var priority int
|
||||
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
|
||||
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
||||
err := rows.Scan(
|
||||
&id,
|
||||
×tamp,
|
||||
@@ -657,6 +684,7 @@ func readMessage(rows *sql.Rows) (*message, error) {
|
||||
&attachmentURL,
|
||||
&sender,
|
||||
&user,
|
||||
&contentType,
|
||||
&encoding,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -687,30 +715,51 @@ func readMessage(rows *sql.Rows) (*message, error) {
|
||||
}
|
||||
}
|
||||
return &message{
|
||||
ID: id,
|
||||
Time: timestamp,
|
||||
Expires: expires,
|
||||
Event: messageEvent,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
Title: title,
|
||||
Priority: priority,
|
||||
Tags: tags,
|
||||
Click: click,
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: senderIP, // Must parse assuming database must be correct
|
||||
User: user,
|
||||
Encoding: encoding,
|
||||
ID: id,
|
||||
Time: timestamp,
|
||||
Expires: expires,
|
||||
Event: messageEvent,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
Title: title,
|
||||
Priority: priority,
|
||||
Tags: tags,
|
||||
Click: click,
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: senderIP, // Must parse assuming database must be correct
|
||||
User: user,
|
||||
ContentType: contentType,
|
||||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *messageCache) UpdateStats(messages int64) error {
|
||||
_, err := c.db.Exec(updateStatsQuery, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *messageCache) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(selectStatsQuery)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *messageCache) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
// Run startup queries
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
@@ -889,3 +938,35 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func migrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
295
server/server.go
295
server/server.go
@@ -9,13 +9,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -32,6 +25,14 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
// Server is the main server, providing the UI and API for ntfy
|
||||
@@ -48,15 +49,17 @@ type Server struct {
|
||||
topics map[string]*topic
|
||||
visitors map[string]*visitor // ip:<ip> or user:<user>
|
||||
firebaseClient *firebaseClient
|
||||
messages int64
|
||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||
userManager *user.Manager // Might be nil!
|
||||
messageCache *messageCache // Database that stores the messages
|
||||
webPush *webPushStore // Database that stores web push subscriptions
|
||||
fileCache *fileCache // File system based cache that stores attachments
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
|
||||
closeChan chan bool
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
||||
@@ -75,17 +78,26 @@ var (
|
||||
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
|
||||
|
||||
webConfigPath = "/config.js"
|
||||
webManifestPath = "/manifest.webmanifest"
|
||||
webRootHTMLPath = "/app.html"
|
||||
webServiceWorkerPath = "/sw.js"
|
||||
accountPath = "/account"
|
||||
matrixPushPath = "/_matrix/push/v1/notify"
|
||||
metricsPath = "/metrics"
|
||||
apiHealthPath = "/v1/health"
|
||||
apiTiers = "/v1/tiers"
|
||||
apiStatsPath = "/v1/stats"
|
||||
apiWebPushPath = "/v1/webpush"
|
||||
apiTiersPath = "/v1/tiers"
|
||||
apiUsersPath = "/v1/users"
|
||||
apiUsersAccessPath = "/v1/users/access"
|
||||
apiAccountPath = "/v1/account"
|
||||
apiAccountTokenPath = "/v1/account/token"
|
||||
apiAccountPasswordPath = "/v1/account/password"
|
||||
apiAccountSettingsPath = "/v1/account/settings"
|
||||
apiAccountSubscriptionPath = "/v1/account/subscription"
|
||||
apiAccountReservationPath = "/v1/account/reservation"
|
||||
apiAccountPhonePath = "/v1/account/phone"
|
||||
apiAccountPhoneVerifyPath = "/v1/account/phone/verify"
|
||||
apiAccountBillingPortalPath = "/v1/account/billing/portal"
|
||||
apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
|
||||
apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
|
||||
@@ -96,13 +108,13 @@ var (
|
||||
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
|
||||
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
|
||||
urlRegex = regexp.MustCompile(`^https?://`)
|
||||
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
|
||||
|
||||
//go:embed site
|
||||
webFs embed.FS
|
||||
webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
|
||||
webSiteDir = "/site"
|
||||
webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
|
||||
webAppIndex = "/app.html" // React app
|
||||
webFs embed.FS
|
||||
webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
|
||||
webSiteDir = "/site"
|
||||
webAppIndex = "/app.html" // React app
|
||||
|
||||
//go:embed docs
|
||||
docsStaticFs embed.FS
|
||||
@@ -116,9 +128,10 @@ const (
|
||||
newMessageBody = "New message" // Used in poll requests as generic message
|
||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||
jsonBodyBytesLimit = 16384
|
||||
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||
unifiedPushTopicLength = 14
|
||||
jsonBodyBytesLimit = 16384 // Max number of bytes for a JSON request body
|
||||
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||
unifiedPushTopicLength = 14 // Length of UnifiedPush topics, including the "up" part
|
||||
messagesHistoryMax = 10 // Number of message count values to keep in memory
|
||||
)
|
||||
|
||||
// WebSocket constants
|
||||
@@ -144,10 +157,21 @@ func New(conf *Config) (*Server, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var webPush *webPushStore
|
||||
if conf.WebPushPublicKey != "" {
|
||||
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
topics, err := messageCache.Topics()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages, err := messageCache.Stats()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fileCache *fileCache
|
||||
if conf.AttachmentCacheDir != "" {
|
||||
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
@@ -177,15 +201,18 @@ func New(conf *Config) (*Server, error) {
|
||||
firebaseClient = newFirebaseClient(sender, auther)
|
||||
}
|
||||
s := &Server{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
topics: topics,
|
||||
userManager: userManager,
|
||||
visitors: make(map[string]*visitor),
|
||||
stripe: stripe,
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
webPush: webPush,
|
||||
fileCache: fileCache,
|
||||
firebaseClient: firebaseClient,
|
||||
smtpSender: mailer,
|
||||
topics: topics,
|
||||
userManager: userManager,
|
||||
messages: messages,
|
||||
messagesHistory: []int64{messages},
|
||||
visitors: make(map[string]*visitor),
|
||||
stripe: stripe,
|
||||
}
|
||||
s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
|
||||
return s, nil
|
||||
@@ -329,6 +356,9 @@ func (s *Server) closeDatabases() {
|
||||
s.userManager.Close()
|
||||
}
|
||||
s.messageCache.Close()
|
||||
if s.webPush != nil {
|
||||
s.webPush.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// handle is the main entry point for all HTTP requests
|
||||
@@ -395,14 +425,26 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
|
||||
}
|
||||
|
||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return s.ensureWebEnabled(s.handleHome)(w, r, v)
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" && s.config.WebRoot == "/" {
|
||||
return s.ensureWebEnabled(s.handleRoot)(w, r, v)
|
||||
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
|
||||
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
|
||||
return s.handleHealth(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
|
||||
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == webManifestPath {
|
||||
return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
|
||||
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
|
||||
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
|
||||
return s.ensureAdmin(s.handleUsersAdd)(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersPath {
|
||||
return s.ensureAdmin(s.handleUsersDelete)(w, r, v)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiUsersAccessPath {
|
||||
return s.ensureAdmin(s.handleAccessAllow)(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersAccessPath {
|
||||
return s.ensureAdmin(s.handleAccessReset)(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
|
||||
return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
|
||||
@@ -441,13 +483,25 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
||||
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
|
||||
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
|
||||
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhoneVerifyPath {
|
||||
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberVerify)))(w, r, v)
|
||||
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhonePath {
|
||||
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberAdd)))(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPhonePath {
|
||||
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberDelete)))(w, r, v)
|
||||
} else if r.Method == http.MethodPost && apiWebPushPath == r.URL.Path {
|
||||
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushUpdate))(w, r, v)
|
||||
} else if r.Method == http.MethodDelete && apiWebPushPath == r.URL.Path {
|
||||
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushDelete))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
|
||||
return s.handleStats(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
|
||||
return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
|
||||
return s.handleMatrixDiscovery(w)
|
||||
} else if r.Method == http.MethodGet && r.URL.Path == metricsPath && s.metricsHandler != nil {
|
||||
return s.handleMetrics(w, r, v)
|
||||
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
||||
} else if r.Method == http.MethodGet && (staticRegex.MatchString(r.URL.Path) || r.URL.Path == webServiceWorkerPath || r.URL.Path == webRootHTMLPath) {
|
||||
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
|
||||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
|
||||
@@ -479,12 +533,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
||||
return errHTTPNotFound
|
||||
}
|
||||
|
||||
func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.WebRootIsApp {
|
||||
r.URL.Path = webAppIndex
|
||||
} else {
|
||||
r.URL.Path = webHomeIndex
|
||||
}
|
||||
func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
r.URL.Path = webAppIndex
|
||||
return s.handleStatic(w, r, v)
|
||||
}
|
||||
|
||||
@@ -516,18 +566,18 @@ func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor
|
||||
}
|
||||
|
||||
func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
appRoot := "/"
|
||||
if !s.config.WebRootIsApp {
|
||||
appRoot = "/app"
|
||||
}
|
||||
response := &apiConfigResponse{
|
||||
BaseURL: "", // Will translate to window.location.origin
|
||||
AppRoot: appRoot,
|
||||
AppRoot: s.config.WebRoot,
|
||||
EnableLogin: s.config.EnableLogin,
|
||||
EnableSignup: s.config.EnableSignup,
|
||||
EnablePayments: s.config.StripeSecretKey != "",
|
||||
EnableCalls: s.config.TwilioAccount != "",
|
||||
EnableEmails: s.config.SMTPSenderFrom != "",
|
||||
EnableReservations: s.config.EnableReservations,
|
||||
EnableWebPush: s.config.WebPushPublicKey != "",
|
||||
BillingContact: s.config.BillingContact,
|
||||
WebPushPublicKey: s.config.WebPushPublicKey,
|
||||
DisallowedTopics: s.config.DisallowedTopics,
|
||||
}
|
||||
b, err := json.MarshalIndent(response, "", " ")
|
||||
@@ -539,6 +589,25 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
|
||||
return err
|
||||
}
|
||||
|
||||
// handleWebManifest serves the web app manifest for the progressive web app (PWA)
|
||||
func (s *Server) handleWebManifest(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
response := &webManifestResponse{
|
||||
Name: "ntfy web",
|
||||
Description: "ntfy lets you send push notifications via scripts from any computer or phone",
|
||||
ShortName: "ntfy",
|
||||
Scope: "/",
|
||||
StartURL: s.config.WebRoot,
|
||||
Display: "standalone",
|
||||
BackgroundColor: "#ffffff",
|
||||
ThemeColor: "#317f6f",
|
||||
Icons: []*webManifestIcon{
|
||||
{SRC: "/static/images/pwa-192x192.png", Sizes: "192x192", Type: "image/png"},
|
||||
{SRC: "/static/images/pwa-512x512.png", Sizes: "512x512", Type: "image/png"},
|
||||
},
|
||||
}
|
||||
return s.writeJSONWithContentType(w, response, "application/manifest+json")
|
||||
}
|
||||
|
||||
// handleMetrics returns Prometheus metrics. This endpoint is only called if enable-metrics is set,
|
||||
// and listen-metrics-http is not set.
|
||||
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
@@ -546,17 +615,34 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visito
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStatic returns all static resources (excluding the docs), including the web app
|
||||
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
r.URL.Path = webSiteDir + r.URL.Path
|
||||
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDocs returns static resources related to the docs
|
||||
func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleStats returns the publicly available server stats
|
||||
func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
|
||||
s.mu.RLock()
|
||||
messages, n, rate := s.messages, len(s.messagesHistory), float64(0)
|
||||
if n > 1 {
|
||||
rate = float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds())
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
response := &apiStatsResponse{
|
||||
Messages: messages,
|
||||
MessagesRate: rate,
|
||||
}
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
||||
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
|
||||
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
|
||||
// can associate the download bandwidth with the uploader.
|
||||
@@ -623,6 +709,9 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if m.Attachment.Name != "" {
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
|
||||
}
|
||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
||||
return err
|
||||
}
|
||||
@@ -649,7 +738,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
return nil, err
|
||||
}
|
||||
m := newDefaultMessage(t.ID, "")
|
||||
cache, firebase, email, unifiedpush, e := s.parsePublishParams(r, m)
|
||||
cache, firebase, email, call, unifiedpush, e := s.parsePublishParams(r, m)
|
||||
if e != nil {
|
||||
return nil, e.With(t)
|
||||
}
|
||||
@@ -663,6 +752,14 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
return nil, errHTTPTooManyRequestsLimitMessages.With(t)
|
||||
} else if email != "" && !vrate.EmailAllowed() {
|
||||
return nil, errHTTPTooManyRequestsLimitEmails.With(t)
|
||||
} else if call != "" {
|
||||
var httpErr *errHTTP
|
||||
call, httpErr = s.convertPhoneNumber(v.User(), call)
|
||||
if httpErr != nil {
|
||||
return nil, httpErr.With(t)
|
||||
} else if !vrate.CallAllowed() {
|
||||
return nil, errHTTPTooManyRequestsLimitCalls.With(t)
|
||||
}
|
||||
}
|
||||
if m.PollID != "" {
|
||||
m = newPollRequestMessage(t.ID, m.PollID)
|
||||
@@ -687,6 +784,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
"message_firebase": firebase,
|
||||
"message_unifiedpush": unifiedpush,
|
||||
"message_email": email,
|
||||
"message_call": call,
|
||||
})
|
||||
if ev.IsTrace() {
|
||||
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
|
||||
@@ -703,9 +801,15 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
|
||||
if s.smtpSender != nil && email != "" {
|
||||
go s.sendEmail(v, m, email)
|
||||
}
|
||||
if s.config.UpstreamBaseURL != "" {
|
||||
if s.config.TwilioAccount != "" && call != "" {
|
||||
go s.callPhone(v, r, m, call)
|
||||
}
|
||||
if s.config.UpstreamBaseURL != "" && !unifiedpush { // UP messages are not sent to upstream
|
||||
go s.forwardPollRequest(v, m)
|
||||
}
|
||||
if s.config.WebPushPublicKey != "" {
|
||||
go s.publishToWebPushEndpoints(v, m)
|
||||
}
|
||||
} else {
|
||||
logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later")
|
||||
}
|
||||
@@ -798,7 +902,11 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
|
||||
req.Header.Set("X-Poll-ID", m.ID)
|
||||
if s.config.UpstreamAccessToken != "" {
|
||||
req.Header.Set("Authorization", util.BearerAuth(s.config.UpstreamAccessToken))
|
||||
}
|
||||
var httpClient = &http.Client{
|
||||
Timeout: time.Second * 10,
|
||||
}
|
||||
@@ -807,12 +915,16 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request")
|
||||
return
|
||||
} else if response.StatusCode != http.StatusOK {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode)
|
||||
if response.StatusCode == http.StatusTooManyRequests {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s; you may solve this by sending fewer daily messages, or by configuring upstream-access-token (assuming you have an account with higher rate limits) ", s.config.UpstreamBaseURL, response.Status)
|
||||
} else {
|
||||
logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s", s.config.UpstreamBaseURL, response.Status)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err *errHTTP) {
|
||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, unifiedpush bool, err *errHTTP) {
|
||||
cache = readBoolParam(r, true, "x-cache", "cache")
|
||||
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
||||
m.Title = readParam(r, "x-title", "title", "t")
|
||||
@@ -828,7 +940,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
}
|
||||
if attach != "" {
|
||||
if !urlRegex.MatchString(attach) {
|
||||
return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
|
||||
return false, false, "", "", false, errHTTPBadRequestAttachmentURLInvalid
|
||||
}
|
||||
m.Attachment.URL = attach
|
||||
if m.Attachment.Name == "" {
|
||||
@@ -846,13 +958,19 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
}
|
||||
if icon != "" {
|
||||
if !urlRegex.MatchString(icon) {
|
||||
return false, false, "", false, errHTTPBadRequestIconURLInvalid
|
||||
return false, false, "", "", false, errHTTPBadRequestIconURLInvalid
|
||||
}
|
||||
m.Icon = icon
|
||||
}
|
||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||
if s.smtpSender == nil && email != "" {
|
||||
return false, false, "", false, errHTTPBadRequestEmailDisabled
|
||||
return false, false, "", "", false, errHTTPBadRequestEmailDisabled
|
||||
}
|
||||
call = readParam(r, "x-call", "call")
|
||||
if call != "" && (s.config.TwilioAccount == "" || s.userManager == nil) {
|
||||
return false, false, "", "", false, errHTTPBadRequestPhoneCallsDisabled
|
||||
} else if call != "" && !isBoolValue(call) && !phoneNumberRegex.MatchString(call) {
|
||||
return false, false, "", "", false, errHTTPBadRequestPhoneNumberInvalid
|
||||
}
|
||||
messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
|
||||
if messageStr != "" {
|
||||
@@ -861,24 +979,27 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
var e error
|
||||
m.Priority, e = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
|
||||
if e != nil {
|
||||
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
||||
return false, false, "", "", false, errHTTPBadRequestPriorityInvalid
|
||||
}
|
||||
m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta")
|
||||
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
||||
if delayStr != "" {
|
||||
if !cache {
|
||||
return false, false, "", false, errHTTPBadRequestDelayNoCache
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayNoCache
|
||||
}
|
||||
if email != "" {
|
||||
return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
|
||||
}
|
||||
if call != "" {
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayNoCall // we cannot store the phone number (yet)
|
||||
}
|
||||
delay, err := util.ParseFutureTime(delayStr, time.Now())
|
||||
if err != nil {
|
||||
return false, false, "", false, errHTTPBadRequestDelayCannotParse
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayCannotParse
|
||||
} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
|
||||
return false, false, "", false, errHTTPBadRequestDelayTooSmall
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayTooSmall
|
||||
} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
|
||||
return false, false, "", false, errHTTPBadRequestDelayTooLarge
|
||||
return false, false, "", "", false, errHTTPBadRequestDelayTooLarge
|
||||
}
|
||||
m.Time = delay.Unix()
|
||||
}
|
||||
@@ -886,9 +1007,13 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
if actionsStr != "" {
|
||||
m.Actions, e = parseActions(actionsStr)
|
||||
if e != nil {
|
||||
return false, false, "", false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
|
||||
return false, false, "", "", false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
|
||||
}
|
||||
}
|
||||
contentType, markdown := readParam(r, "content-type", "content_type"), readBoolParam(r, false, "x-markdown", "markdown", "md")
|
||||
if markdown || strings.ToLower(contentType) == "text/markdown" {
|
||||
m.ContentType = "text/markdown"
|
||||
}
|
||||
unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
|
||||
if unifiedpush {
|
||||
firebase = false
|
||||
@@ -900,7 +1025,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
||||
cache = false
|
||||
email = ""
|
||||
}
|
||||
return cache, firebase, email, unifiedpush, nil
|
||||
return cache, firebase, email, call, unifiedpush, nil
|
||||
}
|
||||
|
||||
// handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
|
||||
@@ -1170,7 +1295,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Subscription connections can be canceled externally, see topic.CancelSubscribers
|
||||
// Subscription connections can be canceled externally, see topic.CancelSubscribersExceptUser
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@@ -1412,6 +1537,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visito
|
||||
return nil
|
||||
}
|
||||
|
||||
// topicFromPath returns the topic from a root path (e.g. /mytopic), creating it if it doesn't exist.
|
||||
func (s *Server) topicFromPath(path string) (*topic, error) {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 2 {
|
||||
@@ -1420,6 +1546,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
|
||||
return s.topicFromID(parts[1])
|
||||
}
|
||||
|
||||
// topicsFromPath returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
|
||||
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
|
||||
parts := strings.Split(path, "/")
|
||||
if len(parts) < 2 {
|
||||
@@ -1433,6 +1560,7 @@ func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
|
||||
return topics, parts[1], nil
|
||||
}
|
||||
|
||||
// topicsFromIDs returns the topics with the given IDs, creating them if they don't exist.
|
||||
func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -1452,6 +1580,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
// topicFromID returns the topic with the given ID, creating it if it doesn't exist.
|
||||
func (s *Server) topicFromID(id string) (*topic, error) {
|
||||
topics, err := s.topicsFromIDs(id)
|
||||
if err != nil {
|
||||
@@ -1460,6 +1589,23 @@ func (s *Server) topicFromID(id string) (*topic, error) {
|
||||
return topics[0], nil
|
||||
}
|
||||
|
||||
// topicsFromPattern returns a list of topics matching the given pattern, but it does not create them.
|
||||
func (s *Server) topicsFromPattern(pattern string) ([]*topic, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
patternRegexp, err := regexp.Compile("^" + strings.ReplaceAll(pattern, "*", ".*") + "$")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics := make([]*topic, 0)
|
||||
for _, t := range s.topics {
|
||||
if patternRegexp.MatchString(t.ID) {
|
||||
topics = append(topics, t)
|
||||
}
|
||||
}
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
func (s *Server) runSMTPServer() error {
|
||||
s.smtpServerBackend = newMailBackend(s.config, s.handle)
|
||||
s.smtpServer = smtp.NewServer(s.smtpServerBackend)
|
||||
@@ -1580,9 +1726,9 @@ func (s *Server) sendDelayedMessages() error {
|
||||
|
||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
logvm(v, m).Debug("Sending delayed message")
|
||||
s.mu.Lock()
|
||||
s.mu.RLock()
|
||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||
s.mu.Unlock()
|
||||
s.mu.RUnlock()
|
||||
if ok {
|
||||
go func() {
|
||||
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
|
||||
@@ -1597,6 +1743,9 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
||||
if s.config.UpstreamBaseURL != "" {
|
||||
go s.forwardPollRequest(v, m)
|
||||
}
|
||||
if s.config.WebPushPublicKey != "" {
|
||||
go s.publishToWebPushEndpoints(v, m)
|
||||
}
|
||||
if err := s.messageCache.MarkPublished(m); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1640,6 +1789,9 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
||||
if m.Icon != "" {
|
||||
r.Header.Set("X-Icon", m.Icon)
|
||||
}
|
||||
if m.Markdown {
|
||||
r.Header.Set("X-Markdown", "yes")
|
||||
}
|
||||
if len(m.Actions) > 0 {
|
||||
actionsStr, err := json.Marshal(m.Actions)
|
||||
if err != nil {
|
||||
@@ -1653,6 +1805,9 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
|
||||
if m.Delay != "" {
|
||||
r.Header.Set("X-Delay", m.Delay)
|
||||
}
|
||||
if m.Call != "" {
|
||||
r.Header.Set("X-Call", m.Call)
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
}
|
||||
@@ -1814,10 +1969,28 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||
}
|
||||
|
||||
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return s.writeJSONWithContentType(w, v, "application/json")
|
||||
}
|
||||
|
||||
func (s *Server) writeJSONWithContentType(w http.ResponseWriter, v any, contentType string) error {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) updateAndWriteStats(messagesCount int64) {
|
||||
s.mu.Lock()
|
||||
s.messagesHistory = append(s.messagesHistory, messagesCount)
|
||||
if len(s.messagesHistory) > messagesHistoryMax {
|
||||
s.messagesHistory = s.messagesHistory[1:]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
go func() {
|
||||
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
|
||||
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -144,6 +144,39 @@
|
||||
# smtp-server-domain:
|
||||
# smtp-server-addr-prefix:
|
||||
|
||||
# Web Push support (background notifications for browsers)
|
||||
#
|
||||
# If enabled, allows ntfy to receive push notifications, even when the ntfy web app is closed. When enabled, users
|
||||
# can enable background notifications in the web app. Once enabled, ntfy will forward published messages to the push
|
||||
# endpoint, which will then forward it to the browser.
|
||||
#
|
||||
# You must configure web-push-public/private key, web-push-file, and web-push-email-address below to enable Web Push.
|
||||
# Run "ntfy webpush keys" to generate the keys.
|
||||
#
|
||||
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
|
||||
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
|
||||
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
|
||||
# - web-push-email-address is the admin email address send to the push provider, e.g. `sysadmin@example.com`
|
||||
# - web-push-startup-queries is an optional list of queries to run on startup`
|
||||
#
|
||||
# web-push-public-key:
|
||||
# web-push-private-key:
|
||||
# web-push-file:
|
||||
# web-push-email-address:
|
||||
# web-push-startup-queries:
|
||||
|
||||
# If enabled, ntfy can perform voice calls via Twilio via the "X-Call" header.
|
||||
#
|
||||
# - twilio-account is the Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586
|
||||
# - twilio-auth-token is the Twilio auth token, e.g. affebeef258625862586258625862586
|
||||
# - twilio-phone-number is the outgoing phone number you purchased, e.g. +18775132586
|
||||
# - twilio-verify-service is the Twilio Verify service SID, e.g. VA12345beefbeef67890beefbeef122586
|
||||
#
|
||||
# twilio-account:
|
||||
# twilio-auth-token:
|
||||
# twilio-phone-number:
|
||||
# twilio-verify-service:
|
||||
|
||||
# Interval in which keepalive messages are sent to the client. This is to prevent
|
||||
# intermediaries closing the connection for inactivity.
|
||||
#
|
||||
@@ -167,11 +200,13 @@
|
||||
#
|
||||
# disallowed-topics:
|
||||
|
||||
# Defines if the root route (/) is pointing to the landing page (as on ntfy.sh) or the
|
||||
# web app. If you self-host, you don't want to change this.
|
||||
# Can be "app" (default), "home" or "disable" to disable the web app entirely.
|
||||
# Defines the root path of the web app, or disables the web app entirely.
|
||||
#
|
||||
# web-root: app
|
||||
# Can be any simple path, e.g. "/", "/app", or "/ntfy". For backwards-compatibility reasons,
|
||||
# the values "app" (maps to "/"), "home" (maps to "/app"), or "disable" (maps to "") to disable
|
||||
# the web app entirely.
|
||||
#
|
||||
# web-root: /
|
||||
|
||||
# Various feature flags used to control the web app, and API access, mainly around user and
|
||||
# account management.
|
||||
@@ -194,7 +229,12 @@
|
||||
# the message ID of the original message, instructing the iOS app to poll this server for the actual message contents.
|
||||
# This is to prevent the upstream server and Firebase/APNS from being able to read the message.
|
||||
#
|
||||
# - upstream-base-url is the base URL of the upstream server. Should be "https://ntfy.sh".
|
||||
# - upstream-access-token is the token used to authenticate with the upstream server. This is only required
|
||||
# if you exceed the upstream rate limits, or the uptream server requires authentication.
|
||||
#
|
||||
# upstream-base-url:
|
||||
# upstream-access-token:
|
||||
|
||||
# Rate limiting: Total number of topics before the server rejects new topics.
|
||||
#
|
||||
@@ -302,6 +342,10 @@
|
||||
# - "field -> level" to match any value, e.g. "time_taken_ms -> debug"
|
||||
# Warning: Using log-level-overrides has a performance penalty. Only use it for temporary debugging.
|
||||
#
|
||||
# Check your permissions:
|
||||
# If you are running ntfy with systemd, make sure this log file is owned by the
|
||||
# ntfy user and group by running: chown ntfy.ntfy <filename>.
|
||||
#
|
||||
# Example (good for production):
|
||||
# log-level: info
|
||||
# log-format: json
|
||||
|
||||
@@ -56,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
|
||||
Messages: limits.MessageLimit,
|
||||
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
|
||||
Emails: limits.EmailLimit,
|
||||
Calls: limits.CallLimit,
|
||||
Reservations: limits.ReservationsLimit,
|
||||
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: limits.AttachmentFileSizeLimit,
|
||||
@@ -67,6 +68,8 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
|
||||
MessagesRemaining: stats.MessagesRemaining,
|
||||
Emails: stats.Emails,
|
||||
EmailsRemaining: stats.EmailsRemaining,
|
||||
Calls: stats.Calls,
|
||||
CallsRemaining: stats.CallsRemaining,
|
||||
Reservations: stats.Reservations,
|
||||
ReservationsRemaining: stats.ReservationsRemaining,
|
||||
AttachmentTotalSize: stats.AttachmentTotalSize,
|
||||
@@ -105,17 +108,19 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
|
||||
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
|
||||
}
|
||||
}
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(reservations) > 0 {
|
||||
response.Reservations = make([]*apiAccountReservation, 0)
|
||||
for _, r := range reservations {
|
||||
response.Reservations = append(response.Reservations, &apiAccountReservation{
|
||||
Topic: r.Topic,
|
||||
Everyone: r.Everyone.String(),
|
||||
})
|
||||
if s.config.EnableReservations {
|
||||
reservations, err := s.userManager.Reservations(u.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(reservations) > 0 {
|
||||
response.Reservations = make([]*apiAccountReservation, 0)
|
||||
for _, r := range reservations {
|
||||
response.Reservations = append(response.Reservations, &apiAccountReservation{
|
||||
Topic: r.Topic,
|
||||
Everyone: r.Everyone.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
tokens, err := s.userManager.Tokens(u.ID)
|
||||
@@ -138,6 +143,15 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
|
||||
})
|
||||
}
|
||||
}
|
||||
if s.config.TwilioAccount != "" {
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(phoneNumbers) > 0 {
|
||||
response.PhoneNumbers = phoneNumbers
|
||||
}
|
||||
}
|
||||
} else {
|
||||
response.Username = user.Everyone
|
||||
response.Role = string(user.RoleAnonymous)
|
||||
@@ -156,6 +170,11 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
|
||||
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
|
||||
return errHTTPBadRequestIncorrectPasswordConfirmation
|
||||
}
|
||||
if s.webPush != nil && u.ID != "" {
|
||||
if err := s.webPush.RemoveSubscriptionsByUserID(u.ID); err != nil {
|
||||
logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
|
||||
}
|
||||
}
|
||||
if u.Billing.StripeSubscriptionID != "" {
|
||||
logvr(v, r).Tag(tagStripe).Info("Canceling billing subscription for user %s", u.Name)
|
||||
if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {
|
||||
@@ -444,7 +463,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.CancelSubscribers(u.ID)
|
||||
t.CancelSubscribersExceptUser(u.ID)
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
@@ -511,6 +530,72 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *vi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountPhoneNumberVerify(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
u := v.User()
|
||||
req, err := readJSONWithLimit[apiAccountPhoneNumberVerifyRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !phoneNumberRegex.MatchString(req.Number) {
|
||||
return errHTTPBadRequestPhoneNumberInvalid
|
||||
} else if req.Channel != "sms" && req.Channel != "call" {
|
||||
return errHTTPBadRequestPhoneNumberVerifyChannelInvalid
|
||||
}
|
||||
// Check user is allowed to add phone numbers
|
||||
if u == nil || (u.IsUser() && u.Tier == nil) {
|
||||
return errHTTPUnauthorized
|
||||
} else if u.IsUser() && u.Tier.CallLimit == 0 {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
// Check if phone number exists
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if util.Contains(phoneNumbers, req.Number) {
|
||||
return errHTTPConflictPhoneNumberExists
|
||||
}
|
||||
// Actually add the unverified number, and send verification
|
||||
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Sending phone number verification")
|
||||
if err := s.verifyPhoneNumber(v, r, req.Number, req.Channel); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountPhoneNumberAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
u := v.User()
|
||||
req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !phoneNumberRegex.MatchString(req.Number) {
|
||||
return errHTTPBadRequestPhoneNumberInvalid
|
||||
}
|
||||
if err := s.verifyPhoneNumberCheck(v, r, req.Number, req.Code); err != nil {
|
||||
return err
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Adding phone number as verified")
|
||||
if err := s.userManager.AddPhoneNumber(u.ID, req.Number); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccountPhoneNumberDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
u := v.User()
|
||||
req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !phoneNumberRegex.MatchString(req.Number) {
|
||||
return errHTTPBadRequestPhoneNumberInvalid
|
||||
}
|
||||
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Deleting phone number")
|
||||
if err := s.userManager.RemovePhoneNumber(u.ID, req.Number); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
// publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic
|
||||
func (s *Server) publishSyncEventAsync(v *visitor) {
|
||||
go func() {
|
||||
|
||||
@@ -151,6 +151,8 @@ func TestAccount_Get_Anonymous(t *testing.T) {
|
||||
require.Equal(t, int64(1004), account.Stats.MessagesRemaining)
|
||||
require.Equal(t, int64(0), account.Stats.Emails)
|
||||
require.Equal(t, int64(24), account.Stats.EmailsRemaining)
|
||||
require.Equal(t, int64(0), account.Stats.Calls)
|
||||
require.Equal(t, int64(0), account.Stats.CallsRemaining)
|
||||
|
||||
rr = request(t, s, "POST", "/mytopic", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
@@ -498,6 +500,8 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
|
||||
func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.EnableSignup = true
|
||||
conf.EnableReservations = true
|
||||
conf.TwilioAccount = "dummy"
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
// Create user
|
||||
@@ -510,6 +514,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
||||
MessageLimit: 123,
|
||||
MessageExpiryDuration: 86400 * time.Second,
|
||||
EmailLimit: 32,
|
||||
CallLimit: 10,
|
||||
ReservationLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
@@ -551,6 +556,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
|
||||
require.Equal(t, int64(123), account.Limits.Messages)
|
||||
require.Equal(t, int64(86400), account.Limits.MessagesExpiryDuration)
|
||||
require.Equal(t, int64(32), account.Limits.Emails)
|
||||
require.Equal(t, int64(10), account.Limits.Calls)
|
||||
require.Equal(t, int64(2), account.Limits.Reservations)
|
||||
require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize)
|
||||
require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize)
|
||||
|
||||
143
server/server_admin.go
Normal file
143
server/server_admin.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/user"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
users, err := s.userManager.Users()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
grants, err := s.userManager.AllGrants()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
usersResponse := make([]*apiUserResponse, len(users))
|
||||
for i, u := range users {
|
||||
tier := ""
|
||||
if u.Tier != nil {
|
||||
tier = u.Tier.Code
|
||||
}
|
||||
userGrants := make([]*apiUserGrantResponse, len(grants[u.ID]))
|
||||
for i, g := range grants[u.ID] {
|
||||
userGrants[i] = &apiUserGrantResponse{
|
||||
Topic: g.TopicPattern,
|
||||
Permission: g.Allow.String(),
|
||||
}
|
||||
}
|
||||
usersResponse[i] = &apiUserResponse{
|
||||
Username: u.Name,
|
||||
Role: string(u.Role),
|
||||
Tier: tier,
|
||||
Grants: userGrants,
|
||||
}
|
||||
}
|
||||
return s.writeJSON(w, usersResponse)
|
||||
}
|
||||
|
||||
func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiUserAddRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !user.AllowedUsername(req.Username) || req.Password == "" {
|
||||
return errHTTPBadRequest.Wrap("username invalid, or password missing")
|
||||
}
|
||||
u, err := s.userManager.User(req.Username)
|
||||
if err != nil && err != user.ErrUserNotFound {
|
||||
return err
|
||||
} else if u != nil {
|
||||
return errHTTPConflictUserExists
|
||||
}
|
||||
var tier *user.Tier
|
||||
if req.Tier != "" {
|
||||
tier, err = s.userManager.Tier(req.Tier)
|
||||
if err == user.ErrTierNotFound {
|
||||
return errHTTPBadRequestTierInvalid
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.userManager.AddUser(req.Username, req.Password, user.RoleUser); err != nil {
|
||||
return err
|
||||
}
|
||||
if tier != nil {
|
||||
if err := s.userManager.ChangeTier(req.Username, req.Tier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u, err := s.userManager.User(req.Username)
|
||||
if err == user.ErrUserNotFound {
|
||||
return errHTTPBadRequestUserNotFound
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else if !u.IsUser() {
|
||||
return errHTTPUnauthorized.Wrap("can only remove regular users from API")
|
||||
}
|
||||
if err := s.userManager.RemoveUser(req.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.killUserSubscriber(u, "*"); err != nil { // FIXME super inefficient
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiAccessAllowRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.userManager.User(req.Username)
|
||||
if err == user.ErrUserNotFound {
|
||||
return errHTTPBadRequestUserNotFound
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
permission, err := user.ParsePermission(req.Permission)
|
||||
if err != nil {
|
||||
return errHTTPBadRequestPermissionInvalid
|
||||
}
|
||||
if err := s.userManager.AllowAccess(req.Username, req.Topic, permission); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessReset(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiAccessResetRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u, err := s.userManager.User(req.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.userManager.ResetAccess(req.Username, req.Topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.killUserSubscriber(u, req.Topic); err != nil { // This may be a pattern
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) killUserSubscriber(u *user.User, topicPattern string) error {
|
||||
topics, err := s.topicsFromPattern(topicPattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range topics {
|
||||
t.CancelSubscriberUser(u.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
181
server/server_admin_test.go
Normal file
181
server/server_admin_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUser_AddRemove(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin, tier
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "tier1",
|
||||
}))
|
||||
|
||||
// Create user via API
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Create user with tier via API
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Check users
|
||||
users, err := s.userManager.Users()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 4, len(users))
|
||||
require.Equal(t, "phil", users[0].Name)
|
||||
require.Equal(t, "ben", users[1].Name)
|
||||
require.Equal(t, user.RoleUser, users[1].Role)
|
||||
require.Nil(t, users[1].Tier)
|
||||
require.Equal(t, "emma", users[2].Name)
|
||||
require.Equal(t, user.RoleUser, users[2].Role)
|
||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||
require.Equal(t, user.Everyone, users[3].Name)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||
defer s.closeDatabases()
|
||||
|
||||
// Create admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
|
||||
// Cannot create user with invalid username
|
||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 400, rr.Code)
|
||||
|
||||
// Cannot create user if user already exists
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot create user with invalid tier
|
||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||
|
||||
// Cannot delete user as non-admin
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
|
||||
// Delete user via API
|
||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
|
||||
// Subscribing not allowed
|
||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
|
||||
// Grant access
|
||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Now subscribing is allowed
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Reset access
|
||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Subscribing not allowed (again)
|
||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
|
||||
// Grant access fails, because non-admin
|
||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, c)
|
||||
defer s.closeDatabases()
|
||||
|
||||
// User and admin, grant access to "gol*" topics
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||
|
||||
start, timeTaken := time.Now(), atomic.Int64{}
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
timeTaken.Store(time.Since(start).Milliseconds())
|
||||
}()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Reset access
|
||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||
waitFor(t, func() bool {
|
||||
return timeTaken.Load() >= 500
|
||||
})
|
||||
}
|
||||
@@ -144,17 +144,18 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
|
||||
}
|
||||
if allowForward {
|
||||
data = map[string]string{
|
||||
"id": m.ID,
|
||||
"time": fmt.Sprintf("%d", m.Time),
|
||||
"event": m.Event,
|
||||
"topic": m.Topic,
|
||||
"priority": fmt.Sprintf("%d", m.Priority),
|
||||
"tags": strings.Join(m.Tags, ","),
|
||||
"click": m.Click,
|
||||
"icon": m.Icon,
|
||||
"title": m.Title,
|
||||
"message": m.Message,
|
||||
"encoding": m.Encoding,
|
||||
"id": m.ID,
|
||||
"time": fmt.Sprintf("%d", m.Time),
|
||||
"event": m.Event,
|
||||
"topic": m.Topic,
|
||||
"priority": fmt.Sprintf("%d", m.Priority),
|
||||
"tags": strings.Join(m.Tags, ","),
|
||||
"click": m.Click,
|
||||
"icon": m.Icon,
|
||||
"title": m.Title,
|
||||
"message": m.Message,
|
||||
"content_type": m.ContentType,
|
||||
"encoding": m.Encoding,
|
||||
}
|
||||
if len(m.Actions) > 0 {
|
||||
actions, err := json.Marshal(m.Actions)
|
||||
|
||||
@@ -182,6 +182,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
"title": "some title",
|
||||
"message": "this is a message",
|
||||
"actions": `[{"id":"123","action":"view","label":"Open page","clear":true,"url":"https://ntfy.sh"},{"id":"456","action":"http","label":"Close door","clear":false,"url":"https://door.com/close","method":"PUT","headers":{"really":"yes"}}]`,
|
||||
"content_type": "",
|
||||
"encoding": "",
|
||||
"attachment_name": "some file.jpg",
|
||||
"attachment_type": "image/jpeg",
|
||||
@@ -203,6 +204,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
||||
"title": "some title",
|
||||
"message": "this is a message",
|
||||
"actions": `[{"id":"123","action":"view","label":"Open page","clear":true,"url":"https://ntfy.sh"},{"id":"456","action":"http","label":"Close door","clear":false,"url":"https://door.com/close","method":"PUT","headers":{"really":"yes"}}]`,
|
||||
"content_type": "",
|
||||
"encoding": "",
|
||||
"attachment_name": "some file.jpg",
|
||||
"attachment_type": "image/jpeg",
|
||||
|
||||
@@ -15,6 +15,7 @@ func (s *Server) execManager() {
|
||||
s.pruneTokens()
|
||||
s.pruneAttachments()
|
||||
s.pruneMessages()
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
|
||||
// Message count per topic
|
||||
var messagesCached int
|
||||
@@ -73,9 +74,14 @@ func (s *Server) execManager() {
|
||||
}
|
||||
|
||||
// Print stats
|
||||
s.mu.Lock()
|
||||
s.mu.RLock()
|
||||
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
|
||||
s.mu.Unlock()
|
||||
s.mu.RUnlock()
|
||||
|
||||
// Update stats
|
||||
s.updateAndWriteStats(messagesCount)
|
||||
|
||||
// Log stats
|
||||
log.
|
||||
Tag(tagManager).
|
||||
Fields(log.Context{
|
||||
|
||||
@@ -15,6 +15,8 @@ var (
|
||||
metricEmailsPublishedFailure prometheus.Counter
|
||||
metricEmailsReceivedSuccess prometheus.Counter
|
||||
metricEmailsReceivedFailure prometheus.Counter
|
||||
metricCallsMadeSuccess prometheus.Counter
|
||||
metricCallsMadeFailure prometheus.Counter
|
||||
metricUnifiedPushPublishedSuccess prometheus.Counter
|
||||
metricMatrixPublishedSuccess prometheus.Counter
|
||||
metricMatrixPublishedFailure prometheus.Counter
|
||||
@@ -57,6 +59,12 @@ func initMetrics() {
|
||||
metricEmailsReceivedFailure = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ntfy_emails_received_failure",
|
||||
})
|
||||
metricCallsMadeSuccess = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ntfy_calls_made_success",
|
||||
})
|
||||
metricCallsMadeFailure = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ntfy_calls_made_failure",
|
||||
})
|
||||
metricUnifiedPushPublishedSuccess = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ntfy_unifiedpush_published_success",
|
||||
})
|
||||
@@ -95,6 +103,8 @@ func initMetrics() {
|
||||
metricEmailsPublishedFailure,
|
||||
metricEmailsReceivedSuccess,
|
||||
metricEmailsReceivedFailure,
|
||||
metricCallsMadeSuccess,
|
||||
metricCallsMadeFailure,
|
||||
metricUnifiedPushPublishedSuccess,
|
||||
metricMatrixPublishedSuccess,
|
||||
metricMatrixPublishedFailure,
|
||||
|
||||
@@ -51,7 +51,16 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
|
||||
|
||||
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if !s.config.EnableWeb {
|
||||
if s.config.WebRoot == "" {
|
||||
return errHTTPNotFound
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
|
||||
return errHTTPNotFound
|
||||
}
|
||||
return next(w, r, v)
|
||||
@@ -76,6 +85,24 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ensureAdmin(next handleFunc) handleFunc {
|
||||
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if !v.User().IsAdmin() {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
return next(w, r, v)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ensureCallsEnabled(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.TwilioAccount == "" || s.userManager == nil {
|
||||
return errHTTPNotFound
|
||||
}
|
||||
return next(w, r, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.config.StripeSecretKey == "" || s.stripe == nil {
|
||||
|
||||
@@ -68,6 +68,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
||||
Messages: freeTier.MessageLimit,
|
||||
MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
|
||||
Emails: freeTier.EmailLimit,
|
||||
Calls: freeTier.CallLimit,
|
||||
Reservations: freeTier.ReservationsLimit,
|
||||
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
|
||||
@@ -96,6 +97,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
||||
Messages: tier.MessageLimit,
|
||||
MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
|
||||
Emails: tier.EmailLimit,
|
||||
Calls: tier.CallLimit,
|
||||
Reservations: tier.ReservationLimit,
|
||||
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSize: tier.AttachmentFileSizeLimit,
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
@@ -220,11 +220,7 @@ func TestServer_StaticSites(t *testing.T) {
|
||||
|
||||
rr = request(t, s, "GET", "/mytopic", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Contains(t, rr.Body.String(), `<meta name="robots" content="noindex, nofollow"/>`)
|
||||
|
||||
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Contains(t, rr.Body.String(), `/* general styling */`)
|
||||
require.Contains(t, rr.Body.String(), `<meta name="robots" content="noindex, nofollow" />`)
|
||||
|
||||
rr = request(t, s, "GET", "/docs", "", nil)
|
||||
require.Equal(t, 301, rr.Code)
|
||||
@@ -234,7 +230,7 @@ func TestServer_StaticSites(t *testing.T) {
|
||||
|
||||
func TestServer_WebEnabled(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.EnableWeb = false
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
rr := request(t, s, "GET", "/", "", nil)
|
||||
@@ -243,11 +239,17 @@ func TestServer_WebEnabled(t *testing.T) {
|
||||
rr = request(t, s, "GET", "/config.js", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/sw.js", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/app.html", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf2 := newTestConfig(t)
|
||||
conf2.EnableWeb = true
|
||||
conf2.WebRoot = "/"
|
||||
s2 := newTestServer(t, conf2)
|
||||
|
||||
rr = request(t, s2, "GET", "/", "", nil)
|
||||
@@ -256,8 +258,34 @@ func TestServer_WebEnabled(t *testing.T) {
|
||||
rr = request(t, s2, "GET", "/config.js", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s2, "GET", "/static/css/home.css", "", nil)
|
||||
rr = request(t, s2, "GET", "/sw.js", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
rr = request(t, s2, "GET", "/app.html", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
|
||||
func TestServer_WebPushEnabled(t *testing.T) {
|
||||
conf := newTestConfig(t)
|
||||
conf.WebRoot = "" // Disable web app
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf2 := newTestConfig(t)
|
||||
s2 := newTestServer(t, conf2)
|
||||
|
||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 404, rr.Code)
|
||||
|
||||
conf3 := newTestConfigWithWebPush(t)
|
||||
s3 := newTestServer(t, conf3)
|
||||
|
||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||
|
||||
}
|
||||
|
||||
func TestServer_PublishLargeMessage(t *testing.T) {
|
||||
@@ -301,6 +329,27 @@ func TestServer_PublishPriority(t *testing.T) {
|
||||
require.Equal(t, 40007, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishPriority_SpecialHTTPHeader(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"Priority": "u=4",
|
||||
"X-Priority": "5",
|
||||
})
|
||||
require.Equal(t, 5, toMessage(t, response.Body.String()).Priority)
|
||||
|
||||
response = request(t, s, "POST", "/mytopic?priority=4", "test", map[string]string{
|
||||
"Priority": "u=9",
|
||||
})
|
||||
require.Equal(t, 4, toMessage(t, response.Body.String()).Priority)
|
||||
|
||||
response = request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"p": "2",
|
||||
"priority": "u=9, i",
|
||||
})
|
||||
require.Equal(t, 2, toMessage(t, response.Body.String()).Priority)
|
||||
}
|
||||
|
||||
func TestServer_PublishGETOnlyOneTopic(t *testing.T) {
|
||||
// This tests a bug that allowed publishing topics with a comma in the name (no ticket)
|
||||
|
||||
@@ -463,6 +512,8 @@ func TestServer_PublishAtAndPrune(t *testing.T) {
|
||||
messages := toMessages(t, response.Body.String())
|
||||
require.Equal(t, 1, len(messages)) // Not affected by pruning
|
||||
require.Equal(t, "a message", messages[0].Message)
|
||||
|
||||
time.Sleep(time.Second) // FIXME CI failing not sure why
|
||||
}
|
||||
|
||||
func TestServer_PublishAndMultiPoll(t *testing.T) {
|
||||
@@ -1199,7 +1250,20 @@ func TestServer_PublishDelayedEmail_Fail(t *testing.T) {
|
||||
"E-Mail": "test@example.com",
|
||||
"Delay": "20 min",
|
||||
})
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, 40003, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishDelayedCall_Fail(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
|
||||
"Call": "yes",
|
||||
"Delay": "20 min",
|
||||
})
|
||||
require.Equal(t, 40037, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
|
||||
@@ -1477,6 +1541,39 @@ func TestServer_PublishActions_AndPoll(t *testing.T) {
|
||||
require.Equal(t, "target_temp_f=65", m.Actions[1].Body)
|
||||
}
|
||||
|
||||
func TestServer_PublishMarkdown(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "PUT", "/mytopic", "**make this bold**", map[string]string{
|
||||
"Content-Type": "text/markdown",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "**make this bold**", m.Message)
|
||||
require.Equal(t, "text/markdown", m.ContentType)
|
||||
}
|
||||
|
||||
func TestServer_PublishMarkdown_QueryParam(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "PUT", "/mytopic?md=1", "**make this bold**", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "**make this bold**", m.Message)
|
||||
require.Equal(t, "text/markdown", m.ContentType)
|
||||
}
|
||||
|
||||
func TestServer_PublishMarkdown_NotMarkdown(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "PUT", "/mytopic", "**make this bold**", map[string]string{
|
||||
"Content-Type": "not-markdown",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "", m.ContentType)
|
||||
}
|
||||
|
||||
func TestServer_PublishAsJSON(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
body := `{"topic":"mytopic","message":"A message","title":"a title\nwith lines","tags":["tag1","tag 2"],` +
|
||||
@@ -1494,12 +1591,25 @@ func TestServer_PublishAsJSON(t *testing.T) {
|
||||
require.Equal(t, "google.pdf", m.Attachment.Name)
|
||||
require.Equal(t, "http://ntfy.sh", m.Click)
|
||||
require.Equal(t, "https://ntfy.sh/static/img/ntfy.png", m.Icon)
|
||||
require.Equal(t, "", m.ContentType)
|
||||
|
||||
require.Equal(t, 4, m.Priority)
|
||||
require.True(t, m.Time > time.Now().Unix()+29*60)
|
||||
require.True(t, m.Time < time.Now().Unix()+31*60)
|
||||
}
|
||||
|
||||
func TestServer_PublishAsJSON_Markdown(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
body := `{"topic":"mytopic","message":"**This is bold**","markdown":true}`
|
||||
response := request(t, s, "PUT", "/", body, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "mytopic", m.Topic)
|
||||
require.Equal(t, "**This is bold**", m.Message)
|
||||
require.Equal(t, "text/markdown", m.ContentType)
|
||||
}
|
||||
|
||||
func TestServer_PublishAsJSON_RateLimit_MessageDailyLimit(t *testing.T) {
|
||||
// Publishing as JSON follows a different path. This ensures that rate
|
||||
// limiting works for this endpoint as well
|
||||
@@ -2106,8 +2216,8 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
||||
start = time.Now()
|
||||
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
||||
m := toMessage(t, response.Body.String())
|
||||
assert.Equal(t, "some body", m.Message)
|
||||
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
||||
require.Equal(t, "some body", m.Message)
|
||||
require.True(t, time.Since(start) < 100*time.Millisecond)
|
||||
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
||||
|
||||
// Wait for all goroutines
|
||||
@@ -2399,6 +2509,184 @@ func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *t
|
||||
require.Nil(t, s.topics["announcements"].rateVisitor)
|
||||
}
|
||||
|
||||
func TestServer_MessageHistoryAndStatsEndpoint(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
c.ManagerInterval = 2 * time.Second
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Publish some messages, and get stats
|
||||
for i := 0; i < 5; i++ {
|
||||
response := request(t, s, "POST", "/mytopic", "some message", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
require.Equal(t, int64(5), s.messages)
|
||||
require.Equal(t, []int64{0}, s.messagesHistory)
|
||||
|
||||
response := request(t, s, "GET", "/v1/stats", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"messages":5,"messages_rate":0}`+"\n", response.Body.String())
|
||||
|
||||
// Run manager and see message history update
|
||||
s.execManager()
|
||||
require.Equal(t, []int64{0, 5}, s.messagesHistory)
|
||||
|
||||
response = request(t, s, "GET", "/v1/stats", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"messages":5,"messages_rate":2.5}`+"\n", response.Body.String()) // 5 messages in 2 seconds = 2.5 messages per second
|
||||
|
||||
// Publish some more messages
|
||||
for i := 0; i < 10; i++ {
|
||||
response := request(t, s, "POST", "/mytopic", "some message", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
}
|
||||
require.Equal(t, int64(15), s.messages)
|
||||
require.Equal(t, []int64{0, 5}, s.messagesHistory)
|
||||
|
||||
response = request(t, s, "GET", "/v1/stats", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"messages":15,"messages_rate":2.5}`+"\n", response.Body.String()) // Rate did not update yet
|
||||
|
||||
// Run manager and see message history update
|
||||
s.execManager()
|
||||
require.Equal(t, []int64{0, 5, 15}, s.messagesHistory)
|
||||
|
||||
response = request(t, s, "GET", "/v1/stats", "", nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"messages":15,"messages_rate":3.75}`+"\n", response.Body.String()) // 15 messages in 4 seconds = 3.75 messages per second
|
||||
}
|
||||
|
||||
func TestServer_MessageHistoryMaxSize(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
for i := 0; i < 20; i++ {
|
||||
s.messages = int64(i)
|
||||
s.execManager()
|
||||
}
|
||||
require.Equal(t, []int64{10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, s.messagesHistory)
|
||||
}
|
||||
|
||||
func TestServer_MessageCountPersistence(t *testing.T) {
|
||||
c := newTestConfig(t)
|
||||
s := newTestServer(t, c)
|
||||
s.messages = 1234
|
||||
s.execManager()
|
||||
waitFor(t, func() bool {
|
||||
messages, err := s.messageCache.Stats()
|
||||
require.Nil(t, err)
|
||||
return messages == 1234
|
||||
})
|
||||
|
||||
s = newTestServer(t, c)
|
||||
require.Equal(t, int64(1234), s.messages)
|
||||
}
|
||||
|
||||
func TestServer_PublishWithUTF8MimeHeader(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "some attachment", map[string]string{
|
||||
"X-Filename": "some =?UTF-8?q?=C3=A4?=ttachment.txt",
|
||||
"X-Message": "=?UTF-8?B?8J+HqfCfh6o=?=",
|
||||
"X-Title": "=?UTF-8?B?bnRmeSDlvojmo5I=?=, no really I mean it! =?UTF-8?Q?This is q=C3=BC=C3=B6ted-print=C3=A4ble.?=",
|
||||
"X-Tags": "=?UTF-8?B?8J+HqfCfh6o=?=, =?UTF-8?B?bnRmeSDlvojmo5I=?=",
|
||||
"X-Click": "=?uTf-8?b?aHR0cHM6Ly/wn5KpLmxh?=",
|
||||
"X-Actions": "http, \"=?utf-8?q?Mettre =C3=A0 jour?=\", \"https://my.tld/webhook/netbird-update\"; =?utf-8?b?aHR0cCwg6L+Z5piv5LiA5Liq5qCH562+LCBodHRwczovL/CfkqkubGE=?=",
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.Equal(t, "🇩🇪", m.Message)
|
||||
require.Equal(t, "ntfy 很棒, no really I mean it! This is qüöted-printäble.", m.Title)
|
||||
require.Equal(t, "some ättachment.txt", m.Attachment.Name)
|
||||
require.Equal(t, "🇩🇪", m.Tags[0])
|
||||
require.Equal(t, "ntfy 很棒", m.Tags[1])
|
||||
require.Equal(t, "https://💩.la", m.Click)
|
||||
require.Equal(t, "Mettre à jour", m.Actions[0].Label)
|
||||
require.Equal(t, "http", m.Actions[1].Action)
|
||||
require.Equal(t, "这是一个标签", m.Actions[1].Label)
|
||||
require.Equal(t, "https://💩.la", m.Actions[1].URL)
|
||||
}
|
||||
|
||||
func TestServer_UpstreamBaseURL_Success(t *testing.T) {
|
||||
var pollID atomic.Pointer[string]
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/87c9cddf7b0105f5fe849bf084c6e600be0fde99be3223335199b4965bd7b735", r.URL.Path)
|
||||
require.Equal(t, "", string(body))
|
||||
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
|
||||
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.BaseURL = "http://myserver.internal"
|
||||
c.UpstreamBaseURL = upstreamServer.URL
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Send message, and wait for upstream server to receive it
|
||||
response := request(t, s, "PUT", "/mytopic", `hi there`, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, m.ID)
|
||||
require.Equal(t, "hi there", m.Message)
|
||||
waitFor(t, func() bool {
|
||||
pID := pollID.Load()
|
||||
return pID != nil && *pID == m.ID
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_UpstreamBaseURL_With_Access_Token_Success(t *testing.T) {
|
||||
var pollID atomic.Pointer[string]
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/a1c72bcb4daf5af54d13ef86aea8f76c11e8b88320d55f1811d5d7b173bcc1df", r.URL.Path)
|
||||
require.Equal(t, "Bearer tk_1234567890", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "", string(body))
|
||||
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
|
||||
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.BaseURL = "http://myserver.internal"
|
||||
c.UpstreamBaseURL = upstreamServer.URL
|
||||
c.UpstreamAccessToken = "tk_1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Send message, and wait for upstream server to receive it
|
||||
response := request(t, s, "PUT", "/mytopic1", `hi there`, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, m.ID)
|
||||
require.Equal(t, "hi there", m.Message)
|
||||
waitFor(t, func() bool {
|
||||
pID := pollID.Load()
|
||||
return pID != nil && *pID == m.ID
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_UpstreamBaseURL_DoNotForwardUnifiedPush(t *testing.T) {
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("UnifiedPush messages should not be forwarded")
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.BaseURL = "http://myserver.internal"
|
||||
c.UpstreamBaseURL = upstreamServer.URL
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Send UP message, this should not forward to upstream server
|
||||
response := request(t, s, "PUT", "/mytopic?up=1", `hi there`, nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
m := toMessage(t, response.Body.String())
|
||||
require.NotEmpty(t, m.ID)
|
||||
require.Equal(t, "hi there", m.Message)
|
||||
|
||||
// Forwarding is done asynchronously, so wait a bit.
|
||||
// This ensures that the t.Fatal above is actually not triggered.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *Config {
|
||||
conf := NewConfig()
|
||||
conf.BaseURL = "http://127.0.0.1:12345"
|
||||
@@ -2408,19 +2696,33 @@ func newTestConfig(t *testing.T) *Config {
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
func configureAuth(t *testing.T, conf *Config) *Config {
|
||||
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
|
||||
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
|
||||
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithAuthFile(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
conf = configureAuth(t, conf)
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestConfigWithWebPush(t *testing.T) *Config {
|
||||
conf := newTestConfig(t)
|
||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||
require.Nil(t, err)
|
||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||
conf.WebPushEmailAddress = "testing@example.com"
|
||||
conf.WebPushPrivateKey = privateKey
|
||||
conf.WebPushPublicKey = publicKey
|
||||
return conf
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, config *Config) *Server {
|
||||
server, err := New(config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Nil(t, err)
|
||||
return server
|
||||
}
|
||||
|
||||
@@ -2500,7 +2802,7 @@ func waitForWithMaxWait(t *testing.T, maxWait time.Duration, f func() bool) {
|
||||
if f() {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("Function f did not succeed after %v: %v", maxWait, string(debug.Stack()))
|
||||
}
|
||||
|
||||
176
server/server_twilio.go
Normal file
176
server/server_twilio.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
twilioCallFormat = `
|
||||
<Response>
|
||||
<Pause length="1"/>
|
||||
<Say loop="3">
|
||||
You have a message from notify on topic %s. Message:
|
||||
<break time="1s"/>
|
||||
%s
|
||||
<break time="1s"/>
|
||||
End of message.
|
||||
<break time="1s"/>
|
||||
This message was sent by user %s. It will be repeated three times.
|
||||
To unsubscribe from calls like this, remove your phone number in the notify web app.
|
||||
<break time="3s"/>
|
||||
</Say>
|
||||
<Say>Goodbye.</Say>
|
||||
</Response>`
|
||||
)
|
||||
|
||||
// convertPhoneNumber checks if the given phone number is verified for the given user, and if so, returns the verified
|
||||
// phone number. It also converts a boolean string ("yes", "1", "true") to the first verified phone number.
|
||||
// If the user is anonymous, it will return an error.
|
||||
func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *errHTTP) {
|
||||
if u == nil {
|
||||
return "", errHTTPBadRequestAnonymousCallsNotAllowed
|
||||
}
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
if err != nil {
|
||||
return "", errHTTPInternalError
|
||||
} else if len(phoneNumbers) == 0 {
|
||||
return "", errHTTPBadRequestPhoneNumberNotVerified
|
||||
}
|
||||
if toBool(phoneNumber) {
|
||||
return phoneNumbers[0], nil
|
||||
} else if util.Contains(phoneNumbers, phoneNumber) {
|
||||
return phoneNumber, nil
|
||||
}
|
||||
for _, p := range phoneNumbers {
|
||||
if p == phoneNumber {
|
||||
return phoneNumber, nil
|
||||
}
|
||||
}
|
||||
return "", errHTTPBadRequestPhoneNumberNotVerified
|
||||
}
|
||||
|
||||
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
|
||||
// Failures will be logged, but not returned to the caller.
|
||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
|
||||
u, sender := v.User(), m.Sender.String()
|
||||
if u != nil {
|
||||
sender = u.Name
|
||||
}
|
||||
body := fmt.Sprintf(twilioCallFormat, xmlEscapeText(m.Topic), xmlEscapeText(m.Message), xmlEscapeText(sender))
|
||||
data := url.Values{}
|
||||
data.Set("From", s.config.TwilioPhoneNumber)
|
||||
data.Set("To", to)
|
||||
data.Set("Twiml", body)
|
||||
ev := logvrm(v, r, m).Tag(tagTwilio).Field("twilio_to", to).FieldIf("twilio_body", body, log.TraceLevel).Debug("Sending Twilio request")
|
||||
response, err := s.callPhoneInternal(data)
|
||||
if err != nil {
|
||||
ev.Field("twilio_response", response).Err(err).Warn("Error sending Twilio request")
|
||||
minc(metricCallsMadeFailure)
|
||||
return
|
||||
}
|
||||
ev.FieldIf("twilio_response", response, log.TraceLevel).Debug("Received successful Twilio response")
|
||||
minc(metricCallsMadeSuccess)
|
||||
}
|
||||
|
||||
func (s *Server) callPhoneInternal(data url.Values) (string, error) {
|
||||
requestURL := fmt.Sprintf("%s/2010-04-01/Accounts/%s/Calls.json", s.config.TwilioCallsBaseURL, s.config.TwilioAccount)
|
||||
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(response), nil
|
||||
}
|
||||
|
||||
func (s *Server) verifyPhoneNumber(v *visitor, r *http.Request, phoneNumber, channel string) error {
|
||||
ev := logvr(v, r).Tag(tagTwilio).Field("twilio_to", phoneNumber).Field("twilio_channel", channel).Debug("Sending phone verification")
|
||||
data := url.Values{}
|
||||
data.Set("To", phoneNumber)
|
||||
data.Set("Channel", channel)
|
||||
requestURL := fmt.Sprintf("%s/v2/Services/%s/Verifications", s.config.TwilioVerifyBaseURL, s.config.TwilioVerifyService)
|
||||
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
ev.Err(err).Warn("Error sending Twilio phone verification request")
|
||||
return err
|
||||
}
|
||||
ev.FieldIf("twilio_response", string(response), log.TraceLevel).Debug("Received Twilio phone verification response")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) verifyPhoneNumberCheck(v *visitor, r *http.Request, phoneNumber, code string) error {
|
||||
ev := logvr(v, r).Tag(tagTwilio).Field("twilio_to", phoneNumber).Debug("Checking phone verification")
|
||||
data := url.Values{}
|
||||
data.Set("To", phoneNumber)
|
||||
data.Set("Code", code)
|
||||
requestURL := fmt.Sprintf("%s/v2/Services/%s/VerificationCheck", s.config.TwilioVerifyBaseURL, s.config.TwilioVerifyService)
|
||||
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if resp.StatusCode != http.StatusOK {
|
||||
if ev.IsTrace() {
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ev.Field("twilio_response", string(response))
|
||||
}
|
||||
ev.Warn("Twilio phone verification failed with status code %d", resp.StatusCode)
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return errHTTPGonePhoneVerificationExpired
|
||||
}
|
||||
return errHTTPInternalError
|
||||
}
|
||||
response, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ev.IsTrace() {
|
||||
ev.Field("twilio_response", string(response)).Trace("Received successful Twilio phone verification response")
|
||||
} else if ev.IsDebug() {
|
||||
ev.Debug("Received successful Twilio phone verification response")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func xmlEscapeText(text string) string {
|
||||
var buf bytes.Buffer
|
||||
_ = xml.EscapeText(&buf, []byte(text))
|
||||
return buf.String()
|
||||
}
|
||||
264
server/server_twilio_test.go
Normal file
264
server/server_twilio_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
||||
var called, verified atomic.Bool
|
||||
var code atomic.Pointer[string]
|
||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||
if code.Load() != nil {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||
code.Store(util.String("123456"))
|
||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||
if verified.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||
verified.Store(true)
|
||||
} else {
|
||||
t.Fatal("Unexpected path:", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer twilioVerifyServer.Close()
|
||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioCallsServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
c.TwilioVerifyService = "VA1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Send verification code for phone number
|
||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return *code.Load() == "123456"
|
||||
})
|
||||
|
||||
// Add phone number with code
|
||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
waitFor(t, func() bool {
|
||||
return verified.Load()
|
||||
})
|
||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(phoneNumbers))
|
||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||
|
||||
// Do the thing
|
||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
|
||||
// Remove the phone number
|
||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
|
||||
// Verify the phone number is gone from the DB
|
||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(phoneNumbers))
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
||||
var called atomic.Bool
|
||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if called.Load() {
|
||||
t.Fatal("Should be only called once")
|
||||
}
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||
called.Store(true)
|
||||
}))
|
||||
defer twilioServer.Close()
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = twilioServer.URL
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "yes", // <<<------
|
||||
})
|
||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||
waitFor(t, func() bool {
|
||||
return called.Load()
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
// Add tier and user
|
||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessageLimit: 10,
|
||||
CallLimit: 1,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
// Do the thing
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"authorization": util.BasicAuth("phil", "phil"),
|
||||
"x-call": "+11122233344",
|
||||
})
|
||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+invalid",
|
||||
})
|
||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||
c.TwilioAccount = "AC1234567890"
|
||||
c.TwilioAuthToken = "AAEAA1234567890"
|
||||
c.TwilioPhoneNumber = "+1234567890"
|
||||
s := newTestServer(t, c)
|
||||
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+123123",
|
||||
})
|
||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||
"x-call": "+1234",
|
||||
})
|
||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||
}
|
||||
171
server/server_webpush.go
Normal file
171
server/server_webpush.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/SherClockHolmes/webpush-go"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
)
|
||||
|
||||
const (
|
||||
webPushTopicSubscribeLimit = 50
|
||||
)
|
||||
|
||||
var (
|
||||
webPushAllowedEndpointsPatterns = []string{
|
||||
"https://*.google.com/",
|
||||
"https://*.googleapis.com/",
|
||||
"https://*.mozilla.com/",
|
||||
"https://*.mozaws.net/",
|
||||
"https://*.windows.com/",
|
||||
"https://*.microsoft.com/",
|
||||
"https://*.apple.com/",
|
||||
}
|
||||
webPushAllowedEndpointsRegex *regexp.Regexp
|
||||
)
|
||||
|
||||
func init() {
|
||||
for i, pattern := range webPushAllowedEndpointsPatterns {
|
||||
webPushAllowedEndpointsPatterns[i] = strings.ReplaceAll(strings.ReplaceAll(pattern, ".", "\\."), "*", ".+")
|
||||
}
|
||||
allPatterns := fmt.Sprintf("^(%s)", strings.Join(webPushAllowedEndpointsPatterns, "|"))
|
||||
webPushAllowedEndpointsRegex = regexp.MustCompile(allPatterns)
|
||||
}
|
||||
|
||||
func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil || req.Endpoint == "" || req.P256dh == "" || req.Auth == "" {
|
||||
return errHTTPBadRequestWebPushSubscriptionInvalid
|
||||
} else if !webPushAllowedEndpointsRegex.MatchString(req.Endpoint) {
|
||||
return errHTTPBadRequestWebPushEndpointUnknown
|
||||
} else if len(req.Topics) > webPushTopicSubscribeLimit {
|
||||
return errHTTPBadRequestWebPushTopicCountTooHigh
|
||||
}
|
||||
topics, err := s.topicsFromIDs(req.Topics...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.userManager != nil {
|
||||
u := v.User()
|
||||
for _, t := range topics {
|
||||
if err := s.userManager.Authorize(u, t.ID, user.PermissionRead); err != nil {
|
||||
logvr(v, r).With(t).Err(err).Debug("Access to topic %s not authorized", t.ID)
|
||||
return errHTTPForbidden.With(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
||||
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
|
||||
if err != nil || req.Endpoint == "" {
|
||||
return errHTTPBadRequestWebPushSubscriptionInvalid
|
||||
}
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(req.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
||||
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
||||
if err != nil {
|
||||
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
||||
return
|
||||
}
|
||||
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m))
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
||||
return
|
||||
}
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload, v, m); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(v, m, subscription).Warn("Unable to publish web push message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) pruneAndNotifyWebPushSubscriptions() {
|
||||
if s.config.WebPushPublicKey == "" {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if err := s.pruneAndNotifyWebPushSubscriptionsInternal(); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).Warn("Unable to prune or notify web push subscriptions")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
|
||||
// Expire old subscriptions
|
||||
if err := s.webPush.RemoveExpiredSubscriptions(s.config.WebPushExpiryDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify subscriptions that will expire soon
|
||||
subscriptions, err := s.webPush.SubscriptionsExpiring(s.config.WebPushExpiryWarningDuration)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(subscriptions) == 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(newWebPushSubscriptionExpiringPayload())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
warningSent := make([]*webPushSubscription, 0)
|
||||
for _, subscription := range subscriptions {
|
||||
if err := s.sendWebPushNotification(subscription, payload); err != nil {
|
||||
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
|
||||
continue
|
||||
}
|
||||
warningSent = append(warningSent, subscription)
|
||||
}
|
||||
if err := s.webPush.MarkExpiryWarningSent(warningSent); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Tag(tagWebPush).Debug("Expired old subscriptions and published %d expiry imminent warnings", len(subscriptions))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
|
||||
payload := &webpush.Subscription{
|
||||
Endpoint: sub.Endpoint,
|
||||
Keys: webpush.Keys{
|
||||
Auth: sub.Auth,
|
||||
P256dh: sub.P256dh,
|
||||
},
|
||||
}
|
||||
resp, err := webpush.SendNotification(message, payload, &webpush.Options{
|
||||
Subscriber: s.config.WebPushEmailAddress,
|
||||
VAPIDPublicKey: s.config.WebPushPublicKey,
|
||||
VAPIDPrivateKey: s.config.WebPushPrivateKey,
|
||||
Urgency: webpush.UrgencyHigh, // iOS requires this to ensure delivery
|
||||
TTL: int(s.config.CacheDuration.Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Err(err).Debug("Unable to publish web push message, removing endpoint")
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
if (resp.StatusCode < 200 || resp.StatusCode > 299) && resp.StatusCode != 429 {
|
||||
log.Tag(tagWebPush).With(sub).With(contexters...).Field("response_code", resp.StatusCode).Debug("Unable to publish web push message, unexpected response")
|
||||
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
return errHTTPInternalErrorWebPushUnableToPublish.With(sub).With(contexters...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
256
server/server_webpush_test.go
Normal file
256
server/server_webpush_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
|
||||
)
|
||||
|
||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfig(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 404, response.Code)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "")
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
topicList := make([]string, 51)
|
||||
for i := range topicList {
|
||||
topicList[i] = util.RandomString(5)
|
||||
}
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 400, response.Code)
|
||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Delete(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||
}
|
||||
|
||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
config.AuthDefault = user.PermissionDenyAll
|
||||
s := newTestServer(t, config)
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||
require.Equal(t, 403, response.Code)
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||
s := newTestServer(t, config)
|
||||
|
||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||
|
||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
|
||||
require.Equal(t, 200, response.Code)
|
||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("ben", "ben"),
|
||||
})
|
||||
// should've been deleted with the account
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "/push-receive", r.URL.Path)
|
||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||
require.Equal(t, "", r.Header.Get("Topic"))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(http.StatusGone)
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||
|
||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||
|
||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||
}
|
||||
|
||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||
|
||||
var received atomic.Bool
|
||||
|
||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
require.Nil(t, err)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(``))
|
||||
received.Store(true)
|
||||
}))
|
||||
defer pushService.Close()
|
||||
|
||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||
|
||||
waitFor(t, func() bool {
|
||||
return received.Load()
|
||||
})
|
||||
|
||||
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
|
||||
require.Nil(t, err)
|
||||
|
||||
s.pruneAndNotifyWebPushSubscriptions()
|
||||
waitFor(t, func() bool {
|
||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
return len(subs) == 0
|
||||
})
|
||||
}
|
||||
|
||||
func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
|
||||
topicsJSON, err := json.Marshal(topics)
|
||||
require.Nil(t, err)
|
||||
|
||||
return fmt.Sprintf(`{
|
||||
"topics": %s,
|
||||
"endpoint": "%s",
|
||||
"p256dh": "p256dh-key",
|
||||
"auth": "auth-key"
|
||||
}`, topicsJSON, endpoint)
|
||||
}
|
||||
|
||||
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
|
||||
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
|
||||
}
|
||||
|
||||
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
|
||||
subs, err := s.webPush.SubscriptionsForTopic(topic)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, expectedLength)
|
||||
}
|
||||
@@ -4,14 +4,15 @@ import (
|
||||
_ "embed" // required by go:embed
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"mime"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
type mailer interface {
|
||||
@@ -131,31 +132,23 @@ This message was sent by {ip} at {time} via {topicURL}`
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed "mailer_emoji.json"
|
||||
//go:embed "mailer_emoji_map.json"
|
||||
emojisJSON string
|
||||
)
|
||||
|
||||
type emoji struct {
|
||||
Emoji string `json:"emoji"`
|
||||
Aliases []string `json:"aliases"`
|
||||
}
|
||||
|
||||
func toEmojis(tags []string) (emojisOut []string, tagsOut []string, err error) {
|
||||
var emojis []emoji
|
||||
if err = json.Unmarshal([]byte(emojisJSON), &emojis); err != nil {
|
||||
var emojiMap map[string]string
|
||||
if err = json.Unmarshal([]byte(emojisJSON), &emojiMap); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tagsOut = make([]string, 0)
|
||||
emojisOut = make([]string, 0)
|
||||
nextTag:
|
||||
for _, t := range tags { // TODO Super inefficient; we should just create a .json file with a map
|
||||
for _, e := range emojis {
|
||||
if util.Contains(e.Aliases, t) {
|
||||
emojisOut = append(emojisOut, e.Emoji)
|
||||
continue nextTag
|
||||
}
|
||||
for _, t := range tags {
|
||||
if emoji, ok := emojiMap[t]; ok {
|
||||
emojisOut = append(emojisOut, emoji)
|
||||
} else {
|
||||
tagsOut = append(tagsOut, t)
|
||||
}
|
||||
tagsOut = append(tagsOut, t)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"mime/quotedprintable"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -296,6 +297,8 @@ func readTextMailBody(reader io.Reader, contentType, transferEncoding string) (s
|
||||
func readPlainTextMailBody(reader io.Reader, transferEncoding string) (string, error) {
|
||||
if strings.ToLower(transferEncoding) == "base64" {
|
||||
reader = base64.NewDecoder(base64.StdEncoding, reader)
|
||||
} else if strings.ToLower(transferEncoding) == "quoted-printable" {
|
||||
reader = quotedprintable.NewReader(reader)
|
||||
}
|
||||
body, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
|
||||
@@ -303,6 +303,39 @@ BBBBBBBBBBBBBBBBBBBBBBBBB`
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Plaintext_QuotedPrintable(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: mytopic@ntfy.sh
|
||||
DATA
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
To: mytopic@ntfy.sh
|
||||
Content-Type: text/plain; charset="UTF-8"
|
||||
Content-Transfer-Encoding: quoted-printable
|
||||
|
||||
what's
|
||||
=C3=A0&=C3=A9"'(-=C3=A8_=C3=A7=C3=A0)
|
||||
=3D=3D=3D=3D=3D
|
||||
up
|
||||
.
|
||||
`
|
||||
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, `what's
|
||||
à&é"'(-è_çà)
|
||||
=====
|
||||
up`, readAll(t, r.Body))
|
||||
})
|
||||
conf.SMTPServerAddrPrefix = ""
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_Unsupported(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
@@ -390,6 +423,49 @@ L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_MultipartQuotedPrintable(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: phil@example.com
|
||||
RCPT TO: ntfy-mytopic@ntfy.sh
|
||||
DATA
|
||||
MIME-Version: 1.0
|
||||
Date: Tue, 28 Dec 2021 00:30:10 +0100
|
||||
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
|
||||
Subject: and one more
|
||||
From: Phil <phil@example.com>
|
||||
To: ntfy-mytopic@ntfy.sh
|
||||
Content-Type: multipart/alternative; boundary="000000000000f3320b05d42915c9"
|
||||
|
||||
--000000000000f3320b05d42915c9
|
||||
Content-Type: text/html; charset="UTF-8"
|
||||
|
||||
html, ignore me
|
||||
|
||||
--000000000000f3320b05d42915c9
|
||||
Content-Type: text/plain; charset="UTF-8"
|
||||
Content-Transfer-Encoding: quoted-printable
|
||||
|
||||
what's
|
||||
=C3=A0&=C3=A9"'(-=C3=A8_=C3=A7=C3=A0)
|
||||
=3D=3D=3D=3D=3D
|
||||
up
|
||||
|
||||
--000000000000f3320b05d42915c9--
|
||||
.
|
||||
`
|
||||
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/mytopic", r.URL.Path)
|
||||
require.Equal(t, "and one more", r.Header.Get("Title"))
|
||||
require.Equal(t, `what's
|
||||
à&é"'(-è_çà)
|
||||
=====
|
||||
up`, readAll(t, r.Body))
|
||||
})
|
||||
defer s.Close()
|
||||
defer c.Close()
|
||||
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
|
||||
}
|
||||
|
||||
func TestSmtpBackend_NestedMultipartBase64(t *testing.T) {
|
||||
email := `EHLO example.com
|
||||
MAIL FROM: test@mydomain.me
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -44,10 +45,16 @@ func newTopic(id string) *topic {
|
||||
}
|
||||
|
||||
// Subscribe subscribes to this topic
|
||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) (subscriberID int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscriberID := rand.Int()
|
||||
for i := 0; i < 5; i++ { // Best effort retry
|
||||
subscriberID = rand.Int()
|
||||
_, exists := t.subscribers[subscriberID]
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
}
|
||||
t.subscribers[subscriberID] = &topicSubscriber{
|
||||
userID: userID, // May be empty
|
||||
subscriber: s,
|
||||
@@ -134,24 +141,40 @@ func (t *topic) Keepalive() {
|
||||
t.lastAccess = time.Now()
|
||||
}
|
||||
|
||||
// CancelSubscribers calls the cancel function for all subscribers, forcing
|
||||
func (t *topic) CancelSubscribers(exceptUserID string) {
|
||||
// CancelSubscribersExceptUser calls the cancel function for all subscribers, forcing
|
||||
func (t *topic) CancelSubscribersExceptUser(exceptUserID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, s := range t.subscribers {
|
||||
if s.userID != exceptUserID {
|
||||
log.
|
||||
Tag(tagSubscribe).
|
||||
With(t).
|
||||
Fields(log.Context{
|
||||
"user_id": s.userID,
|
||||
}).
|
||||
Debug("Canceling subscriber %s", s.userID)
|
||||
s.cancel()
|
||||
t.cancelUserSubscriber(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelSubscriberUser kills the subscriber with the given user ID
|
||||
func (t *topic) CancelSubscriberUser(userID string) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
for _, s := range t.subscribers {
|
||||
if s.userID == userID {
|
||||
t.cancelUserSubscriber(s)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *topic) cancelUserSubscriber(s *topicSubscriber) {
|
||||
log.
|
||||
Tag(tagSubscribe).
|
||||
With(t).
|
||||
Fields(log.Context{
|
||||
"user_id": s.userID,
|
||||
}).
|
||||
Debug("Canceling subscriber with user ID %s", s.userID)
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (t *topic) Context() log.Context {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTopic_CancelSubscribers(t *testing.T) {
|
||||
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
@@ -25,11 +27,34 @@ func TestTopic_CancelSubscribers(t *testing.T) {
|
||||
to.Subscribe(subFn, "", cancelFn1)
|
||||
to.Subscribe(subFn, "u_phil", cancelFn2)
|
||||
|
||||
to.CancelSubscribers("u_phil")
|
||||
to.CancelSubscribersExceptUser("u_phil")
|
||||
require.True(t, canceled1.Load())
|
||||
require.False(t, canceled2.Load())
|
||||
}
|
||||
|
||||
func TestTopic_CancelSubscribersUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
return nil
|
||||
}
|
||||
canceled1 := atomic.Bool{}
|
||||
cancelFn1 := func() {
|
||||
canceled1.Store(true)
|
||||
}
|
||||
canceled2 := atomic.Bool{}
|
||||
cancelFn2 := func() {
|
||||
canceled2.Store(true)
|
||||
}
|
||||
to := newTopic("mytopic")
|
||||
to.Subscribe(subFn, "u_another", cancelFn1)
|
||||
to.Subscribe(subFn, "u_phil", cancelFn2)
|
||||
|
||||
to.CancelSubscriberUser("u_phil")
|
||||
require.False(t, canceled1.Load())
|
||||
require.True(t, canceled2.Load())
|
||||
}
|
||||
|
||||
func TestTopic_Keepalive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -39,3 +64,29 @@ func TestTopic_Keepalive(t *testing.T) {
|
||||
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
|
||||
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
|
||||
}
|
||||
|
||||
func TestTopic_Subscribe_DuplicateID(t *testing.T) {
|
||||
t.Parallel()
|
||||
to := newTopic("mytopic")
|
||||
|
||||
// Fix random seed to force same number generation
|
||||
rand.Seed(1)
|
||||
a := rand.Int()
|
||||
to.subscribers[a] = &topicSubscriber{
|
||||
userID: "a",
|
||||
subscriber: nil,
|
||||
cancel: func() {},
|
||||
}
|
||||
|
||||
subFn := func(v *visitor, msg *message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Force rand.Int to generate the same id once more
|
||||
rand.Seed(1)
|
||||
id := to.Subscribe(subFn, "b", func() {})
|
||||
res := to.subscribers[id]
|
||||
|
||||
require.NotEqual(t, id, a)
|
||||
require.Equal(t, "b", res.userID, "b")
|
||||
}
|
||||
|
||||
171
server/types.go
171
server/types.go
@@ -1,12 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
|
||||
"heckel.io/ntfy/util"
|
||||
)
|
||||
|
||||
@@ -24,23 +25,24 @@ const (
|
||||
|
||||
// message represents a message published to a topic
|
||||
type message struct {
|
||||
ID string `json:"id"` // Random message ID
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*action `json:"actions,omitempty"`
|
||||
Attachment *attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // Username of the uploader, used to associated attachments
|
||||
ID string `json:"id"` // Random message ID
|
||||
Time int64 `json:"time"` // Unix time in seconds
|
||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||
Event string `json:"event"` // One of the above
|
||||
Topic string `json:"topic"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Click string `json:"click,omitempty"`
|
||||
Icon string `json:"icon,omitempty"`
|
||||
Actions []*action `json:"actions,omitempty"`
|
||||
Attachment *attachment `json:"attachment,omitempty"`
|
||||
PollID string `json:"poll_id,omitempty"`
|
||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
|
||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||
}
|
||||
|
||||
func (m *message) Context() log.Context {
|
||||
@@ -99,8 +101,10 @@ type publishMessage struct {
|
||||
Icon string `json:"icon"`
|
||||
Actions []action `json:"actions"`
|
||||
Attach string `json:"attach"`
|
||||
Markdown bool `json:"markdown"`
|
||||
Filename string `json:"filename"`
|
||||
Email string `json:"email"`
|
||||
Call string `json:"call"`
|
||||
Delay string `json:"delay"`
|
||||
}
|
||||
|
||||
@@ -239,6 +243,45 @@ type apiHealthResponse struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type apiStatsResponse struct {
|
||||
Messages int64 `json:"messages"`
|
||||
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
|
||||
}
|
||||
|
||||
type apiUserAddRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Tier string `json:"tier"`
|
||||
// Do not add 'role' here. We don't want to add admins via the API.
|
||||
}
|
||||
|
||||
type apiUserResponse struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
Tier string `json:"tier,omitempty"`
|
||||
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
|
||||
}
|
||||
|
||||
type apiUserGrantResponse struct {
|
||||
Topic string `json:"topic"` // This may be a pattern
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
type apiUserDeleteRequest struct {
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type apiAccessAllowRequest struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"` // This may be a pattern
|
||||
Permission string `json:"permission"`
|
||||
}
|
||||
|
||||
type apiAccessResetRequest struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
type apiAccountCreateRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
@@ -272,6 +315,16 @@ type apiAccountTokenResponse struct {
|
||||
Expires int64 `json:"expires,omitempty"` // Unix timestamp
|
||||
}
|
||||
|
||||
type apiAccountPhoneNumberVerifyRequest struct {
|
||||
Number string `json:"number"`
|
||||
Channel string `json:"channel"`
|
||||
}
|
||||
|
||||
type apiAccountPhoneNumberAddRequest struct {
|
||||
Number string `json:"number"`
|
||||
Code string `json:"code"` // Only set when adding a phone number
|
||||
}
|
||||
|
||||
type apiAccountTier struct {
|
||||
Code string `json:"code"`
|
||||
Name string `json:"name"`
|
||||
@@ -282,6 +335,7 @@ type apiAccountLimits struct {
|
||||
Messages int64 `json:"messages"`
|
||||
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
|
||||
Emails int64 `json:"emails"`
|
||||
Calls int64 `json:"calls"`
|
||||
Reservations int64 `json:"reservations"`
|
||||
AttachmentTotalSize int64 `json:"attachment_total_size"`
|
||||
AttachmentFileSize int64 `json:"attachment_file_size"`
|
||||
@@ -294,6 +348,8 @@ type apiAccountStats struct {
|
||||
MessagesRemaining int64 `json:"messages_remaining"`
|
||||
Emails int64 `json:"emails"`
|
||||
EmailsRemaining int64 `json:"emails_remaining"`
|
||||
Calls int64 `json:"calls"`
|
||||
CallsRemaining int64 `json:"calls_remaining"`
|
||||
Reservations int64 `json:"reservations"`
|
||||
ReservationsRemaining int64 `json:"reservations_remaining"`
|
||||
AttachmentTotalSize int64 `json:"attachment_total_size"`
|
||||
@@ -323,6 +379,7 @@ type apiAccountResponse struct {
|
||||
Subscriptions []*user.Subscription `json:"subscriptions,omitempty"`
|
||||
Reservations []*apiAccountReservation `json:"reservations,omitempty"`
|
||||
Tokens []*apiAccountTokenResponse `json:"tokens,omitempty"`
|
||||
PhoneNumbers []string `json:"phone_numbers,omitempty"`
|
||||
Tier *apiAccountTier `json:"tier,omitempty"`
|
||||
Limits *apiAccountLimits `json:"limits,omitempty"`
|
||||
Stats *apiAccountStats `json:"stats,omitempty"`
|
||||
@@ -340,8 +397,12 @@ type apiConfigResponse struct {
|
||||
EnableLogin bool `json:"enable_login"`
|
||||
EnableSignup bool `json:"enable_signup"`
|
||||
EnablePayments bool `json:"enable_payments"`
|
||||
EnableCalls bool `json:"enable_calls"`
|
||||
EnableEmails bool `json:"enable_emails"`
|
||||
EnableReservations bool `json:"enable_reservations"`
|
||||
EnableWebPush bool `json:"enable_web_push"`
|
||||
BillingContact string `json:"billing_contact"`
|
||||
WebPushPublicKey string `json:"web_push_public_key"`
|
||||
DisallowedTopics []string `json:"disallowed_topics"`
|
||||
}
|
||||
|
||||
@@ -406,3 +467,75 @@ type apiStripeSubscriptionDeletedEvent struct {
|
||||
ID string `json:"id"`
|
||||
Customer string `json:"customer"`
|
||||
}
|
||||
|
||||
type apiWebPushUpdateSubscriptionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
Auth string `json:"auth"`
|
||||
P256dh string `json:"p256dh"`
|
||||
Topics []string `json:"topics"`
|
||||
}
|
||||
|
||||
// List of possible Web Push events (see sw.js)
|
||||
const (
|
||||
webPushMessageEvent = "message"
|
||||
webPushExpiringEvent = "subscription_expiring"
|
||||
)
|
||||
|
||||
type webPushPayload struct {
|
||||
Event string `json:"event"`
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
Message *message `json:"message"`
|
||||
}
|
||||
|
||||
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
|
||||
return &webPushPayload{
|
||||
Event: webPushMessageEvent,
|
||||
SubscriptionID: subscriptionID,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
type webPushControlMessagePayload struct {
|
||||
Event string `json:"event"`
|
||||
}
|
||||
|
||||
func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
|
||||
return &webPushControlMessagePayload{
|
||||
Event: webPushExpiringEvent,
|
||||
}
|
||||
}
|
||||
|
||||
type webPushSubscription struct {
|
||||
ID string
|
||||
Endpoint string
|
||||
Auth string
|
||||
P256dh string
|
||||
UserID string
|
||||
}
|
||||
|
||||
func (w *webPushSubscription) Context() log.Context {
|
||||
return map[string]any{
|
||||
"web_push_subscription_id": w.ID,
|
||||
"web_push_subscription_user_id": w.UserID,
|
||||
"web_push_subscription_endpoint": w.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// https://developer.mozilla.org/en-US/docs/Web/Manifest
|
||||
type webManifestResponse struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ShortName string `json:"short_name"`
|
||||
Scope string `json:"scope"`
|
||||
StartURL string `json:"start_url"`
|
||||
Display string `json:"display"`
|
||||
BackgroundColor string `json:"background_color"`
|
||||
ThemeColor string `json:"theme_color"`
|
||||
Icons []*webManifestIcon `json:"icons"`
|
||||
}
|
||||
|
||||
type webManifestIcon struct {
|
||||
SRC string `json:"src"`
|
||||
Sizes string `json:"sizes"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
@@ -5,16 +5,31 @@ import (
|
||||
"fmt"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
mimeDecoder mime.WordDecoder
|
||||
priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
|
||||
)
|
||||
|
||||
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
|
||||
value := strings.ToLower(readParam(r, names...))
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return toBool(value)
|
||||
}
|
||||
|
||||
func isBoolValue(value string) bool {
|
||||
return value == "1" || value == "yes" || value == "true" || value == "0" || value == "no" || value == "false"
|
||||
}
|
||||
|
||||
func toBool(value string) bool {
|
||||
return value == "1" || value == "yes" || value == "true"
|
||||
}
|
||||
|
||||
@@ -39,9 +54,9 @@ func readParam(r *http.Request, names ...string) string {
|
||||
|
||||
func readHeaderParam(r *http.Request, names ...string) string {
|
||||
for _, name := range names {
|
||||
value := r.Header.Get(name)
|
||||
value := strings.TrimSpace(maybeDecodeHeader(name, r.Header.Get(name)))
|
||||
if value != "" {
|
||||
return strings.TrimSpace(value)
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
@@ -114,3 +129,27 @@ func fromContext[T any](r *http.Request, key contextKey) (T, error) {
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// maybeDecodeHeader decodes the given header value if it is MIME encoded, e.g. "=?utf-8?q?Hello_World?=",
|
||||
// or returns the original header value if it is not MIME encoded. It also calls maybeIgnoreSpecialHeader
|
||||
// to ignore new HTTP "Priority" header.
|
||||
func maybeDecodeHeader(name, value string) string {
|
||||
decoded, err := mimeDecoder.DecodeHeader(value)
|
||||
if err != nil {
|
||||
return maybeIgnoreSpecialHeader(name, value)
|
||||
}
|
||||
return maybeIgnoreSpecialHeader(name, decoded)
|
||||
}
|
||||
|
||||
// maybeIgnoreSpecialHeader ignores new HTTP "Priority" header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
|
||||
//
|
||||
// Cloudflare (and potentially other providers) add this to requests when forwarding to the backend (ntfy),
|
||||
// so we just ignore it. If the "Priority" header is set to "u=*, i" or "u=*" (by Cloudflare), the header will be ignored.
|
||||
// Returning an empty string will allow the rest of the logic to continue searching for another header (x-priority, prio, p),
|
||||
// or in the Query parameters.
|
||||
func maybeIgnoreSpecialHeader(name, value string) string {
|
||||
if strings.ToLower(name) == "priority" && priorityHeaderIgnoreRegex.MatchString(strings.TrimSpace(value)) {
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -75,3 +75,16 @@ Accept: */*
|
||||
(peeked bytes not UTF-8, peek limit of 4096 bytes reached, hex: ` + fmt.Sprintf("%x", body[:4096]) + ` ...)`
|
||||
require.Equal(t, expected, renderHTTPRequest(r))
|
||||
}
|
||||
|
||||
func TestMaybeIgnoreSpecialHeader(t *testing.T) {
|
||||
require.Empty(t, maybeIgnoreSpecialHeader("priority", "u=1"))
|
||||
require.Empty(t, maybeIgnoreSpecialHeader("Priority", "u=1"))
|
||||
require.Empty(t, maybeIgnoreSpecialHeader("Priority", "u=1, i"))
|
||||
}
|
||||
|
||||
func TestMaybeDecodeHeaders(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "http://ntfy.sh/mytopic/json?since=all", nil)
|
||||
r.Header.Set("Priority", "u=1") // Cloudflare priority header
|
||||
r.Header.Set("X-Priority", "5") // ntfy priority header
|
||||
require.Equal(t, "5", readHeaderParam(r, "x-priority", "priority", "p"))
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ const (
|
||||
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
|
||||
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
|
||||
visitorDefaultReservationsLimit = int64(0)
|
||||
|
||||
// visitorDefaultCallsLimit is the amount of calls a user without a tier is allowed to make.
|
||||
// This number is zero, because phone numbers have to be verified first.
|
||||
visitorDefaultCallsLimit = int64(0)
|
||||
)
|
||||
|
||||
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
|
||||
@@ -56,6 +60,7 @@ type visitor struct {
|
||||
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
|
||||
messagesLimiter *util.FixedLimiter // Rate limiter for messages
|
||||
emailsLimiter *util.RateLimiter // Rate limiter for emails
|
||||
callsLimiter *util.FixedLimiter // Rate limiter for calls
|
||||
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
|
||||
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
|
||||
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
|
||||
@@ -79,6 +84,7 @@ type visitorLimits struct {
|
||||
EmailLimit int64
|
||||
EmailLimitBurst int
|
||||
EmailLimitReplenish rate.Limit
|
||||
CallLimit int64
|
||||
ReservationsLimit int64
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
@@ -91,6 +97,8 @@ type visitorStats struct {
|
||||
MessagesRemaining int64
|
||||
Emails int64
|
||||
EmailsRemaining int64
|
||||
Calls int64
|
||||
CallsRemaining int64
|
||||
Reservations int64
|
||||
ReservationsRemaining int64
|
||||
AttachmentTotalSize int64
|
||||
@@ -107,10 +115,11 @@ const (
|
||||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messages, emails int64
|
||||
var messages, emails, calls int64
|
||||
if user != nil {
|
||||
messages = user.Stats.Messages
|
||||
emails = user.Stats.Emails
|
||||
calls = user.Stats.Calls
|
||||
}
|
||||
v := &visitor{
|
||||
config: conf,
|
||||
@@ -124,11 +133,12 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
||||
requestLimiter: nil, // Set in resetLimiters
|
||||
messagesLimiter: nil, // Set in resetLimiters, may be nil
|
||||
emailsLimiter: nil, // Set in resetLimiters
|
||||
callsLimiter: nil, // Set in resetLimiters, may be nil
|
||||
bandwidthLimiter: nil, // Set in resetLimiters
|
||||
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||
authLimiter: nil, // Set in resetLimiters, may be nil
|
||||
}
|
||||
v.resetLimitersNoLock(messages, emails, false)
|
||||
v.resetLimitersNoLock(messages, emails, calls, false)
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -147,12 +157,19 @@ func (v *visitor) contextNoLock() log.Context {
|
||||
"visitor_messages": info.Stats.Messages,
|
||||
"visitor_messages_limit": info.Limits.MessageLimit,
|
||||
"visitor_messages_remaining": info.Stats.MessagesRemaining,
|
||||
"visitor_emails": info.Stats.Emails,
|
||||
"visitor_emails_limit": info.Limits.EmailLimit,
|
||||
"visitor_emails_remaining": info.Stats.EmailsRemaining,
|
||||
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
|
||||
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
|
||||
}
|
||||
if v.config.SMTPSenderFrom != "" {
|
||||
fields["visitor_emails"] = info.Stats.Emails
|
||||
fields["visitor_emails_limit"] = info.Limits.EmailLimit
|
||||
fields["visitor_emails_remaining"] = info.Stats.EmailsRemaining
|
||||
}
|
||||
if v.config.TwilioAccount != "" {
|
||||
fields["visitor_calls"] = info.Stats.Calls
|
||||
fields["visitor_calls_limit"] = info.Limits.CallLimit
|
||||
fields["visitor_calls_remaining"] = info.Stats.CallsRemaining
|
||||
}
|
||||
if v.authLimiter != nil {
|
||||
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
|
||||
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
|
||||
@@ -216,6 +233,12 @@ func (v *visitor) EmailAllowed() bool {
|
||||
return v.emailsLimiter.Allow()
|
||||
}
|
||||
|
||||
func (v *visitor) CallAllowed() bool {
|
||||
v.mu.RLock() // limiters could be replaced!
|
||||
defer v.mu.RUnlock()
|
||||
return v.callsLimiter.Allow()
|
||||
}
|
||||
|
||||
func (v *visitor) SubscriptionAllowed() bool {
|
||||
v.mu.RLock() // limiters could be replaced!
|
||||
defer v.mu.RUnlock()
|
||||
@@ -296,6 +319,7 @@ func (v *visitor) Stats() *user.Stats {
|
||||
return &user.Stats{
|
||||
Messages: v.messagesLimiter.Value(),
|
||||
Emails: v.emailsLimiter.Value(),
|
||||
Calls: v.callsLimiter.Value(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,6 +328,7 @@ func (v *visitor) ResetStats() {
|
||||
defer v.mu.RUnlock()
|
||||
v.emailsLimiter.Reset()
|
||||
v.messagesLimiter.Reset()
|
||||
v.callsLimiter.Reset()
|
||||
}
|
||||
|
||||
// User returns the visitor user, or nil if there is none
|
||||
@@ -334,11 +359,11 @@ func (v *visitor) SetUser(u *user.User) {
|
||||
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||
v.user = u // u may be nil!
|
||||
if shouldResetLimiters {
|
||||
var messages, emails int64
|
||||
var messages, emails, calls int64
|
||||
if u != nil {
|
||||
messages, emails = u.Stats.Messages, u.Stats.Emails
|
||||
messages, emails, calls = u.Stats.Messages, u.Stats.Emails, u.Stats.Calls
|
||||
}
|
||||
v.resetLimitersNoLock(messages, emails, true)
|
||||
v.resetLimitersNoLock(messages, emails, calls, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,11 +378,12 @@ func (v *visitor) MaybeUserID() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) {
|
||||
func (v *visitor) resetLimitersNoLock(messages, emails, calls int64, enqueueUpdate bool) {
|
||||
limits := v.limitsNoLock()
|
||||
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
||||
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)
|
||||
v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails)
|
||||
v.callsLimiter = util.NewFixedLimiterWithValue(limits.CallLimit, calls)
|
||||
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
|
||||
if v.user == nil {
|
||||
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||
@@ -370,6 +396,7 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
|
||||
go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{
|
||||
Messages: messages,
|
||||
Emails: emails,
|
||||
Calls: calls,
|
||||
})
|
||||
}
|
||||
log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters
|
||||
@@ -398,6 +425,7 @@ func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
|
||||
EmailLimit: tier.EmailLimit,
|
||||
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
|
||||
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
|
||||
CallLimit: tier.CallLimit,
|
||||
ReservationsLimit: tier.ReservationLimit,
|
||||
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
|
||||
@@ -420,6 +448,7 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits {
|
||||
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
|
||||
EmailLimitBurst: conf.VisitorEmailLimitBurst,
|
||||
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
|
||||
CallLimit: visitorDefaultCallsLimit,
|
||||
ReservationsLimit: visitorDefaultReservationsLimit,
|
||||
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
|
||||
@@ -465,12 +494,15 @@ func (v *visitor) Info() (*visitorInfo, error) {
|
||||
func (v *visitor) infoLightNoLock() *visitorInfo {
|
||||
messages := v.messagesLimiter.Value()
|
||||
emails := v.emailsLimiter.Value()
|
||||
calls := v.callsLimiter.Value()
|
||||
limits := v.limitsNoLock()
|
||||
stats := &visitorStats{
|
||||
Messages: messages,
|
||||
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
|
||||
Emails: emails,
|
||||
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
|
||||
Calls: calls,
|
||||
CallsRemaining: zeroIfNegative(limits.CallLimit - calls),
|
||||
}
|
||||
return &visitorInfo{
|
||||
Limits: limits,
|
||||
|
||||
280
server/webpush_store.go
Normal file
280
server/webpush_store.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionIDPrefix = "wps_"
|
||||
subscriptionIDLength = 10
|
||||
subscriptionEndpointLimitPerSubscriberIP = 10
|
||||
)
|
||||
|
||||
var (
|
||||
errWebPushNoRows = errors.New("no rows found")
|
||||
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
|
||||
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
|
||||
)
|
||||
|
||||
const (
|
||||
createWebPushSubscriptionsTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS subscription (
|
||||
id TEXT PRIMARY KEY,
|
||||
endpoint TEXT NOT NULL,
|
||||
key_auth TEXT NOT NULL,
|
||||
key_p256dh TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
subscriber_ip TEXT NOT NULL,
|
||||
updated_at INT NOT NULL,
|
||||
warned_at INT NOT NULL DEFAULT 0
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
|
||||
CREATE TABLE IF NOT EXISTS subscription_topic (
|
||||
subscription_id TEXT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
PRIMARY KEY (subscription_id, topic),
|
||||
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
COMMIT;
|
||||
`
|
||||
builtinStartupQueries = `
|
||||
PRAGMA foreign_keys = ON;
|
||||
`
|
||||
|
||||
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
|
||||
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
|
||||
selectWebPushSubscriptionsForTopicQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription_topic st
|
||||
JOIN subscription s ON s.id = st.subscription_id
|
||||
WHERE st.topic = ?
|
||||
ORDER BY endpoint
|
||||
`
|
||||
selectWebPushSubscriptionsExpiringSoonQuery = `
|
||||
SELECT id, endpoint, key_auth, key_p256dh, user_id
|
||||
FROM subscription
|
||||
WHERE warned_at = 0 AND updated_at <= ?
|
||||
`
|
||||
insertWebPushSubscriptionQuery = `
|
||||
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT (endpoint)
|
||||
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
|
||||
`
|
||||
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
|
||||
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
|
||||
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
|
||||
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
|
||||
|
||||
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
|
||||
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
const (
|
||||
currentWebPushSchemaVersion = 1
|
||||
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
type webPushStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupWebPushDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &webPushStore{
|
||||
db: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupWebPushDB(db *sql.DB) error {
|
||||
// If 'schemaVersion' table does not exist, this must be a new database
|
||||
rows, err := db.Query(selectWebPushSchemaVersionQuery)
|
||||
if err != nil {
|
||||
return setupNewWebPushDB(db)
|
||||
}
|
||||
return rows.Close()
|
||||
}
|
||||
|
||||
func setupNewWebPushDB(db *sql.DB) error {
|
||||
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(builtinStartupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
|
||||
// existing entries for a given endpoint.
|
||||
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
// Read number of subscriptions for subscriber IP address
|
||||
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rowsCount.Close()
|
||||
var subscriptionCount int
|
||||
if !rowsCount.Next() {
|
||||
return errWebPushNoRows
|
||||
}
|
||||
if err := rowsCount.Scan(&subscriptionCount); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rowsCount.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read existing subscription ID for endpoint (or create new ID)
|
||||
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
var subscriptionID string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||
return errWebPushTooManySubscriptions
|
||||
}
|
||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Insert or update subscription
|
||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace all subscription topics
|
||||
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, topic := range topics {
|
||||
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// SubscriptionsForTopic returns all subscriptions for the given topic
|
||||
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
|
||||
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return c.subscriptionsFromRows(rows)
|
||||
}
|
||||
|
||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
|
||||
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, subscription := range subscriptions {
|
||||
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
|
||||
subscriptions := make([]*webPushSubscription, 0)
|
||||
for rows.Next() {
|
||||
var id, endpoint, auth, p256dh, userID string
|
||||
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
subscriptions = append(subscriptions, &webPushSubscription{
|
||||
ID: id,
|
||||
Endpoint: endpoint,
|
||||
Auth: auth,
|
||||
P256dh: p256dh,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
return subscriptions, nil
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
|
||||
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
|
||||
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
|
||||
if userID == "" {
|
||||
return errWebPushUserIDCannotBeEmpty
|
||||
}
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
|
||||
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
|
||||
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (c *webPushStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
199
server/webpush_store_test.go
Normal file
199
server/webpush_store_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("test-topic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||
require.Equal(t, subs[0].Auth, "auth-key")
|
||||
require.Equal(t, subs[0].UserID, "u_1234")
|
||||
|
||||
subs2, err := webPush.SubscriptionsForTopic("mytopic")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs2, 1)
|
||||
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert 10 subscriptions with the same IP address
|
||||
for i := 0; i < 10; i++ {
|
||||
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
|
||||
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
// Another one for the same endpoint should be fine
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different endpoint it should fail
|
||||
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
|
||||
|
||||
// But with a different IP address it should be fine again
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
|
||||
}
|
||||
|
||||
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics, and another with one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
|
||||
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
// Update the first subscription to have only one topic
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 2)
|
||||
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
|
||||
|
||||
subs, err = webPush.SubscriptionsForTopic("topic2")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// And remove it again
|
||||
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
|
||||
}
|
||||
|
||||
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Mark them as warning sent
|
||||
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
|
||||
|
||||
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
|
||||
require.Nil(t, err)
|
||||
defer rows.Close()
|
||||
var endpoint string
|
||||
require.True(t, rows.Next())
|
||||
require.Nil(t, rows.Scan(&endpoint))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, testWebPushEndpoint, endpoint)
|
||||
require.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as soon-to-expire
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Should not be cleaned up yet
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// Run expiration
|
||||
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
|
||||
}
|
||||
|
||||
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
|
||||
webPush := newTestWebPushStore(t)
|
||||
defer webPush.Close()
|
||||
|
||||
// Insert subscription with two topics
|
||||
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
|
||||
subs, err := webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 1)
|
||||
|
||||
// Fake-mark them as expired
|
||||
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Run expiration
|
||||
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
|
||||
|
||||
// List again, should be 0
|
||||
subs, err = webPush.SubscriptionsForTopic("topic1")
|
||||
require.Nil(t, err)
|
||||
require.Len(t, subs, 0)
|
||||
}
|
||||
|
||||
func newTestWebPushStore(t *testing.T) *webPushStore {
|
||||
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
|
||||
require.Nil(t, err)
|
||||
return webPush
|
||||
}
|
||||
Reference in New Issue
Block a user