Merge branch 'main' of github.com:binwiederhier/ntfy into html-emails

This commit is contained in:
binwiederhier
2023-11-16 06:11:00 -05:00
249 changed files with 39439 additions and 23403 deletions

View File

@@ -1,10 +1,11 @@
package server
import (
"heckel.io/ntfy/user"
"io/fs"
"net/netip"
"time"
"heckel.io/ntfy/user"
)
// Defines default config settings (excluding limits, see below)
@@ -22,6 +23,12 @@ const (
DefaultStripePriceCacheDuration = 3 * time.Hour // Time to keep Stripe prices cached in memory before a refresh is needed
)
// Defines default Web Push settings
const (
DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour
DefaultWebPushExpiryDuration = 9 * 24 * time.Hour
)
// Defines all global and per-visitor limits
// - message size limit: the max number of bytes for a message
// - total topic limit: max number of topics overall
@@ -92,12 +99,13 @@ type Config struct {
KeepaliveInterval time.Duration
ManagerInterval time.Duration
DisallowedTopics []string
WebRootIsApp bool
WebRoot string // empty to disable
DelayedSenderInterval time.Duration
FirebaseKeepaliveInterval time.Duration
FirebasePollInterval time.Duration
FirebaseQuotaExceededPenaltyDuration time.Duration
UpstreamBaseURL string
UpstreamAccessToken string
SMTPSenderAddr string
SMTPSenderUser string
SMTPSenderPass string
@@ -105,6 +113,12 @@ type Config struct {
SMTPServerListen string
SMTPServerDomain string
SMTPServerAddrPrefix string
TwilioAccount string
TwilioAuthToken string
TwilioPhoneNumber string
TwilioCallsBaseURL string
TwilioVerifyBaseURL string
TwilioVerifyService string
MetricsEnable bool
MetricsListenHTTP string
ProfileListenHTTP string
@@ -133,13 +147,19 @@ type Config struct {
StripeWebhookKey string
StripePriceCacheDuration time.Duration
BillingContact string
EnableWeb bool
EnableSignup bool // Enable creation of accounts via API and UI
EnableLogin bool
EnableReservations bool // Allow users with role "user" to own/reserve topics
EnableMetrics bool
AccessControlAllowOrigin string // CORS header field to restrict access from web clients
Version string // injected by App
WebPushPrivateKey string
WebPushPublicKey string
WebPushFile string
WebPushEmailAddress string
WebPushStartupQueries string
WebPushExpiryDuration time.Duration
WebPushExpiryWarningDuration time.Duration
}
// NewConfig instantiates a default new server config
@@ -171,12 +191,13 @@ func NewConfig() *Config {
KeepaliveInterval: DefaultKeepaliveInterval,
ManagerInterval: DefaultManagerInterval,
DisallowedTopics: DefaultDisallowedTopics,
WebRootIsApp: false,
WebRoot: "/",
DelayedSenderInterval: DefaultDelayedSenderInterval,
FirebaseKeepaliveInterval: DefaultFirebaseKeepaliveInterval,
FirebasePollInterval: DefaultFirebasePollInterval,
FirebaseQuotaExceededPenaltyDuration: DefaultFirebaseQuotaExceededPenaltyDuration,
UpstreamBaseURL: "",
UpstreamAccessToken: "",
SMTPSenderAddr: "",
SMTPSenderUser: "",
SMTPSenderPass: "",
@@ -184,6 +205,12 @@ func NewConfig() *Config {
SMTPServerListen: "",
SMTPServerDomain: "",
SMTPServerAddrPrefix: "",
TwilioCallsBaseURL: "https://api.twilio.com", // Override for tests
TwilioAccount: "",
TwilioAuthToken: "",
TwilioPhoneNumber: "",
TwilioVerifyBaseURL: "https://verify.twilio.com", // Override for tests
TwilioVerifyService: "",
MessageLimit: DefaultMessageLengthLimit,
MinDelay: DefaultMinDelay,
MaxDelay: DefaultMaxDelay,
@@ -209,11 +236,16 @@ func NewConfig() *Config {
StripeWebhookKey: "",
StripePriceCacheDuration: DefaultStripePriceCacheDuration,
BillingContact: "",
EnableWeb: true,
EnableSignup: false,
EnableLogin: false,
EnableReservations: false,
AccessControlAllowOrigin: "*",
Version: "",
WebPushPrivateKey: "",
WebPushPublicKey: "",
WebPushFile: "",
WebPushEmailAddress: "",
WebPushExpiryDuration: DefaultWebPushExpiryDuration,
WebPushExpiryWarningDuration: DefaultWebPushExpiryWarningDuration,
}
}

View File

@@ -106,12 +106,25 @@ var (
errHTTPBadRequestNotAPaidUser = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", "", nil}
errHTTPBadRequestBillingRequestInvalid = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid billing request", "", nil}
errHTTPBadRequestBillingSubscriptionExists = &errHTTP{40029, http.StatusBadRequest, "invalid request: billing subscription already exists", "", nil}
errHTTPBadRequestTierInvalid = &errHTTP{40030, http.StatusBadRequest, "invalid request: tier does not exist", "", nil}
errHTTPBadRequestUserNotFound = &errHTTP{40031, http.StatusBadRequest, "invalid request: user does not exist", "", nil}
errHTTPBadRequestPhoneCallsDisabled = &errHTTP{40032, http.StatusBadRequest, "invalid request: calling is disabled", "https://ntfy.sh/docs/config/#phone-calls", nil}
errHTTPBadRequestPhoneNumberInvalid = &errHTTP{40033, http.StatusBadRequest, "invalid request: phone number invalid", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestPhoneNumberNotVerified = &errHTTP{40034, http.StatusBadRequest, "invalid request: phone number not verified, or no matching verified numbers found", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestAnonymousCallsNotAllowed = &errHTTP{40035, http.StatusBadRequest, "invalid request: anonymous phone calls are not allowed", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestPhoneNumberVerifyChannelInvalid = &errHTTP{40036, http.StatusBadRequest, "invalid request: verification channel must be 'sms' or 'call'", "https://ntfy.sh/docs/publish/#phone-calls", nil}
errHTTPBadRequestDelayNoCall = &errHTTP{40037, http.StatusBadRequest, "delayed call notifications are not supported", "", nil}
errHTTPBadRequestWebPushSubscriptionInvalid = &errHTTP{40038, http.StatusBadRequest, "invalid request: web push payload malformed", "", nil}
errHTTPBadRequestWebPushEndpointUnknown = &errHTTP{40039, http.StatusBadRequest, "invalid request: web push endpoint unknown", "", nil}
errHTTPBadRequestWebPushTopicCountTooHigh = &errHTTP{40040, http.StatusBadRequest, "invalid request: too many web push topic subscriptions", "", nil}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPForbidden = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}
errHTTPConflictUserExists = &errHTTP{40901, http.StatusConflict, "conflict: user already exists", "", nil}
errHTTPConflictTopicReserved = &errHTTP{40902, http.StatusConflict, "conflict: access control entry for topic or topic pattern already exists", "", nil}
errHTTPConflictSubscriptionExists = &errHTTP{40903, http.StatusConflict, "conflict: topic subscription already exists", "", nil}
errHTTPConflictPhoneNumberExists = &errHTTP{40904, http.StatusConflict, "conflict: phone number already exists", "", nil}
errHTTPGonePhoneVerificationExpired = &errHTTP{41001, http.StatusGone, "phone number verification expired or does not exist", "", nil}
errHTTPEntityTooLargeAttachment = &errHTTP{41301, http.StatusRequestEntityTooLarge, "attachment too large, or bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPEntityTooLargeMatrixRequest = &errHTTP{41302, http.StatusRequestEntityTooLarge, "Matrix request is larger than the max allowed length", "", nil}
errHTTPEntityTooLargeJSONBody = &errHTTP{41303, http.StatusRequestEntityTooLarge, "JSON body too large", "", nil}
@@ -124,8 +137,10 @@ var (
errHTTPTooManyRequestsLimitReservations = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", "", nil}
errHTTPTooManyRequestsLimitMessages = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPTooManyRequestsLimitAuthFailure = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations", nil} // FIXME document limit
errHTTPTooManyRequestsLimitCalls = &errHTTP{42910, http.StatusTooManyRequests, "limit reached: daily phone call quota reached", "https://ntfy.sh/docs/publish/#limitations", nil}
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", "", nil}
errHTTPInternalErrorInvalidPath = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", "", nil}
errHTTPInternalErrorMissingBaseURL = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/", nil}
errHTTPInternalErrorWebPushUnableToPublish = &errHTTP{50004, http.StatusInternalServerError, "internal server error: unable to publish web push message", "", nil}
errHTTPInsufficientStorageUnifiedPush = &errHTTP{50701, http.StatusInsufficientStorage, "cannot publish to UnifiedPush topic without previously active subscriber", "", nil}
)

View File

@@ -20,6 +20,7 @@ const (
tagFirebase = "firebase"
tagSMTP = "smtp" // Receive email
tagEmail = "email" // Send email
tagTwilio = "twilio"
tagFileCache = "file_cache"
tagMessageCache = "message_cache"
tagStripe = "stripe"
@@ -28,6 +29,7 @@ const (
tagResetter = "resetter"
tagWebsocket = "websocket"
tagMatrix = "matrix"
tagWebPush = "webpush"
)
var (

File diff suppressed because one or more lines are too long

1857
server/mailer_emoji_map.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,7 @@ import (
var (
errUnexpectedMessageType = errors.New("unexpected message type")
errMessageNotFound = errors.New("message not found")
errNoRows = errors.New("no rows found")
)
// Messages cache
@@ -44,6 +45,7 @@ const (
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
@@ -54,46 +56,51 @@ const (
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
insertMessageQuery = `
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
deleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
selectMessagesByIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`
selectMessagesSinceTimeQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`
selectMessagesSinceIDQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`
selectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`
selectMessagesDueQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
@@ -108,11 +115,14 @@ const (
selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
)
// Schema management queries
const (
currentSchemaVersion = 10
currentSchemaVersion = 12
createSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
@@ -222,20 +232,36 @@ const (
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
migrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
migrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
)
var (
migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
0: migrateFrom0,
1: migrateFrom1,
2: migrateFrom2,
3: migrateFrom3,
4: migrateFrom4,
5: migrateFrom5,
6: migrateFrom6,
7: migrateFrom7,
8: migrateFrom8,
9: migrateFrom9,
10: migrateFrom10,
11: migrateFrom11,
}
)
@@ -251,7 +277,7 @@ func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration
if err != nil {
return nil, err
}
if err := setupDB(db, startupQueries, cacheDuration); err != nil {
if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
var queue *util.BatchingQueue[*message]
@@ -365,6 +391,7 @@ func (c *messageCache) addMessages(ms []*message) error {
attachmentDeleted, // Always zero
sender,
m.User,
m.ContentType,
m.Encoding,
published,
)
@@ -637,7 +664,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
func readMessage(rows *sql.Rows) (*message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
err := rows.Scan(
&id,
&timestamp,
@@ -657,6 +684,7 @@ func readMessage(rows *sql.Rows) (*message, error) {
&attachmentURL,
&sender,
&user,
&contentType,
&encoding,
)
if err != nil {
@@ -687,30 +715,51 @@ func readMessage(rows *sql.Rows) (*message, error) {
}
}
return &message{
ID: id,
Time: timestamp,
Expires: expires,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Icon: icon,
Actions: actions,
Attachment: att,
Sender: senderIP, // Must parse assuming database must be correct
User: user,
Encoding: encoding,
ID: id,
Time: timestamp,
Expires: expires,
Event: messageEvent,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Icon: icon,
Actions: actions,
Attachment: att,
Sender: senderIP, // Must parse assuming database must be correct
User: user,
ContentType: contentType,
Encoding: encoding,
}, nil
}
func (c *messageCache) UpdateStats(messages int64) error {
_, err := c.db.Exec(updateStatsQuery, messages)
return err
}
func (c *messageCache) Stats() (messages int64, err error) {
rows, err := c.db.Query(selectStatsQuery)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *messageCache) Close() error {
return c.db.Close()
}
func setupDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
// Run startup queries
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
@@ -889,3 +938,35 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
}
return tx.Commit()
}
func migrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func migrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -9,13 +9,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net"
"net/http"
@@ -32,6 +25,14 @@ import (
"sync"
"time"
"unicode/utf8"
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/errgroup"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)
// Server is the main server, providing the UI and API for ntfy
@@ -48,15 +49,17 @@ type Server struct {
topics map[string]*topic
visitors map[string]*visitor // ip:<ip> or user:<user>
firebaseClient *firebaseClient
messages int64
messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
webPush *webPushStore // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
closeChan chan bool
mu sync.Mutex
mu sync.RWMutex
}
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@@ -75,17 +78,26 @@ var (
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js"
webManifestPath = "/manifest.webmanifest"
webRootHTMLPath = "/app.html"
webServiceWorkerPath = "/sw.js"
accountPath = "/account"
matrixPushPath = "/_matrix/push/v1/notify"
metricsPath = "/metrics"
apiHealthPath = "/v1/health"
apiTiers = "/v1/tiers"
apiStatsPath = "/v1/stats"
apiWebPushPath = "/v1/webpush"
apiTiersPath = "/v1/tiers"
apiUsersPath = "/v1/users"
apiUsersAccessPath = "/v1/users/access"
apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password"
apiAccountSettingsPath = "/v1/account/settings"
apiAccountSubscriptionPath = "/v1/account/subscription"
apiAccountReservationPath = "/v1/account/reservation"
apiAccountPhonePath = "/v1/account/phone"
apiAccountPhoneVerifyPath = "/v1/account/phone/verify"
apiAccountBillingPortalPath = "/v1/account/billing/portal"
apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
@@ -96,13 +108,13 @@ var (
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
urlRegex = regexp.MustCompile(`^https?://`)
phoneNumberRegex = regexp.MustCompile(`^\+\d{1,100}$`)
//go:embed site
webFs embed.FS
webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
webSiteDir = "/site"
webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
webAppIndex = "/app.html" // React app
webFs embed.FS
webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
webSiteDir = "/site"
webAppIndex = "/app.html" // React app
//go:embed docs
docsStaticFs embed.FS
@@ -116,9 +128,10 @@ const (
newMessageBody = "New message" // Used in poll requests as generic message
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
unifiedPushTopicLength = 14
jsonBodyBytesLimit = 16384 // Max number of bytes for a JSON request body
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
unifiedPushTopicLength = 14 // Length of UnifiedPush topics, including the "up" part
messagesHistoryMax = 10 // Number of message count values to keep in memory
)
// WebSocket constants
@@ -144,10 +157,21 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
var webPush *webPushStore
if conf.WebPushPublicKey != "" {
webPush, err = newWebPushStore(conf.WebPushFile, conf.WebPushStartupQueries)
if err != nil {
return nil, err
}
}
topics, err := messageCache.Topics()
if err != nil {
return nil, err
}
messages, err := messageCache.Stats()
if err != nil {
return nil, err
}
var fileCache *fileCache
if conf.AttachmentCacheDir != "" {
fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
@@ -177,15 +201,18 @@ func New(conf *Config) (*Server, error) {
firebaseClient = newFirebaseClient(sender, auther)
}
s := &Server{
config: conf,
messageCache: messageCache,
fileCache: fileCache,
firebaseClient: firebaseClient,
smtpSender: mailer,
topics: topics,
userManager: userManager,
visitors: make(map[string]*visitor),
stripe: stripe,
config: conf,
messageCache: messageCache,
webPush: webPush,
fileCache: fileCache,
firebaseClient: firebaseClient,
smtpSender: mailer,
topics: topics,
userManager: userManager,
messages: messages,
messagesHistory: []int64{messages},
visitors: make(map[string]*visitor),
stripe: stripe,
}
s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
return s, nil
@@ -329,6 +356,9 @@ func (s *Server) closeDatabases() {
s.userManager.Close()
}
s.messageCache.Close()
if s.webPush != nil {
s.webPush.Close()
}
}
// handle is the main entry point for all HTTP requests
@@ -395,14 +425,26 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
}
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.ensureWebEnabled(s.handleHome)(w, r, v)
if r.Method == http.MethodGet && r.URL.Path == "/" && s.config.WebRoot == "/" {
return s.ensureWebEnabled(s.handleRoot)(w, r, v)
} else if r.Method == http.MethodHead && r.URL.Path == "/" {
return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webManifestPath {
return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersAdd)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersDelete)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiUsersAccessPath {
return s.ensureAdmin(s.handleAccessAllow)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersAccessPath {
return s.ensureAdmin(s.handleAccessReset)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
@@ -441,13 +483,25 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
} else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhoneVerifyPath {
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberVerify)))(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiAccountPhonePath {
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberAdd)))(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPhonePath {
return s.ensureUser(s.ensureCallsEnabled(s.withAccountSync(s.handleAccountPhoneNumberDelete)))(w, r, v)
} else if r.Method == http.MethodPost && apiWebPushPath == r.URL.Path {
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushUpdate))(w, r, v)
} else if r.Method == http.MethodDelete && apiWebPushPath == r.URL.Path {
return s.ensureWebPushEnabled(s.limitRequests(s.handleWebPushDelete))(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
return s.handleStats(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
return s.handleMatrixDiscovery(w)
} else if r.Method == http.MethodGet && r.URL.Path == metricsPath && s.metricsHandler != nil {
return s.handleMetrics(w, r, v)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
} else if r.Method == http.MethodGet && (staticRegex.MatchString(r.URL.Path) || r.URL.Path == webServiceWorkerPath || r.URL.Path == webRootHTMLPath) {
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
@@ -479,12 +533,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return errHTTPNotFound
}
func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.WebRootIsApp {
r.URL.Path = webAppIndex
} else {
r.URL.Path = webHomeIndex
}
func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request, v *visitor) error {
r.URL.Path = webAppIndex
return s.handleStatic(w, r, v)
}
@@ -516,18 +566,18 @@ func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor
}
func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
appRoot := "/"
if !s.config.WebRootIsApp {
appRoot = "/app"
}
response := &apiConfigResponse{
BaseURL: "", // Will translate to window.location.origin
AppRoot: appRoot,
AppRoot: s.config.WebRoot,
EnableLogin: s.config.EnableLogin,
EnableSignup: s.config.EnableSignup,
EnablePayments: s.config.StripeSecretKey != "",
EnableCalls: s.config.TwilioAccount != "",
EnableEmails: s.config.SMTPSenderFrom != "",
EnableReservations: s.config.EnableReservations,
EnableWebPush: s.config.WebPushPublicKey != "",
BillingContact: s.config.BillingContact,
WebPushPublicKey: s.config.WebPushPublicKey,
DisallowedTopics: s.config.DisallowedTopics,
}
b, err := json.MarshalIndent(response, "", " ")
@@ -539,6 +589,25 @@ func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visi
return err
}
// handleWebManifest serves the web app manifest for the progressive web app (PWA)
func (s *Server) handleWebManifest(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
response := &webManifestResponse{
Name: "ntfy web",
Description: "ntfy lets you send push notifications via scripts from any computer or phone",
ShortName: "ntfy",
Scope: "/",
StartURL: s.config.WebRoot,
Display: "standalone",
BackgroundColor: "#ffffff",
ThemeColor: "#317f6f",
Icons: []*webManifestIcon{
{SRC: "/static/images/pwa-192x192.png", Sizes: "192x192", Type: "image/png"},
{SRC: "/static/images/pwa-512x512.png", Sizes: "512x512", Type: "image/png"},
},
}
return s.writeJSONWithContentType(w, response, "application/manifest+json")
}
// handleMetrics returns Prometheus metrics. This endpoint is only called if enable-metrics is set,
// and listen-metrics-http is not set.
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visitor) error {
@@ -546,17 +615,34 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visito
return nil
}
// handleStatic returns all static resources (excluding the docs), including the web app
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
r.URL.Path = webSiteDir + r.URL.Path
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
return nil
}
// handleDocs returns static resources related to the docs
func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
return nil
}
// handleStats returns the publicly available server stats
func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
s.mu.RLock()
messages, n, rate := s.messages, len(s.messagesHistory), float64(0)
if n > 1 {
rate = float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds())
}
s.mu.RUnlock()
response := &apiStatsResponse{
Messages: messages,
MessagesRate: rate,
}
return s.writeJSON(w, response)
}
// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
// can associate the download bandwidth with the uploader.
@@ -623,6 +709,9 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
return err
}
defer f.Close()
if m.Attachment.Name != "" {
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name))
}
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
return err
}
@@ -649,7 +738,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
return nil, err
}
m := newDefaultMessage(t.ID, "")
cache, firebase, email, unifiedpush, e := s.parsePublishParams(r, m)
cache, firebase, email, call, unifiedpush, e := s.parsePublishParams(r, m)
if e != nil {
return nil, e.With(t)
}
@@ -663,6 +752,14 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
return nil, errHTTPTooManyRequestsLimitMessages.With(t)
} else if email != "" && !vrate.EmailAllowed() {
return nil, errHTTPTooManyRequestsLimitEmails.With(t)
} else if call != "" {
var httpErr *errHTTP
call, httpErr = s.convertPhoneNumber(v.User(), call)
if httpErr != nil {
return nil, httpErr.With(t)
} else if !vrate.CallAllowed() {
return nil, errHTTPTooManyRequestsLimitCalls.With(t)
}
}
if m.PollID != "" {
m = newPollRequestMessage(t.ID, m.PollID)
@@ -687,6 +784,7 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
"message_firebase": firebase,
"message_unifiedpush": unifiedpush,
"message_email": email,
"message_call": call,
})
if ev.IsTrace() {
ev.Field("message_body", util.MaybeMarshalJSON(m)).Trace("Received message")
@@ -703,9 +801,15 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, e
if s.smtpSender != nil && email != "" {
go s.sendEmail(v, m, email)
}
if s.config.UpstreamBaseURL != "" {
if s.config.TwilioAccount != "" && call != "" {
go s.callPhone(v, r, m, call)
}
if s.config.UpstreamBaseURL != "" && !unifiedpush { // UP messages are not sent to upstream
go s.forwardPollRequest(v, m)
}
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
} else {
logvrm(v, r, m).Tag(tagPublish).Debug("Message delayed, will process later")
}
@@ -798,7 +902,11 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
logvm(v, m).Err(err).Warn("Unable to publish poll request")
return
}
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Set("X-Poll-ID", m.ID)
if s.config.UpstreamAccessToken != "" {
req.Header.Set("Authorization", util.BearerAuth(s.config.UpstreamAccessToken))
}
var httpClient = &http.Client{
Timeout: time.Second * 10,
}
@@ -807,12 +915,16 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
logvm(v, m).Err(err).Warn("Unable to publish poll request")
return
} else if response.StatusCode != http.StatusOK {
logvm(v, m).Err(err).Warn("Unable to publish poll request, unexpected HTTP status: %d", response.StatusCode)
if response.StatusCode == http.StatusTooManyRequests {
logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s; you may solve this by sending fewer daily messages, or by configuring upstream-access-token (assuming you have an account with higher rate limits) ", s.config.UpstreamBaseURL, response.Status)
} else {
logvm(v, m).Err(err).Warn("Unable to publish poll request, the upstream server %s responded with HTTP %s", s.config.UpstreamBaseURL, response.Status)
}
return
}
}
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err *errHTTP) {
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, unifiedpush bool, err *errHTTP) {
cache = readBoolParam(r, true, "x-cache", "cache")
firebase = readBoolParam(r, true, "x-firebase", "firebase")
m.Title = readParam(r, "x-title", "title", "t")
@@ -828,7 +940,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
}
if attach != "" {
if !urlRegex.MatchString(attach) {
return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
return false, false, "", "", false, errHTTPBadRequestAttachmentURLInvalid
}
m.Attachment.URL = attach
if m.Attachment.Name == "" {
@@ -846,13 +958,19 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
}
if icon != "" {
if !urlRegex.MatchString(icon) {
return false, false, "", false, errHTTPBadRequestIconURLInvalid
return false, false, "", "", false, errHTTPBadRequestIconURLInvalid
}
m.Icon = icon
}
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
if s.smtpSender == nil && email != "" {
return false, false, "", false, errHTTPBadRequestEmailDisabled
return false, false, "", "", false, errHTTPBadRequestEmailDisabled
}
call = readParam(r, "x-call", "call")
if call != "" && (s.config.TwilioAccount == "" || s.userManager == nil) {
return false, false, "", "", false, errHTTPBadRequestPhoneCallsDisabled
} else if call != "" && !isBoolValue(call) && !phoneNumberRegex.MatchString(call) {
return false, false, "", "", false, errHTTPBadRequestPhoneNumberInvalid
}
messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
if messageStr != "" {
@@ -861,24 +979,27 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
var e error
m.Priority, e = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
if e != nil {
return false, false, "", false, errHTTPBadRequestPriorityInvalid
return false, false, "", "", false, errHTTPBadRequestPriorityInvalid
}
m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta")
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
if delayStr != "" {
if !cache {
return false, false, "", false, errHTTPBadRequestDelayNoCache
return false, false, "", "", false, errHTTPBadRequestDelayNoCache
}
if email != "" {
return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
return false, false, "", "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
}
if call != "" {
return false, false, "", "", false, errHTTPBadRequestDelayNoCall // we cannot store the phone number (yet)
}
delay, err := util.ParseFutureTime(delayStr, time.Now())
if err != nil {
return false, false, "", false, errHTTPBadRequestDelayCannotParse
return false, false, "", "", false, errHTTPBadRequestDelayCannotParse
} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
return false, false, "", false, errHTTPBadRequestDelayTooSmall
return false, false, "", "", false, errHTTPBadRequestDelayTooSmall
} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
return false, false, "", false, errHTTPBadRequestDelayTooLarge
return false, false, "", "", false, errHTTPBadRequestDelayTooLarge
}
m.Time = delay.Unix()
}
@@ -886,9 +1007,13 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
if actionsStr != "" {
m.Actions, e = parseActions(actionsStr)
if e != nil {
return false, false, "", false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
return false, false, "", "", false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
}
}
contentType, markdown := readParam(r, "content-type", "content_type"), readBoolParam(r, false, "x-markdown", "markdown", "md")
if markdown || strings.ToLower(contentType) == "text/markdown" {
m.ContentType = "text/markdown"
}
unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
if unifiedpush {
firebase = false
@@ -900,7 +1025,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
cache = false
email = ""
}
return cache, firebase, email, unifiedpush, nil
return cache, firebase, email, call, unifiedpush, nil
}
// handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
@@ -1170,7 +1295,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
defer conn.Close()
// Subscription connections can be canceled externally, see topic.CancelSubscribers
// Subscription connections can be canceled externally, see topic.CancelSubscribersExceptUser
cancelCtx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -1412,6 +1537,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visito
return nil
}
// topicFromPath returns the topic from a root path (e.g. /mytopic), creating it if it doesn't exist.
func (s *Server) topicFromPath(path string) (*topic, error) {
parts := strings.Split(path, "/")
if len(parts) < 2 {
@@ -1420,6 +1546,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
return s.topicFromID(parts[1])
}
// topicsFromPath returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
parts := strings.Split(path, "/")
if len(parts) < 2 {
@@ -1433,6 +1560,7 @@ func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
return topics, parts[1], nil
}
// topicsFromIDs returns the topics with the given IDs, creating them if they don't exist.
func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1452,6 +1580,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return topics, nil
}
// topicFromID returns the topic with the given ID, creating it if it doesn't exist.
func (s *Server) topicFromID(id string) (*topic, error) {
topics, err := s.topicsFromIDs(id)
if err != nil {
@@ -1460,6 +1589,23 @@ func (s *Server) topicFromID(id string) (*topic, error) {
return topics[0], nil
}
// topicsFromPattern returns a list of topics matching the given pattern, but it does not create them.
func (s *Server) topicsFromPattern(pattern string) ([]*topic, error) {
s.mu.RLock()
defer s.mu.RUnlock()
patternRegexp, err := regexp.Compile("^" + strings.ReplaceAll(pattern, "*", ".*") + "$")
if err != nil {
return nil, err
}
topics := make([]*topic, 0)
for _, t := range s.topics {
if patternRegexp.MatchString(t.ID) {
topics = append(topics, t)
}
}
return topics, nil
}
func (s *Server) runSMTPServer() error {
s.smtpServerBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpServerBackend)
@@ -1580,9 +1726,9 @@ func (s *Server) sendDelayedMessages() error {
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
logvm(v, m).Debug("Sending delayed message")
s.mu.Lock()
s.mu.RLock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
s.mu.Unlock()
s.mu.RUnlock()
if ok {
go func() {
// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
@@ -1597,6 +1743,9 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
if s.config.UpstreamBaseURL != "" {
go s.forwardPollRequest(v, m)
}
if s.config.WebPushPublicKey != "" {
go s.publishToWebPushEndpoints(v, m)
}
if err := s.messageCache.MarkPublished(m); err != nil {
return err
}
@@ -1640,6 +1789,9 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
if m.Icon != "" {
r.Header.Set("X-Icon", m.Icon)
}
if m.Markdown {
r.Header.Set("X-Markdown", "yes")
}
if len(m.Actions) > 0 {
actionsStr, err := json.Marshal(m.Actions)
if err != nil {
@@ -1653,6 +1805,9 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
if m.Delay != "" {
r.Header.Set("X-Delay", m.Delay)
}
if m.Call != "" {
r.Header.Set("X-Call", m.Call)
}
return next(w, r, v)
}
}
@@ -1814,10 +1969,28 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
}
func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
w.Header().Set("Content-Type", "application/json")
return s.writeJSONWithContentType(w, v, "application/json")
}
func (s *Server) writeJSONWithContentType(w http.ResponseWriter, v any, contentType string) error {
w.Header().Set("Content-Type", contentType)
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if err := json.NewEncoder(w).Encode(v); err != nil {
return err
}
return nil
}
func (s *Server) updateAndWriteStats(messagesCount int64) {
s.mu.Lock()
s.messagesHistory = append(s.messagesHistory, messagesCount)
if len(s.messagesHistory) > messagesHistoryMax {
s.messagesHistory = s.messagesHistory[1:]
}
s.mu.Unlock()
go func() {
if err := s.messageCache.UpdateStats(messagesCount); err != nil {
log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
}
}()
}

View File

@@ -144,6 +144,39 @@
# smtp-server-domain:
# smtp-server-addr-prefix:
# Web Push support (background notifications for browsers)
#
# If enabled, allows ntfy to receive push notifications, even when the ntfy web app is closed. When enabled, users
# can enable background notifications in the web app. Once enabled, ntfy will forward published messages to the push
# endpoint, which will then forward it to the browser.
#
# You must configure web-push-public/private key, web-push-file, and web-push-email-address below to enable Web Push.
# Run "ntfy webpush keys" to generate the keys.
#
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
# - web-push-email-address is the admin email address send to the push provider, e.g. `sysadmin@example.com`
# - web-push-startup-queries is an optional list of queries to run on startup`
#
# web-push-public-key:
# web-push-private-key:
# web-push-file:
# web-push-email-address:
# web-push-startup-queries:
# If enabled, ntfy can perform voice calls via Twilio via the "X-Call" header.
#
# - twilio-account is the Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586
# - twilio-auth-token is the Twilio auth token, e.g. affebeef258625862586258625862586
# - twilio-phone-number is the outgoing phone number you purchased, e.g. +18775132586
# - twilio-verify-service is the Twilio Verify service SID, e.g. VA12345beefbeef67890beefbeef122586
#
# twilio-account:
# twilio-auth-token:
# twilio-phone-number:
# twilio-verify-service:
# Interval in which keepalive messages are sent to the client. This is to prevent
# intermediaries closing the connection for inactivity.
#
@@ -167,11 +200,13 @@
#
# disallowed-topics:
# Defines if the root route (/) is pointing to the landing page (as on ntfy.sh) or the
# web app. If you self-host, you don't want to change this.
# Can be "app" (default), "home" or "disable" to disable the web app entirely.
# Defines the root path of the web app, or disables the web app entirely.
#
# web-root: app
# Can be any simple path, e.g. "/", "/app", or "/ntfy". For backwards-compatibility reasons,
# the values "app" (maps to "/"), "home" (maps to "/app"), or "disable" (maps to "") to disable
# the web app entirely.
#
# web-root: /
# Various feature flags used to control the web app, and API access, mainly around user and
# account management.
@@ -194,7 +229,12 @@
# the message ID of the original message, instructing the iOS app to poll this server for the actual message contents.
# This is to prevent the upstream server and Firebase/APNS from being able to read the message.
#
# - upstream-base-url is the base URL of the upstream server. Should be "https://ntfy.sh".
# - upstream-access-token is the token used to authenticate with the upstream server. This is only required
# if you exceed the upstream rate limits, or the uptream server requires authentication.
#
# upstream-base-url:
# upstream-access-token:
# Rate limiting: Total number of topics before the server rejects new topics.
#
@@ -302,6 +342,10 @@
# - "field -> level" to match any value, e.g. "time_taken_ms -> debug"
# Warning: Using log-level-overrides has a performance penalty. Only use it for temporary debugging.
#
# Check your permissions:
# If you are running ntfy with systemd, make sure this log file is owned by the
# ntfy user and group by running: chown ntfy.ntfy <filename>.
#
# Example (good for production):
# log-level: info
# log-format: json

View File

@@ -56,6 +56,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
Messages: limits.MessageLimit,
MessagesExpiryDuration: int64(limits.MessageExpiryDuration.Seconds()),
Emails: limits.EmailLimit,
Calls: limits.CallLimit,
Reservations: limits.ReservationsLimit,
AttachmentTotalSize: limits.AttachmentTotalSizeLimit,
AttachmentFileSize: limits.AttachmentFileSizeLimit,
@@ -67,6 +68,8 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
MessagesRemaining: stats.MessagesRemaining,
Emails: stats.Emails,
EmailsRemaining: stats.EmailsRemaining,
Calls: stats.Calls,
CallsRemaining: stats.CallsRemaining,
Reservations: stats.Reservations,
ReservationsRemaining: stats.ReservationsRemaining,
AttachmentTotalSize: stats.AttachmentTotalSize,
@@ -105,17 +108,19 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
CancelAt: u.Billing.StripeSubscriptionCancelAt.Unix(),
}
}
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
}
if len(reservations) > 0 {
response.Reservations = make([]*apiAccountReservation, 0)
for _, r := range reservations {
response.Reservations = append(response.Reservations, &apiAccountReservation{
Topic: r.Topic,
Everyone: r.Everyone.String(),
})
if s.config.EnableReservations {
reservations, err := s.userManager.Reservations(u.Name)
if err != nil {
return err
}
if len(reservations) > 0 {
response.Reservations = make([]*apiAccountReservation, 0)
for _, r := range reservations {
response.Reservations = append(response.Reservations, &apiAccountReservation{
Topic: r.Topic,
Everyone: r.Everyone.String(),
})
}
}
}
tokens, err := s.userManager.Tokens(u.ID)
@@ -138,6 +143,15 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
})
}
}
if s.config.TwilioAccount != "" {
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
if err != nil {
return err
}
if len(phoneNumbers) > 0 {
response.PhoneNumbers = phoneNumbers
}
}
} else {
response.Username = user.Everyone
response.Role = string(user.RoleAnonymous)
@@ -156,6 +170,11 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
return errHTTPBadRequestIncorrectPasswordConfirmation
}
if s.webPush != nil && u.ID != "" {
if err := s.webPush.RemoveSubscriptionsByUserID(u.ID); err != nil {
logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
}
}
if u.Billing.StripeSubscriptionID != "" {
logvr(v, r).Tag(tagStripe).Info("Canceling billing subscription for user %s", u.Name)
if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {
@@ -444,7 +463,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
if err != nil {
return err
}
t.CancelSubscribers(u.ID)
t.CancelSubscribersExceptUser(u.ID)
return s.writeJSON(w, newSuccessResponse())
}
@@ -511,6 +530,72 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *vi
return nil
}
func (s *Server) handleAccountPhoneNumberVerify(w http.ResponseWriter, r *http.Request, v *visitor) error {
u := v.User()
req, err := readJSONWithLimit[apiAccountPhoneNumberVerifyRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
} else if !phoneNumberRegex.MatchString(req.Number) {
return errHTTPBadRequestPhoneNumberInvalid
} else if req.Channel != "sms" && req.Channel != "call" {
return errHTTPBadRequestPhoneNumberVerifyChannelInvalid
}
// Check user is allowed to add phone numbers
if u == nil || (u.IsUser() && u.Tier == nil) {
return errHTTPUnauthorized
} else if u.IsUser() && u.Tier.CallLimit == 0 {
return errHTTPUnauthorized
}
// Check if phone number exists
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
if err != nil {
return err
} else if util.Contains(phoneNumbers, req.Number) {
return errHTTPConflictPhoneNumberExists
}
// Actually add the unverified number, and send verification
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Sending phone number verification")
if err := s.verifyPhoneNumber(v, r, req.Number, req.Channel); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccountPhoneNumberAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
u := v.User()
req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
if !phoneNumberRegex.MatchString(req.Number) {
return errHTTPBadRequestPhoneNumberInvalid
}
if err := s.verifyPhoneNumberCheck(v, r, req.Number, req.Code); err != nil {
return err
}
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Adding phone number as verified")
if err := s.userManager.AddPhoneNumber(u.ID, req.Number); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccountPhoneNumberDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
u := v.User()
req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
if !phoneNumberRegex.MatchString(req.Number) {
return errHTTPBadRequestPhoneNumberInvalid
}
logvr(v, r).Tag(tagAccount).Field("phone_number", req.Number).Debug("Deleting phone number")
if err := s.userManager.RemovePhoneNumber(u.ID, req.Number); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
// publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic
func (s *Server) publishSyncEventAsync(v *visitor) {
go func() {

View File

@@ -151,6 +151,8 @@ func TestAccount_Get_Anonymous(t *testing.T) {
require.Equal(t, int64(1004), account.Stats.MessagesRemaining)
require.Equal(t, int64(0), account.Stats.Emails)
require.Equal(t, int64(24), account.Stats.EmailsRemaining)
require.Equal(t, int64(0), account.Stats.Calls)
require.Equal(t, int64(0), account.Stats.CallsRemaining)
rr = request(t, s, "POST", "/mytopic", "", nil)
require.Equal(t, 200, rr.Code)
@@ -498,6 +500,8 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.EnableSignup = true
conf.EnableReservations = true
conf.TwilioAccount = "dummy"
s := newTestServer(t, conf)
// Create user
@@ -510,6 +514,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
MessageLimit: 123,
MessageExpiryDuration: 86400 * time.Second,
EmailLimit: 32,
CallLimit: 10,
ReservationLimit: 2,
AttachmentFileSizeLimit: 1231231,
AttachmentTotalSizeLimit: 123123,
@@ -551,6 +556,7 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
require.Equal(t, int64(123), account.Limits.Messages)
require.Equal(t, int64(86400), account.Limits.MessagesExpiryDuration)
require.Equal(t, int64(32), account.Limits.Emails)
require.Equal(t, int64(10), account.Limits.Calls)
require.Equal(t, int64(2), account.Limits.Reservations)
require.Equal(t, int64(1231231), account.Limits.AttachmentFileSize)
require.Equal(t, int64(123123), account.Limits.AttachmentTotalSize)

143
server/server_admin.go Normal file
View File

@@ -0,0 +1,143 @@
package server
import (
"heckel.io/ntfy/user"
"net/http"
)
func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
users, err := s.userManager.Users()
if err != nil {
return err
}
grants, err := s.userManager.AllGrants()
if err != nil {
return err
}
usersResponse := make([]*apiUserResponse, len(users))
for i, u := range users {
tier := ""
if u.Tier != nil {
tier = u.Tier.Code
}
userGrants := make([]*apiUserGrantResponse, len(grants[u.ID]))
for i, g := range grants[u.ID] {
userGrants[i] = &apiUserGrantResponse{
Topic: g.TopicPattern,
Permission: g.Allow.String(),
}
}
usersResponse[i] = &apiUserResponse{
Username: u.Name,
Role: string(u.Role),
Tier: tier,
Grants: userGrants,
}
}
return s.writeJSON(w, usersResponse)
}
func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserAddRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
} else if !user.AllowedUsername(req.Username) || req.Password == "" {
return errHTTPBadRequest.Wrap("username invalid, or password missing")
}
u, err := s.userManager.User(req.Username)
if err != nil && err != user.ErrUserNotFound {
return err
} else if u != nil {
return errHTTPConflictUserExists
}
var tier *user.Tier
if req.Tier != "" {
tier, err = s.userManager.Tier(req.Tier)
if err == user.ErrTierNotFound {
return errHTTPBadRequestTierInvalid
} else if err != nil {
return err
}
}
if err := s.userManager.AddUser(req.Username, req.Password, user.RoleUser); err != nil {
return err
}
if tier != nil {
if err := s.userManager.ChangeTier(req.Username, req.Tier); err != nil {
return err
}
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
u, err := s.userManager.User(req.Username)
if err == user.ErrUserNotFound {
return errHTTPBadRequestUserNotFound
} else if err != nil {
return err
} else if !u.IsUser() {
return errHTTPUnauthorized.Wrap("can only remove regular users from API")
}
if err := s.userManager.RemoveUser(req.Username); err != nil {
return err
}
if err := s.killUserSubscriber(u, "*"); err != nil { // FIXME super inefficient
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccessAllowRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
_, err = s.userManager.User(req.Username)
if err == user.ErrUserNotFound {
return errHTTPBadRequestUserNotFound
} else if err != nil {
return err
}
permission, err := user.ParsePermission(req.Permission)
if err != nil {
return errHTTPBadRequestPermissionInvalid
}
if err := s.userManager.AllowAccess(req.Username, req.Topic, permission); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccessReset(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccessResetRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
u, err := s.userManager.User(req.Username)
if err != nil {
return err
}
if err := s.userManager.ResetAccess(req.Username, req.Topic); err != nil {
return err
}
if err := s.killUserSubscriber(u, req.Topic); err != nil { // This may be a pattern
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) killUserSubscriber(u *user.User, topicPattern string) error {
topics, err := s.topicsFromPattern(topicPattern)
if err != nil {
return err
}
for _, t := range topics {
t.CancelSubscriberUser(u.ID)
}
return nil
}

181
server/server_admin_test.go Normal file
View File

@@ -0,0 +1,181 @@
package server
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"sync/atomic"
"testing"
"time"
)
func TestUser_AddRemove(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create user via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_AddRemove_Failures(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Cannot create user with invalid username
rr := request(t, s, "PUT", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "PUT", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestAccess_AllowReset(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
})
}

View File

@@ -144,17 +144,18 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
}
if allowForward {
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": m.Event,
"topic": m.Topic,
"priority": fmt.Sprintf("%d", m.Priority),
"tags": strings.Join(m.Tags, ","),
"click": m.Click,
"icon": m.Icon,
"title": m.Title,
"message": m.Message,
"encoding": m.Encoding,
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": m.Event,
"topic": m.Topic,
"priority": fmt.Sprintf("%d", m.Priority),
"tags": strings.Join(m.Tags, ","),
"click": m.Click,
"icon": m.Icon,
"title": m.Title,
"message": m.Message,
"content_type": m.ContentType,
"encoding": m.Encoding,
}
if len(m.Actions) > 0 {
actions, err := json.Marshal(m.Actions)

View File

@@ -182,6 +182,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
"title": "some title",
"message": "this is a message",
"actions": `[{"id":"123","action":"view","label":"Open page","clear":true,"url":"https://ntfy.sh"},{"id":"456","action":"http","label":"Close door","clear":false,"url":"https://door.com/close","method":"PUT","headers":{"really":"yes"}}]`,
"content_type": "",
"encoding": "",
"attachment_name": "some file.jpg",
"attachment_type": "image/jpeg",
@@ -203,6 +204,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
"title": "some title",
"message": "this is a message",
"actions": `[{"id":"123","action":"view","label":"Open page","clear":true,"url":"https://ntfy.sh"},{"id":"456","action":"http","label":"Close door","clear":false,"url":"https://door.com/close","method":"PUT","headers":{"really":"yes"}}]`,
"content_type": "",
"encoding": "",
"attachment_name": "some file.jpg",
"attachment_type": "image/jpeg",

View File

@@ -15,6 +15,7 @@ func (s *Server) execManager() {
s.pruneTokens()
s.pruneAttachments()
s.pruneMessages()
s.pruneAndNotifyWebPushSubscriptions()
// Message count per topic
var messagesCached int
@@ -73,9 +74,14 @@ func (s *Server) execManager() {
}
// Print stats
s.mu.Lock()
s.mu.RLock()
messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
s.mu.Unlock()
s.mu.RUnlock()
// Update stats
s.updateAndWriteStats(messagesCount)
// Log stats
log.
Tag(tagManager).
Fields(log.Context{

View File

@@ -15,6 +15,8 @@ var (
metricEmailsPublishedFailure prometheus.Counter
metricEmailsReceivedSuccess prometheus.Counter
metricEmailsReceivedFailure prometheus.Counter
metricCallsMadeSuccess prometheus.Counter
metricCallsMadeFailure prometheus.Counter
metricUnifiedPushPublishedSuccess prometheus.Counter
metricMatrixPublishedSuccess prometheus.Counter
metricMatrixPublishedFailure prometheus.Counter
@@ -57,6 +59,12 @@ func initMetrics() {
metricEmailsReceivedFailure = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ntfy_emails_received_failure",
})
metricCallsMadeSuccess = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ntfy_calls_made_success",
})
metricCallsMadeFailure = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ntfy_calls_made_failure",
})
metricUnifiedPushPublishedSuccess = prometheus.NewCounter(prometheus.CounterOpts{
Name: "ntfy_unifiedpush_published_success",
})
@@ -95,6 +103,8 @@ func initMetrics() {
metricEmailsPublishedFailure,
metricEmailsReceivedSuccess,
metricEmailsReceivedFailure,
metricCallsMadeSuccess,
metricCallsMadeFailure,
metricUnifiedPushPublishedSuccess,
metricMatrixPublishedSuccess,
metricMatrixPublishedFailure,

View File

@@ -51,7 +51,16 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !s.config.EnableWeb {
if s.config.WebRoot == "" {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensureWebPushEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.WebRoot == "" || s.config.WebPushPublicKey == "" {
return errHTTPNotFound
}
return next(w, r, v)
@@ -76,6 +85,24 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
})
}
func (s *Server) ensureAdmin(next handleFunc) handleFunc {
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !v.User().IsAdmin() {
return errHTTPUnauthorized
}
return next(w, r, v)
})
}
func (s *Server) ensureCallsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.TwilioAccount == "" || s.userManager == nil {
return errHTTPNotFound
}
return next(w, r, v)
}
}
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.StripeSecretKey == "" || s.stripe == nil {

View File

@@ -68,6 +68,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
Messages: freeTier.MessageLimit,
MessagesExpiryDuration: int64(freeTier.MessageExpiryDuration.Seconds()),
Emails: freeTier.EmailLimit,
Calls: freeTier.CallLimit,
Reservations: freeTier.ReservationsLimit,
AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
@@ -96,6 +97,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
Messages: tier.MessageLimit,
MessagesExpiryDuration: int64(tier.MessageExpiryDuration.Seconds()),
Emails: tier.EmailLimit,
Calls: tier.CallLimit,
Reservations: tier.ReservationLimit,
AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
AttachmentFileSize: tier.AttachmentFileSizeLimit,

View File

@@ -18,11 +18,11 @@ import (
"runtime/debug"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/SherClockHolmes/webpush-go"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
@@ -220,11 +220,7 @@ func TestServer_StaticSites(t *testing.T) {
rr = request(t, s, "GET", "/mytopic", "", nil)
require.Equal(t, 200, rr.Code)
require.Contains(t, rr.Body.String(), `<meta name="robots" content="noindex, nofollow"/>`)
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
require.Equal(t, 200, rr.Code)
require.Contains(t, rr.Body.String(), `/* general styling */`)
require.Contains(t, rr.Body.String(), `<meta name="robots" content="noindex, nofollow" />`)
rr = request(t, s, "GET", "/docs", "", nil)
require.Equal(t, 301, rr.Code)
@@ -234,7 +230,7 @@ func TestServer_StaticSites(t *testing.T) {
func TestServer_WebEnabled(t *testing.T) {
conf := newTestConfig(t)
conf.EnableWeb = false
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/", "", nil)
@@ -243,11 +239,17 @@ func TestServer_WebEnabled(t *testing.T) {
rr = request(t, s, "GET", "/config.js", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/sw.js", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/app.html", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s, "GET", "/static/css/home.css", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
conf2.EnableWeb = true
conf2.WebRoot = "/"
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/", "", nil)
@@ -256,8 +258,34 @@ func TestServer_WebEnabled(t *testing.T) {
rr = request(t, s2, "GET", "/config.js", "", nil)
require.Equal(t, 200, rr.Code)
rr = request(t, s2, "GET", "/static/css/home.css", "", nil)
rr = request(t, s2, "GET", "/sw.js", "", nil)
require.Equal(t, 200, rr.Code)
rr = request(t, s2, "GET", "/app.html", "", nil)
require.Equal(t, 200, rr.Code)
}
func TestServer_WebPushEnabled(t *testing.T) {
conf := newTestConfig(t)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf3 := newTestConfigWithWebPush(t)
s3 := newTestServer(t, conf3)
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
}
func TestServer_PublishLargeMessage(t *testing.T) {
@@ -301,6 +329,27 @@ func TestServer_PublishPriority(t *testing.T) {
require.Equal(t, 40007, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PublishPriority_SpecialHTTPHeader(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"Priority": "u=4",
"X-Priority": "5",
})
require.Equal(t, 5, toMessage(t, response.Body.String()).Priority)
response = request(t, s, "POST", "/mytopic?priority=4", "test", map[string]string{
"Priority": "u=9",
})
require.Equal(t, 4, toMessage(t, response.Body.String()).Priority)
response = request(t, s, "POST", "/mytopic", "test", map[string]string{
"p": "2",
"priority": "u=9, i",
})
require.Equal(t, 2, toMessage(t, response.Body.String()).Priority)
}
func TestServer_PublishGETOnlyOneTopic(t *testing.T) {
// This tests a bug that allowed publishing topics with a comma in the name (no ticket)
@@ -463,6 +512,8 @@ func TestServer_PublishAtAndPrune(t *testing.T) {
messages := toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages)) // Not affected by pruning
require.Equal(t, "a message", messages[0].Message)
time.Sleep(time.Second) // FIXME CI failing not sure why
}
func TestServer_PublishAndMultiPoll(t *testing.T) {
@@ -1199,7 +1250,20 @@ func TestServer_PublishDelayedEmail_Fail(t *testing.T) {
"E-Mail": "test@example.com",
"Delay": "20 min",
})
require.Equal(t, 400, response.Code)
require.Equal(t, 40003, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PublishDelayedCall_Fail(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "PUT", "/mytopic", "fail", map[string]string{
"Call": "yes",
"Delay": "20 min",
})
require.Equal(t, 40037, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
@@ -1477,6 +1541,39 @@ func TestServer_PublishActions_AndPoll(t *testing.T) {
require.Equal(t, "target_temp_f=65", m.Actions[1].Body)
}
func TestServer_PublishMarkdown(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "**make this bold**", map[string]string{
"Content-Type": "text/markdown",
})
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.Equal(t, "**make this bold**", m.Message)
require.Equal(t, "text/markdown", m.ContentType)
}
func TestServer_PublishMarkdown_QueryParam(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic?md=1", "**make this bold**", nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.Equal(t, "**make this bold**", m.Message)
require.Equal(t, "text/markdown", m.ContentType)
}
func TestServer_PublishMarkdown_NotMarkdown(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "PUT", "/mytopic", "**make this bold**", map[string]string{
"Content-Type": "not-markdown",
})
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.Equal(t, "", m.ContentType)
}
func TestServer_PublishAsJSON(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
body := `{"topic":"mytopic","message":"A message","title":"a title\nwith lines","tags":["tag1","tag 2"],` +
@@ -1494,12 +1591,25 @@ func TestServer_PublishAsJSON(t *testing.T) {
require.Equal(t, "google.pdf", m.Attachment.Name)
require.Equal(t, "http://ntfy.sh", m.Click)
require.Equal(t, "https://ntfy.sh/static/img/ntfy.png", m.Icon)
require.Equal(t, "", m.ContentType)
require.Equal(t, 4, m.Priority)
require.True(t, m.Time > time.Now().Unix()+29*60)
require.True(t, m.Time < time.Now().Unix()+31*60)
}
func TestServer_PublishAsJSON_Markdown(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
body := `{"topic":"mytopic","message":"**This is bold**","markdown":true}`
response := request(t, s, "PUT", "/", body, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.Equal(t, "mytopic", m.Topic)
require.Equal(t, "**This is bold**", m.Message)
require.Equal(t, "text/markdown", m.ContentType)
}
func TestServer_PublishAsJSON_RateLimit_MessageDailyLimit(t *testing.T) {
// Publishing as JSON follows a different path. This ensures that rate
// limiting works for this endpoint as well
@@ -2106,8 +2216,8 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
start = time.Now()
response := request(t, s, "PUT", "/mytopic", "some body", nil)
m := toMessage(t, response.Body.String())
assert.Equal(t, "some body", m.Message)
assert.True(t, time.Since(start) < 100*time.Millisecond)
require.Equal(t, "some body", m.Message)
require.True(t, time.Since(start) < 100*time.Millisecond)
log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
// Wait for all goroutines
@@ -2399,6 +2509,184 @@ func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *t
require.Nil(t, s.topics["announcements"].rateVisitor)
}
func TestServer_MessageHistoryAndStatsEndpoint(t *testing.T) {
c := newTestConfig(t)
c.ManagerInterval = 2 * time.Second
s := newTestServer(t, c)
// Publish some messages, and get stats
for i := 0; i < 5; i++ {
response := request(t, s, "POST", "/mytopic", "some message", nil)
require.Equal(t, 200, response.Code)
}
require.Equal(t, int64(5), s.messages)
require.Equal(t, []int64{0}, s.messagesHistory)
response := request(t, s, "GET", "/v1/stats", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"messages":5,"messages_rate":0}`+"\n", response.Body.String())
// Run manager and see message history update
s.execManager()
require.Equal(t, []int64{0, 5}, s.messagesHistory)
response = request(t, s, "GET", "/v1/stats", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"messages":5,"messages_rate":2.5}`+"\n", response.Body.String()) // 5 messages in 2 seconds = 2.5 messages per second
// Publish some more messages
for i := 0; i < 10; i++ {
response := request(t, s, "POST", "/mytopic", "some message", nil)
require.Equal(t, 200, response.Code)
}
require.Equal(t, int64(15), s.messages)
require.Equal(t, []int64{0, 5}, s.messagesHistory)
response = request(t, s, "GET", "/v1/stats", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"messages":15,"messages_rate":2.5}`+"\n", response.Body.String()) // Rate did not update yet
// Run manager and see message history update
s.execManager()
require.Equal(t, []int64{0, 5, 15}, s.messagesHistory)
response = request(t, s, "GET", "/v1/stats", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"messages":15,"messages_rate":3.75}`+"\n", response.Body.String()) // 15 messages in 4 seconds = 3.75 messages per second
}
func TestServer_MessageHistoryMaxSize(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
for i := 0; i < 20; i++ {
s.messages = int64(i)
s.execManager()
}
require.Equal(t, []int64{10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, s.messagesHistory)
}
func TestServer_MessageCountPersistence(t *testing.T) {
c := newTestConfig(t)
s := newTestServer(t, c)
s.messages = 1234
s.execManager()
waitFor(t, func() bool {
messages, err := s.messageCache.Stats()
require.Nil(t, err)
return messages == 1234
})
s = newTestServer(t, c)
require.Equal(t, int64(1234), s.messages)
}
func TestServer_PublishWithUTF8MimeHeader(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "some attachment", map[string]string{
"X-Filename": "some =?UTF-8?q?=C3=A4?=ttachment.txt",
"X-Message": "=?UTF-8?B?8J+HqfCfh6o=?=",
"X-Title": "=?UTF-8?B?bnRmeSDlvojmo5I=?=, no really I mean it! =?UTF-8?Q?This is q=C3=BC=C3=B6ted-print=C3=A4ble.?=",
"X-Tags": "=?UTF-8?B?8J+HqfCfh6o=?=, =?UTF-8?B?bnRmeSDlvojmo5I=?=",
"X-Click": "=?uTf-8?b?aHR0cHM6Ly/wn5KpLmxh?=",
"X-Actions": "http, \"=?utf-8?q?Mettre =C3=A0 jour?=\", \"https://my.tld/webhook/netbird-update\"; =?utf-8?b?aHR0cCwg6L+Z5piv5LiA5Liq5qCH562+LCBodHRwczovL/CfkqkubGE=?=",
})
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.Equal(t, "🇩🇪", m.Message)
require.Equal(t, "ntfy 很棒, no really I mean it! This is qüöted-printäble.", m.Title)
require.Equal(t, "some ättachment.txt", m.Attachment.Name)
require.Equal(t, "🇩🇪", m.Tags[0])
require.Equal(t, "ntfy 很棒", m.Tags[1])
require.Equal(t, "https://💩.la", m.Click)
require.Equal(t, "Mettre à jour", m.Actions[0].Label)
require.Equal(t, "http", m.Actions[1].Action)
require.Equal(t, "这是一个标签", m.Actions[1].Label)
require.Equal(t, "https://💩.la", m.Actions[1].URL)
}
func TestServer_UpstreamBaseURL_Success(t *testing.T) {
var pollID atomic.Pointer[string]
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/87c9cddf7b0105f5fe849bf084c6e600be0fde99be3223335199b4965bd7b735", r.URL.Path)
require.Equal(t, "", string(body))
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
s := newTestServer(t, c)
// Send message, and wait for upstream server to receive it
response := request(t, s, "PUT", "/mytopic", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
waitFor(t, func() bool {
pID := pollID.Load()
return pID != nil && *pID == m.ID
})
}
func TestServer_UpstreamBaseURL_With_Access_Token_Success(t *testing.T) {
var pollID atomic.Pointer[string]
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/a1c72bcb4daf5af54d13ef86aea8f76c11e8b88320d55f1811d5d7b173bcc1df", r.URL.Path)
require.Equal(t, "Bearer tk_1234567890", r.Header.Get("Authorization"))
require.Equal(t, "", string(body))
require.NotEmpty(t, r.Header.Get("X-Poll-ID"))
pollID.Store(util.String(r.Header.Get("X-Poll-ID")))
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
c.UpstreamAccessToken = "tk_1234567890"
s := newTestServer(t, c)
// Send message, and wait for upstream server to receive it
response := request(t, s, "PUT", "/mytopic1", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
waitFor(t, func() bool {
pID := pollID.Load()
return pID != nil && *pID == m.ID
})
}
func TestServer_UpstreamBaseURL_DoNotForwardUnifiedPush(t *testing.T) {
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("UnifiedPush messages should not be forwarded")
}))
defer upstreamServer.Close()
c := newTestConfigWithAuthFile(t)
c.BaseURL = "http://myserver.internal"
c.UpstreamBaseURL = upstreamServer.URL
s := newTestServer(t, c)
// Send UP message, this should not forward to upstream server
response := request(t, s, "PUT", "/mytopic?up=1", `hi there`, nil)
require.Equal(t, 200, response.Code)
m := toMessage(t, response.Body.String())
require.NotEmpty(t, m.ID)
require.Equal(t, "hi there", m.Message)
// Forwarding is done asynchronously, so wait a bit.
// This ensures that the t.Fatal above is actually not triggered.
time.Sleep(500 * time.Millisecond)
}
func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345"
@@ -2408,19 +2696,33 @@ func newTestConfig(t *testing.T) *Config {
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
func configureAuth(t *testing.T, conf *Config) *Config {
conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
conf.AuthStartupQueries = "pragma journal_mode = WAL; pragma synchronous = normal; pragma temp_store = memory;"
conf.AuthBcryptCost = bcrypt.MinCost // This speeds up tests a lot
return conf
}
func newTestConfigWithAuthFile(t *testing.T) *Config {
conf := newTestConfig(t)
conf = configureAuth(t, conf)
return conf
}
func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
require.Nil(t, err)
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
conf.WebPushEmailAddress = "testing@example.com"
conf.WebPushPrivateKey = privateKey
conf.WebPushPublicKey = publicKey
return conf
}
func newTestServer(t *testing.T, config *Config) *Server {
server, err := New(config)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
return server
}
@@ -2500,7 +2802,7 @@ func waitForWithMaxWait(t *testing.T, maxWait time.Duration, f func() bool) {
if f() {
return
}
time.Sleep(100 * time.Millisecond)
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("Function f did not succeed after %v: %v", maxWait, string(debug.Stack()))
}

176
server/server_twilio.go Normal file
View File

@@ -0,0 +1,176 @@
package server
import (
"bytes"
"encoding/xml"
"fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/url"
"strings"
)
const (
twilioCallFormat = `
<Response>
<Pause length="1"/>
<Say loop="3">
You have a message from notify on topic %s. Message:
<break time="1s"/>
%s
<break time="1s"/>
End of message.
<break time="1s"/>
This message was sent by user %s. It will be repeated three times.
To unsubscribe from calls like this, remove your phone number in the notify web app.
<break time="3s"/>
</Say>
<Say>Goodbye.</Say>
</Response>`
)
// convertPhoneNumber checks if the given phone number is verified for the given user, and if so, returns the verified
// phone number. It also converts a boolean string ("yes", "1", "true") to the first verified phone number.
// If the user is anonymous, it will return an error.
func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *errHTTP) {
if u == nil {
return "", errHTTPBadRequestAnonymousCallsNotAllowed
}
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
if err != nil {
return "", errHTTPInternalError
} else if len(phoneNumbers) == 0 {
return "", errHTTPBadRequestPhoneNumberNotVerified
}
if toBool(phoneNumber) {
return phoneNumbers[0], nil
} else if util.Contains(phoneNumbers, phoneNumber) {
return phoneNumber, nil
}
for _, p := range phoneNumbers {
if p == phoneNumber {
return phoneNumber, nil
}
}
return "", errHTTPBadRequestPhoneNumberNotVerified
}
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
// Failures will be logged, but not returned to the caller.
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
u, sender := v.User(), m.Sender.String()
if u != nil {
sender = u.Name
}
body := fmt.Sprintf(twilioCallFormat, xmlEscapeText(m.Topic), xmlEscapeText(m.Message), xmlEscapeText(sender))
data := url.Values{}
data.Set("From", s.config.TwilioPhoneNumber)
data.Set("To", to)
data.Set("Twiml", body)
ev := logvrm(v, r, m).Tag(tagTwilio).Field("twilio_to", to).FieldIf("twilio_body", body, log.TraceLevel).Debug("Sending Twilio request")
response, err := s.callPhoneInternal(data)
if err != nil {
ev.Field("twilio_response", response).Err(err).Warn("Error sending Twilio request")
minc(metricCallsMadeFailure)
return
}
ev.FieldIf("twilio_response", response, log.TraceLevel).Debug("Received successful Twilio response")
minc(metricCallsMadeSuccess)
}
func (s *Server) callPhoneInternal(data url.Values) (string, error) {
requestURL := fmt.Sprintf("%s/2010-04-01/Accounts/%s/Calls.json", s.config.TwilioCallsBaseURL, s.config.TwilioAccount)
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
if err != nil {
return "", err
}
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(response), nil
}
func (s *Server) verifyPhoneNumber(v *visitor, r *http.Request, phoneNumber, channel string) error {
ev := logvr(v, r).Tag(tagTwilio).Field("twilio_to", phoneNumber).Field("twilio_channel", channel).Debug("Sending phone verification")
data := url.Values{}
data.Set("To", phoneNumber)
data.Set("Channel", channel)
requestURL := fmt.Sprintf("%s/v2/Services/%s/Verifications", s.config.TwilioVerifyBaseURL, s.config.TwilioVerifyService)
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
response, err := io.ReadAll(resp.Body)
if err != nil {
ev.Err(err).Warn("Error sending Twilio phone verification request")
return err
}
ev.FieldIf("twilio_response", string(response), log.TraceLevel).Debug("Received Twilio phone verification response")
return nil
}
func (s *Server) verifyPhoneNumberCheck(v *visitor, r *http.Request, phoneNumber, code string) error {
ev := logvr(v, r).Tag(tagTwilio).Field("twilio_to", phoneNumber).Debug("Checking phone verification")
data := url.Values{}
data.Set("To", phoneNumber)
data.Set("Code", code)
requestURL := fmt.Sprintf("%s/v2/Services/%s/VerificationCheck", s.config.TwilioVerifyBaseURL, s.config.TwilioVerifyService)
req, err := http.NewRequest(http.MethodPost, requestURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}
req.Header.Set("User-Agent", "ntfy/"+s.config.Version)
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", util.BasicAuth(s.config.TwilioAccount, s.config.TwilioAuthToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
} else if resp.StatusCode != http.StatusOK {
if ev.IsTrace() {
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
ev.Field("twilio_response", string(response))
}
ev.Warn("Twilio phone verification failed with status code %d", resp.StatusCode)
if resp.StatusCode == http.StatusNotFound {
return errHTTPGonePhoneVerificationExpired
}
return errHTTPInternalError
}
response, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if ev.IsTrace() {
ev.Field("twilio_response", string(response)).Trace("Received successful Twilio phone verification response")
} else if ev.IsDebug() {
ev.Debug("Received successful Twilio phone verification response")
}
return nil
}
func xmlEscapeText(text string) string {
var buf bytes.Buffer
_ = xml.EscapeText(&buf, []byte(text))
return buf.String()
}

View File

@@ -0,0 +1,264 @@
package server
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
)
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
if code.Load() != nil {
t.Fatal("Should be only called once")
}
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
code.Store(util.String("123456"))
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
if verified.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
verified.Store(true)
} else {
t.Fatal("Unexpected path:", r.URL.Path)
}
}))
defer twilioVerifyServer.Close()
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioVerifyService = "VA1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
// Send verification code for phone number
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return *code.Load() == "123456"
})
// Add phone number with code
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return verified.Load()
})
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 1, len(phoneNumbers))
require.Equal(t, "+12223334444", phoneNumbers[0])
// Do the thing
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
// Remove the phone number
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Verify the phone number is gone from the DB
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 0, len(phoneNumbers))
}
func TestServer_Twilio_Call_Success(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
}
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes", // <<<------
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
}
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+invalid",
})
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+123123",
})
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
})
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
}

171
server/server_webpush.go Normal file
View File

@@ -0,0 +1,171 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/SherClockHolmes/webpush-go"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
)
const (
webPushTopicSubscribeLimit = 50
)
var (
webPushAllowedEndpointsPatterns = []string{
"https://*.google.com/",
"https://*.googleapis.com/",
"https://*.mozilla.com/",
"https://*.mozaws.net/",
"https://*.windows.com/",
"https://*.microsoft.com/",
"https://*.apple.com/",
}
webPushAllowedEndpointsRegex *regexp.Regexp
)
func init() {
for i, pattern := range webPushAllowedEndpointsPatterns {
webPushAllowedEndpointsPatterns[i] = strings.ReplaceAll(strings.ReplaceAll(pattern, ".", "\\."), "*", ".+")
}
allPatterns := fmt.Sprintf("^(%s)", strings.Join(webPushAllowedEndpointsPatterns, "|"))
webPushAllowedEndpointsRegex = regexp.MustCompile(allPatterns)
}
func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil || req.Endpoint == "" || req.P256dh == "" || req.Auth == "" {
return errHTTPBadRequestWebPushSubscriptionInvalid
} else if !webPushAllowedEndpointsRegex.MatchString(req.Endpoint) {
return errHTTPBadRequestWebPushEndpointUnknown
} else if len(req.Topics) > webPushTopicSubscribeLimit {
return errHTTPBadRequestWebPushTopicCountTooHigh
}
topics, err := s.topicsFromIDs(req.Topics...)
if err != nil {
return err
}
if s.userManager != nil {
u := v.User()
for _, t := range topics {
if err := s.userManager.Authorize(u, t.ID, user.PermissionRead); err != nil {
logvr(v, r).With(t).Err(err).Debug("Access to topic %s not authorized", t.ID)
return errHTTPForbidden.With(t)
}
}
}
if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *visitor) error {
req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil || req.Endpoint == "" {
return errHTTPBadRequestWebPushSubscriptionInvalid
}
if err := s.webPush.RemoveSubscriptionsByEndpoint(req.Endpoint); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
if err != nil {
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
return
}
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m))
if err != nil {
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
return
}
for _, subscription := range subscriptions {
if err := s.sendWebPushNotification(subscription, payload, v, m); err != nil {
log.Tag(tagWebPush).Err(err).With(v, m, subscription).Warn("Unable to publish web push message")
}
}
}
func (s *Server) pruneAndNotifyWebPushSubscriptions() {
if s.config.WebPushPublicKey == "" {
return
}
go func() {
if err := s.pruneAndNotifyWebPushSubscriptionsInternal(); err != nil {
log.Tag(tagWebPush).Err(err).Warn("Unable to prune or notify web push subscriptions")
}
}()
}
func (s *Server) pruneAndNotifyWebPushSubscriptionsInternal() error {
// Expire old subscriptions
if err := s.webPush.RemoveExpiredSubscriptions(s.config.WebPushExpiryDuration); err != nil {
return err
}
// Notify subscriptions that will expire soon
subscriptions, err := s.webPush.SubscriptionsExpiring(s.config.WebPushExpiryWarningDuration)
if err != nil {
return err
} else if len(subscriptions) == 0 {
return nil
}
payload, err := json.Marshal(newWebPushSubscriptionExpiringPayload())
if err != nil {
return err
}
warningSent := make([]*webPushSubscription, 0)
for _, subscription := range subscriptions {
if err := s.sendWebPushNotification(subscription, payload); err != nil {
log.Tag(tagWebPush).Err(err).With(subscription).Warn("Unable to publish expiry imminent warning")
continue
}
warningSent = append(warningSent, subscription)
}
if err := s.webPush.MarkExpiryWarningSent(warningSent); err != nil {
return err
}
log.Tag(tagWebPush).Debug("Expired old subscriptions and published %d expiry imminent warnings", len(subscriptions))
return nil
}
func (s *Server) sendWebPushNotification(sub *webPushSubscription, message []byte, contexters ...log.Contexter) error {
log.Tag(tagWebPush).With(sub).With(contexters...).Debug("Sending web push message")
payload := &webpush.Subscription{
Endpoint: sub.Endpoint,
Keys: webpush.Keys{
Auth: sub.Auth,
P256dh: sub.P256dh,
},
}
resp, err := webpush.SendNotification(message, payload, &webpush.Options{
Subscriber: s.config.WebPushEmailAddress,
VAPIDPublicKey: s.config.WebPushPublicKey,
VAPIDPrivateKey: s.config.WebPushPrivateKey,
Urgency: webpush.UrgencyHigh, // iOS requires this to ensure delivery
TTL: int(s.config.CacheDuration.Seconds()),
})
if err != nil {
log.Tag(tagWebPush).With(sub).With(contexters...).Err(err).Debug("Unable to publish web push message, removing endpoint")
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
return err
}
return err
}
if (resp.StatusCode < 200 || resp.StatusCode > 299) && resp.StatusCode != 429 {
log.Tag(tagWebPush).With(sub).With(contexters...).Field("response_code", resp.StatusCode).Debug("Unable to publish web push message, unexpected response")
if err := s.webPush.RemoveSubscriptionsByEndpoint(sub.Endpoint); err != nil {
return err
}
return errHTTPInternalErrorWebPushUnableToPublish.With(sub).With(contexters...)
}
return nil
}

View File

@@ -0,0 +1,256 @@
package server
import (
"encoding/json"
"fmt"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
)
const (
testWebPushEndpoint = "https://updates.push.services.mozilla.com/wpush/v1/AAABBCCCDDEEEFFF"
)
func TestServer_WebPush_Disabled(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
}
func TestServer_WebPush_TopicAdd(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
}
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
}
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Delete(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Publish(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
}
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
}
func TestServer_WebPush_Expiry(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
waitFor(t, func() bool {
return received.Load()
})
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
})
}
func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
topicsJSON, err := json.Marshal(topics)
require.Nil(t, err)
return fmt.Sprintf(`{
"topics": %s,
"endpoint": "%s",
"p256dh": "p256dh-key",
"auth": "auth-key"
}`, topicsJSON, endpoint)
}
func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
}
func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
subs, err := s.webPush.SubscriptionsForTopic(topic)
require.Nil(t, err)
require.Len(t, subs, expectedLength)
}

View File

@@ -4,14 +4,15 @@ import (
_ "embed" // required by go:embed
"encoding/json"
"fmt"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"mime"
"net"
"net/smtp"
"strings"
"sync"
"time"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
)
type mailer interface {
@@ -131,31 +132,23 @@ This message was sent by {ip} at {time} via {topicURL}`
}
var (
//go:embed "mailer_emoji.json"
//go:embed "mailer_emoji_map.json"
emojisJSON string
)
type emoji struct {
Emoji string `json:"emoji"`
Aliases []string `json:"aliases"`
}
func toEmojis(tags []string) (emojisOut []string, tagsOut []string, err error) {
var emojis []emoji
if err = json.Unmarshal([]byte(emojisJSON), &emojis); err != nil {
var emojiMap map[string]string
if err = json.Unmarshal([]byte(emojisJSON), &emojiMap); err != nil {
return nil, nil, err
}
tagsOut = make([]string, 0)
emojisOut = make([]string, 0)
nextTag:
for _, t := range tags { // TODO Super inefficient; we should just create a .json file with a map
for _, e := range emojis {
if util.Contains(e.Aliases, t) {
emojisOut = append(emojisOut, e.Emoji)
continue nextTag
}
for _, t := range tags {
if emoji, ok := emojiMap[t]; ok {
emojisOut = append(emojisOut, emoji)
} else {
tagsOut = append(tagsOut, t)
}
tagsOut = append(tagsOut, t)
}
return
}

View File

@@ -10,6 +10,7 @@ import (
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net"
"net/http"
"net/http/httptest"
@@ -296,6 +297,8 @@ func readTextMailBody(reader io.Reader, contentType, transferEncoding string) (s
func readPlainTextMailBody(reader io.Reader, transferEncoding string) (string, error) {
if strings.ToLower(transferEncoding) == "base64" {
reader = base64.NewDecoder(base64.StdEncoding, reader)
} else if strings.ToLower(transferEncoding) == "quoted-printable" {
reader = quotedprintable.NewReader(reader)
}
body, err := io.ReadAll(reader)
if err != nil {

View File

@@ -303,6 +303,39 @@ BBBBBBBBBBBBBBBBBBBBBBBBB`
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Plaintext_QuotedPrintable(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: mytopic@ntfy.sh
DATA
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
From: Phil <phil@example.com>
To: mytopic@ntfy.sh
Content-Type: text/plain; charset="UTF-8"
Content-Transfer-Encoding: quoted-printable
what's
=C3=A0&=C3=A9"'(-=C3=A8_=C3=A7=C3=A0)
=3D=3D=3D=3D=3D
up
.
`
s, c, conf, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, `what's
à&é"'(-è_çà)
=====
up`, readAll(t, r.Body))
})
conf.SMTPServerAddrPrefix = ""
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_Unsupported(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
@@ -390,6 +423,49 @@ L0VOIj4KClRoaXMgaXMgYSB0ZXN0IG1lc3NhZ2UgZnJvbSBUcnVlTkFTIENPUkUuCg==
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_MultipartQuotedPrintable(t *testing.T) {
email := `EHLO example.com
MAIL FROM: phil@example.com
RCPT TO: ntfy-mytopic@ntfy.sh
DATA
MIME-Version: 1.0
Date: Tue, 28 Dec 2021 00:30:10 +0100
Message-ID: <CAAvm79YP0C=Rt1N=KWmSUBB87KK2rRChmdzKqF1vCwMEUiVzLQ@mail.gmail.com>
Subject: and one more
From: Phil <phil@example.com>
To: ntfy-mytopic@ntfy.sh
Content-Type: multipart/alternative; boundary="000000000000f3320b05d42915c9"
--000000000000f3320b05d42915c9
Content-Type: text/html; charset="UTF-8"
html, ignore me
--000000000000f3320b05d42915c9
Content-Type: text/plain; charset="UTF-8"
Content-Transfer-Encoding: quoted-printable
what's
=C3=A0&=C3=A9"'(-=C3=A8_=C3=A7=C3=A0)
=3D=3D=3D=3D=3D
up
--000000000000f3320b05d42915c9--
.
`
s, c, _, scanner := newTestSMTPServer(t, func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/mytopic", r.URL.Path)
require.Equal(t, "and one more", r.Header.Get("Title"))
require.Equal(t, `what's
à&é"'(-è_çà)
=====
up`, readAll(t, r.Body))
})
defer s.Close()
defer c.Close()
writeAndReadUntilLine(t, email, c, scanner, "250 2.0.0 OK: queued")
}
func TestSmtpBackend_NestedMultipartBase64(t *testing.T) {
email := `EHLO example.com
MAIL FROM: test@mydomain.me

View File

@@ -1,11 +1,12 @@
package server
import (
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"math/rand"
"sync"
"time"
"heckel.io/ntfy/log"
"heckel.io/ntfy/util"
)
const (
@@ -44,10 +45,16 @@ func newTopic(id string) *topic {
}
// Subscribe subscribes to this topic
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) (subscriberID int) {
t.mu.Lock()
defer t.mu.Unlock()
subscriberID := rand.Int()
for i := 0; i < 5; i++ { // Best effort retry
subscriberID = rand.Int()
_, exists := t.subscribers[subscriberID]
if !exists {
break
}
}
t.subscribers[subscriberID] = &topicSubscriber{
userID: userID, // May be empty
subscriber: s,
@@ -134,24 +141,40 @@ func (t *topic) Keepalive() {
t.lastAccess = time.Now()
}
// CancelSubscribers calls the cancel function for all subscribers, forcing
func (t *topic) CancelSubscribers(exceptUserID string) {
// CancelSubscribersExceptUser calls the cancel function for all subscribers, forcing
func (t *topic) CancelSubscribersExceptUser(exceptUserID string) {
t.mu.Lock()
defer t.mu.Unlock()
for _, s := range t.subscribers {
if s.userID != exceptUserID {
log.
Tag(tagSubscribe).
With(t).
Fields(log.Context{
"user_id": s.userID,
}).
Debug("Canceling subscriber %s", s.userID)
s.cancel()
t.cancelUserSubscriber(s)
}
}
}
// CancelSubscriberUser kills the subscriber with the given user ID
func (t *topic) CancelSubscriberUser(userID string) {
t.mu.RLock()
defer t.mu.RUnlock()
for _, s := range t.subscribers {
if s.userID == userID {
t.cancelUserSubscriber(s)
return
}
}
}
func (t *topic) cancelUserSubscriber(s *topicSubscriber) {
log.
Tag(tagSubscribe).
With(t).
Fields(log.Context{
"user_id": s.userID,
}).
Debug("Canceling subscriber with user ID %s", s.userID)
s.cancel()
}
func (t *topic) Context() log.Context {
t.mu.RLock()
defer t.mu.RUnlock()

View File

@@ -1,13 +1,15 @@
package server
import (
"github.com/stretchr/testify/require"
"math/rand"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestTopic_CancelSubscribers(t *testing.T) {
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
@@ -25,11 +27,34 @@ func TestTopic_CancelSubscribers(t *testing.T) {
to.Subscribe(subFn, "", cancelFn1)
to.Subscribe(subFn, "u_phil", cancelFn2)
to.CancelSubscribers("u_phil")
to.CancelSubscribersExceptUser("u_phil")
require.True(t, canceled1.Load())
require.False(t, canceled2.Load())
}
func TestTopic_CancelSubscribersUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
return nil
}
canceled1 := atomic.Bool{}
cancelFn1 := func() {
canceled1.Store(true)
}
canceled2 := atomic.Bool{}
cancelFn2 := func() {
canceled2.Store(true)
}
to := newTopic("mytopic")
to.Subscribe(subFn, "u_another", cancelFn1)
to.Subscribe(subFn, "u_phil", cancelFn2)
to.CancelSubscriberUser("u_phil")
require.False(t, canceled1.Load())
require.True(t, canceled2.Load())
}
func TestTopic_Keepalive(t *testing.T) {
t.Parallel()
@@ -39,3 +64,29 @@ func TestTopic_Keepalive(t *testing.T) {
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
}
func TestTopic_Subscribe_DuplicateID(t *testing.T) {
t.Parallel()
to := newTopic("mytopic")
// Fix random seed to force same number generation
rand.Seed(1)
a := rand.Int()
to.subscribers[a] = &topicSubscriber{
userID: "a",
subscriber: nil,
cancel: func() {},
}
subFn := func(v *visitor, msg *message) error {
return nil
}
// Force rand.Int to generate the same id once more
rand.Seed(1)
id := to.Subscribe(subFn, "b", func() {})
res := to.subscribers[id]
require.NotEqual(t, id, a)
require.Equal(t, "b", res.userID, "b")
}

View File

@@ -1,12 +1,13 @@
package server
import (
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"net/http"
"net/netip"
"time"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
)
@@ -24,23 +25,24 @@ const (
// message represents a message published to a topic
type message struct {
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // Username of the uploader, used to associated attachments
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
func (m *message) Context() log.Context {
@@ -99,8 +101,10 @@ type publishMessage struct {
Icon string `json:"icon"`
Actions []action `json:"actions"`
Attach string `json:"attach"`
Markdown bool `json:"markdown"`
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Delay string `json:"delay"`
}
@@ -239,6 +243,45 @@ type apiHealthResponse struct {
Healthy bool `json:"healthy"`
}
type apiStatsResponse struct {
Messages int64 `json:"messages"`
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
}
type apiUserAddRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Tier string `json:"tier"`
// Do not add 'role' here. We don't want to add admins via the API.
}
type apiUserResponse struct {
Username string `json:"username"`
Role string `json:"role"`
Tier string `json:"tier,omitempty"`
Grants []*apiUserGrantResponse `json:"grants,omitempty"`
}
type apiUserGrantResponse struct {
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
}
type apiUserDeleteRequest struct {
Username string `json:"username"`
}
type apiAccessAllowRequest struct {
Username string `json:"username"`
Topic string `json:"topic"` // This may be a pattern
Permission string `json:"permission"`
}
type apiAccessResetRequest struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
type apiAccountCreateRequest struct {
Username string `json:"username"`
Password string `json:"password"`
@@ -272,6 +315,16 @@ type apiAccountTokenResponse struct {
Expires int64 `json:"expires,omitempty"` // Unix timestamp
}
type apiAccountPhoneNumberVerifyRequest struct {
Number string `json:"number"`
Channel string `json:"channel"`
}
type apiAccountPhoneNumberAddRequest struct {
Number string `json:"number"`
Code string `json:"code"` // Only set when adding a phone number
}
type apiAccountTier struct {
Code string `json:"code"`
Name string `json:"name"`
@@ -282,6 +335,7 @@ type apiAccountLimits struct {
Messages int64 `json:"messages"`
MessagesExpiryDuration int64 `json:"messages_expiry_duration"`
Emails int64 `json:"emails"`
Calls int64 `json:"calls"`
Reservations int64 `json:"reservations"`
AttachmentTotalSize int64 `json:"attachment_total_size"`
AttachmentFileSize int64 `json:"attachment_file_size"`
@@ -294,6 +348,8 @@ type apiAccountStats struct {
MessagesRemaining int64 `json:"messages_remaining"`
Emails int64 `json:"emails"`
EmailsRemaining int64 `json:"emails_remaining"`
Calls int64 `json:"calls"`
CallsRemaining int64 `json:"calls_remaining"`
Reservations int64 `json:"reservations"`
ReservationsRemaining int64 `json:"reservations_remaining"`
AttachmentTotalSize int64 `json:"attachment_total_size"`
@@ -323,6 +379,7 @@ type apiAccountResponse struct {
Subscriptions []*user.Subscription `json:"subscriptions,omitempty"`
Reservations []*apiAccountReservation `json:"reservations,omitempty"`
Tokens []*apiAccountTokenResponse `json:"tokens,omitempty"`
PhoneNumbers []string `json:"phone_numbers,omitempty"`
Tier *apiAccountTier `json:"tier,omitempty"`
Limits *apiAccountLimits `json:"limits,omitempty"`
Stats *apiAccountStats `json:"stats,omitempty"`
@@ -340,8 +397,12 @@ type apiConfigResponse struct {
EnableLogin bool `json:"enable_login"`
EnableSignup bool `json:"enable_signup"`
EnablePayments bool `json:"enable_payments"`
EnableCalls bool `json:"enable_calls"`
EnableEmails bool `json:"enable_emails"`
EnableReservations bool `json:"enable_reservations"`
EnableWebPush bool `json:"enable_web_push"`
BillingContact string `json:"billing_contact"`
WebPushPublicKey string `json:"web_push_public_key"`
DisallowedTopics []string `json:"disallowed_topics"`
}
@@ -406,3 +467,75 @@ type apiStripeSubscriptionDeletedEvent struct {
ID string `json:"id"`
Customer string `json:"customer"`
}
type apiWebPushUpdateSubscriptionRequest struct {
Endpoint string `json:"endpoint"`
Auth string `json:"auth"`
P256dh string `json:"p256dh"`
Topics []string `json:"topics"`
}
// List of possible Web Push events (see sw.js)
const (
webPushMessageEvent = "message"
webPushExpiringEvent = "subscription_expiring"
)
type webPushPayload struct {
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *message `json:"message"`
}
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
return &webPushPayload{
Event: webPushMessageEvent,
SubscriptionID: subscriptionID,
Message: message,
}
}
type webPushControlMessagePayload struct {
Event string `json:"event"`
}
func newWebPushSubscriptionExpiringPayload() *webPushControlMessagePayload {
return &webPushControlMessagePayload{
Event: webPushExpiringEvent,
}
}
type webPushSubscription struct {
ID string
Endpoint string
Auth string
P256dh string
UserID string
}
func (w *webPushSubscription) Context() log.Context {
return map[string]any{
"web_push_subscription_id": w.ID,
"web_push_subscription_user_id": w.UserID,
"web_push_subscription_endpoint": w.Endpoint,
}
}
// https://developer.mozilla.org/en-US/docs/Web/Manifest
type webManifestResponse struct {
Name string `json:"name"`
Description string `json:"description"`
ShortName string `json:"short_name"`
Scope string `json:"scope"`
StartURL string `json:"start_url"`
Display string `json:"display"`
BackgroundColor string `json:"background_color"`
ThemeColor string `json:"theme_color"`
Icons []*webManifestIcon `json:"icons"`
}
type webManifestIcon struct {
SRC string `json:"src"`
Sizes string `json:"sizes"`
Type string `json:"type"`
}

View File

@@ -5,16 +5,31 @@ import (
"fmt"
"heckel.io/ntfy/util"
"io"
"mime"
"net/http"
"net/netip"
"regexp"
"strings"
)
var (
mimeDecoder mime.WordDecoder
priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
)
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
value := strings.ToLower(readParam(r, names...))
if value == "" {
return defaultValue
}
return toBool(value)
}
func isBoolValue(value string) bool {
return value == "1" || value == "yes" || value == "true" || value == "0" || value == "no" || value == "false"
}
func toBool(value string) bool {
return value == "1" || value == "yes" || value == "true"
}
@@ -39,9 +54,9 @@ func readParam(r *http.Request, names ...string) string {
func readHeaderParam(r *http.Request, names ...string) string {
for _, name := range names {
value := r.Header.Get(name)
value := strings.TrimSpace(maybeDecodeHeader(name, r.Header.Get(name)))
if value != "" {
return strings.TrimSpace(value)
return value
}
}
return ""
@@ -114,3 +129,27 @@ func fromContext[T any](r *http.Request, key contextKey) (T, error) {
}
return t, nil
}
// maybeDecodeHeader decodes the given header value if it is MIME encoded, e.g. "=?utf-8?q?Hello_World?=",
// or returns the original header value if it is not MIME encoded. It also calls maybeIgnoreSpecialHeader
// to ignore new HTTP "Priority" header.
func maybeDecodeHeader(name, value string) string {
decoded, err := mimeDecoder.DecodeHeader(value)
if err != nil {
return maybeIgnoreSpecialHeader(name, value)
}
return maybeIgnoreSpecialHeader(name, decoded)
}
// maybeIgnoreSpecialHeader ignores new HTTP "Priority" header (see https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-priority)
//
// Cloudflare (and potentially other providers) add this to requests when forwarding to the backend (ntfy),
// so we just ignore it. If the "Priority" header is set to "u=*, i" or "u=*" (by Cloudflare), the header will be ignored.
// Returning an empty string will allow the rest of the logic to continue searching for another header (x-priority, prio, p),
// or in the Query parameters.
func maybeIgnoreSpecialHeader(name, value string) string {
if strings.ToLower(name) == "priority" && priorityHeaderIgnoreRegex.MatchString(strings.TrimSpace(value)) {
return ""
}
return value
}

View File

@@ -2,9 +2,9 @@ package server
import (
"bytes"
"crypto/rand"
"fmt"
"github.com/stretchr/testify/require"
"math/rand"
"net/http"
"strings"
"testing"
@@ -75,3 +75,16 @@ Accept: */*
(peeked bytes not UTF-8, peek limit of 4096 bytes reached, hex: ` + fmt.Sprintf("%x", body[:4096]) + ` ...)`
require.Equal(t, expected, renderHTTPRequest(r))
}
func TestMaybeIgnoreSpecialHeader(t *testing.T) {
require.Empty(t, maybeIgnoreSpecialHeader("priority", "u=1"))
require.Empty(t, maybeIgnoreSpecialHeader("Priority", "u=1"))
require.Empty(t, maybeIgnoreSpecialHeader("Priority", "u=1, i"))
}
func TestMaybeDecodeHeaders(t *testing.T) {
r, _ := http.NewRequest("GET", "http://ntfy.sh/mytopic/json?since=all", nil)
r.Header.Set("Priority", "u=1") // Cloudflare priority header
r.Header.Set("X-Priority", "5") // ntfy priority header
require.Equal(t, "5", readHeaderParam(r, "x-priority", "priority", "p"))
}

View File

@@ -24,6 +24,10 @@ const (
// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
visitorDefaultReservationsLimit = int64(0)
// visitorDefaultCallsLimit is the amount of calls a user without a tier is allowed to make.
// This number is zero, because phone numbers have to be verified first.
visitorDefaultCallsLimit = int64(0)
)
// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
@@ -56,6 +60,7 @@ type visitor struct {
requestLimiter *rate.Limiter // Rate limiter for (almost) all requests (including messages)
messagesLimiter *util.FixedLimiter // Rate limiter for messages
emailsLimiter *util.RateLimiter // Rate limiter for emails
callsLimiter *util.FixedLimiter // Rate limiter for calls
subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
bandwidthLimiter *util.RateLimiter // Limiter for attachment bandwidth downloads
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
@@ -79,6 +84,7 @@ type visitorLimits struct {
EmailLimit int64
EmailLimitBurst int
EmailLimitReplenish rate.Limit
CallLimit int64
ReservationsLimit int64
AttachmentTotalSizeLimit int64
AttachmentFileSizeLimit int64
@@ -91,6 +97,8 @@ type visitorStats struct {
MessagesRemaining int64
Emails int64
EmailsRemaining int64
Calls int64
CallsRemaining int64
Reservations int64
ReservationsRemaining int64
AttachmentTotalSize int64
@@ -107,10 +115,11 @@ const (
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails int64
var messages, emails, calls int64
if user != nil {
messages = user.Stats.Messages
emails = user.Stats.Emails
calls = user.Stats.Calls
}
v := &visitor{
config: conf,
@@ -124,11 +133,12 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
requestLimiter: nil, // Set in resetLimiters
messagesLimiter: nil, // Set in resetLimiters, may be nil
emailsLimiter: nil, // Set in resetLimiters
callsLimiter: nil, // Set in resetLimiters, may be nil
bandwidthLimiter: nil, // Set in resetLimiters
accountLimiter: nil, // Set in resetLimiters, may be nil
authLimiter: nil, // Set in resetLimiters, may be nil
}
v.resetLimitersNoLock(messages, emails, false)
v.resetLimitersNoLock(messages, emails, calls, false)
return v
}
@@ -147,12 +157,19 @@ func (v *visitor) contextNoLock() log.Context {
"visitor_messages": info.Stats.Messages,
"visitor_messages_limit": info.Limits.MessageLimit,
"visitor_messages_remaining": info.Stats.MessagesRemaining,
"visitor_emails": info.Stats.Emails,
"visitor_emails_limit": info.Limits.EmailLimit,
"visitor_emails_remaining": info.Stats.EmailsRemaining,
"visitor_request_limiter_limit": v.requestLimiter.Limit(),
"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
}
if v.config.SMTPSenderFrom != "" {
fields["visitor_emails"] = info.Stats.Emails
fields["visitor_emails_limit"] = info.Limits.EmailLimit
fields["visitor_emails_remaining"] = info.Stats.EmailsRemaining
}
if v.config.TwilioAccount != "" {
fields["visitor_calls"] = info.Stats.Calls
fields["visitor_calls_limit"] = info.Limits.CallLimit
fields["visitor_calls_remaining"] = info.Stats.CallsRemaining
}
if v.authLimiter != nil {
fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
@@ -216,6 +233,12 @@ func (v *visitor) EmailAllowed() bool {
return v.emailsLimiter.Allow()
}
func (v *visitor) CallAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
return v.callsLimiter.Allow()
}
func (v *visitor) SubscriptionAllowed() bool {
v.mu.RLock() // limiters could be replaced!
defer v.mu.RUnlock()
@@ -296,6 +319,7 @@ func (v *visitor) Stats() *user.Stats {
return &user.Stats{
Messages: v.messagesLimiter.Value(),
Emails: v.emailsLimiter.Value(),
Calls: v.callsLimiter.Value(),
}
}
@@ -304,6 +328,7 @@ func (v *visitor) ResetStats() {
defer v.mu.RUnlock()
v.emailsLimiter.Reset()
v.messagesLimiter.Reset()
v.callsLimiter.Reset()
}
// User returns the visitor user, or nil if there is none
@@ -334,11 +359,11 @@ func (v *visitor) SetUser(u *user.User) {
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
v.user = u // u may be nil!
if shouldResetLimiters {
var messages, emails int64
var messages, emails, calls int64
if u != nil {
messages, emails = u.Stats.Messages, u.Stats.Emails
messages, emails, calls = u.Stats.Messages, u.Stats.Emails, u.Stats.Calls
}
v.resetLimitersNoLock(messages, emails, true)
v.resetLimitersNoLock(messages, emails, calls, true)
}
}
@@ -353,11 +378,12 @@ func (v *visitor) MaybeUserID() string {
return ""
}
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) {
func (v *visitor) resetLimitersNoLock(messages, emails, calls int64, enqueueUpdate bool) {
limits := v.limitsNoLock()
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)
v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails)
v.callsLimiter = util.NewFixedLimiterWithValue(limits.CallLimit, calls)
v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
if v.user == nil {
v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
@@ -370,6 +396,7 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
go v.userManager.EnqueueUserStats(v.user.ID, &user.Stats{
Messages: messages,
Emails: emails,
Calls: calls,
})
}
log.Fields(v.contextNoLock()).Debug("Rate limiters reset for visitor") // Must be after function, because contextNoLock() describes rate limiters
@@ -398,6 +425,7 @@ func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
EmailLimit: tier.EmailLimit,
EmailLimitBurst: util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
EmailLimitReplenish: dailyLimitToRate(tier.EmailLimit),
CallLimit: tier.CallLimit,
ReservationsLimit: tier.ReservationLimit,
AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
AttachmentFileSizeLimit: tier.AttachmentFileSizeLimit,
@@ -420,6 +448,7 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits {
EmailLimit: replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
EmailLimitBurst: conf.VisitorEmailLimitBurst,
EmailLimitReplenish: rate.Every(conf.VisitorEmailLimitReplenish),
CallLimit: visitorDefaultCallsLimit,
ReservationsLimit: visitorDefaultReservationsLimit,
AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
AttachmentFileSizeLimit: conf.AttachmentFileSizeLimit,
@@ -465,12 +494,15 @@ func (v *visitor) Info() (*visitorInfo, error) {
func (v *visitor) infoLightNoLock() *visitorInfo {
messages := v.messagesLimiter.Value()
emails := v.emailsLimiter.Value()
calls := v.callsLimiter.Value()
limits := v.limitsNoLock()
stats := &visitorStats{
Messages: messages,
MessagesRemaining: zeroIfNegative(limits.MessageLimit - messages),
Emails: emails,
EmailsRemaining: zeroIfNegative(limits.EmailLimit - emails),
Calls: calls,
CallsRemaining: zeroIfNegative(limits.CallLimit - calls),
}
return &visitorInfo{
Limits: limits,

280
server/webpush_store.go Normal file
View File

@@ -0,0 +1,280 @@
package server
import (
"database/sql"
"errors"
"heckel.io/ntfy/util"
"net/netip"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
const (
subscriptionIDPrefix = "wps_"
subscriptionIDLength = 10
subscriptionEndpointLimitPerSubscriberIP = 10
)
var (
errWebPushNoRows = errors.New("no rows found")
errWebPushTooManySubscriptions = errors.New("too many subscriptions")
errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
)
const (
createWebPushSubscriptionsTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS subscription (
id TEXT PRIMARY KEY,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
user_id TEXT NOT NULL,
subscriber_ip TEXT NOT NULL,
updated_at INT NOT NULL,
warned_at INT NOT NULL DEFAULT 0
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
CREATE TABLE IF NOT EXISTS subscription_topic (
subscription_id TEXT NOT NULL,
topic TEXT NOT NULL,
PRIMARY KEY (subscription_id, topic),
FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
COMMIT;
`
builtinStartupQueries = `
PRAGMA foreign_keys = ON;
`
selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
selectWebPushSubscriptionsForTopicQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription_topic st
JOIN subscription s ON s.id = st.subscription_id
WHERE st.topic = ?
ORDER BY endpoint
`
selectWebPushSubscriptionsExpiringSoonQuery = `
SELECT id, endpoint, key_auth, key_p256dh, user_id
FROM subscription
WHERE warned_at = 0 AND updated_at <= ?
`
insertWebPushSubscriptionQuery = `
INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (endpoint)
DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
`
updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
)
// Schema management queries
const (
currentWebPushSchemaVersion = 1
insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
type webPushStore struct {
db *sql.DB
}
func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupWebPushDB(db); err != nil {
return nil, err
}
if err := runWebPushStartupQueries(db, startupQueries); err != nil {
return nil, err
}
return &webPushStore{
db: db,
}, nil
}
func setupWebPushDB(db *sql.DB) error {
// If 'schemaVersion' table does not exist, this must be a new database
rows, err := db.Query(selectWebPushSchemaVersionQuery)
if err != nil {
return setupNewWebPushDB(db)
}
return rows.Close()
}
func setupNewWebPushDB(db *sql.DB) error {
if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
return err
}
if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
return err
}
return nil
}
func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
if _, err := db.Exec(builtinStartupQueries); err != nil {
return err
}
return nil
}
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
// existing entries for a given endpoint.
func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
if err != nil {
return err
}
defer rowsCount.Close()
var subscriptionCount int
if !rowsCount.Next() {
return errWebPushNoRows
}
if err := rowsCount.Scan(&subscriptionCount); err != nil {
return err
}
if err := rowsCount.Close(); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
if err != nil {
return err
}
defer rows.Close()
var subscriptionID string
if rows.Next() {
if err := rows.Scan(&subscriptionID); err != nil {
return err
}
} else {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return errWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
}
if err := rows.Close(); err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
return err
}
}
return tx.Commit()
}
// SubscriptionsForTopic returns all subscriptions for the given topic
func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
if err != nil {
return nil, err
}
defer rows.Close()
return c.subscriptionsFromRows(rows)
}
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, subscription := range subscriptions {
if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
return err
}
}
return tx.Commit()
}
func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
subscriptions := make([]*webPushSubscription, 0)
for rows.Next() {
var id, endpoint, auth, p256dh, userID string
if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
return nil, err
}
subscriptions = append(subscriptions, &webPushSubscription{
ID: id,
Endpoint: endpoint,
Auth: auth,
P256dh: p256dh,
UserID: userID,
})
}
return subscriptions, nil
}
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
return err
}
// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
if userID == "" {
return errWebPushUserIDCannotBeEmpty
}
_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
return err
}
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
return err
}
// Close closes the underlying database connection
func (c *webPushStore) Close() error {
return c.db.Close()
}

View File

@@ -0,0 +1,199 @@
package server
import (
"fmt"
"github.com/stretchr/testify/require"
"net/netip"
"path/filepath"
"testing"
"time"
)
func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
subs, err := webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "u_1234")
subs2, err := webPush.SubscriptionsForTopic("mytopic")
require.Nil(t, err)
require.Len(t, subs2, 1)
require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
}
func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert 10 subscriptions with the same IP address
for i := 0; i < 10; i++ {
endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
}
// Another one for the same endpoint should be fine
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different endpoint it should fail
require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
// But with a different IP address it should be fine again
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
}
func TestWebPushStore_UpsertSubscription_UpdateTopics(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics, and another with one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"1", "auth-key", "p256dh-key", "", netip.MustParseAddr("9.9.9.9"), []string{"topic1"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
require.Equal(t, testWebPushEndpoint+"1", subs[1].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
// Update the first subscription to have only one topic
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1"}))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 2)
require.Equal(t, testWebPushEndpoint+"0", subs[0].Endpoint)
subs, err = webPush.SubscriptionsForTopic("topic2")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByEndpoint(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByEndpoint(testWebPushEndpoint))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// And remove it again
require.Nil(t, webPush.RemoveSubscriptionsByUserID("u_1234"))
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func TestWebPushStore_RemoveSubscriptionsByUserID_Empty(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
require.Equal(t, errWebPushUserIDCannotBeEmpty, webPush.RemoveSubscriptionsByUserID(""))
}
func TestWebPushStore_MarkExpiryWarningSent(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Mark them as warning sent
require.Nil(t, webPush.MarkExpiryWarningSent(subs))
rows, err := webPush.db.Query("SELECT endpoint FROM subscription WHERE warned_at > 0")
require.Nil(t, err)
defer rows.Close()
var endpoint string
require.True(t, rows.Next())
require.Nil(t, rows.Scan(&endpoint))
require.Nil(t, err)
require.Equal(t, testWebPushEndpoint, endpoint)
require.False(t, rows.Next())
}
func TestWebPushStore_SubscriptionsExpiring(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as soon-to-expire
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-8*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Should not be cleaned up yet
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// Run expiration
subs, err = webPush.SubscriptionsExpiring(7 * 24 * time.Hour)
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, testWebPushEndpoint, subs[0].Endpoint)
}
func TestWebPushStore_RemoveExpiredSubscriptions(t *testing.T) {
webPush := newTestWebPushStore(t)
defer webPush.Close()
// Insert subscription with two topics
require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"topic1", "topic2"}))
subs, err := webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 1)
// Fake-mark them as expired
_, err = webPush.db.Exec("UPDATE subscription SET updated_at = ? WHERE endpoint = ?", time.Now().Add(-10*24*time.Hour).Unix(), testWebPushEndpoint)
require.Nil(t, err)
// Run expiration
require.Nil(t, webPush.RemoveExpiredSubscriptions(9*24*time.Hour))
// List again, should be 0
subs, err = webPush.SubscriptionsForTopic("topic1")
require.Nil(t, err)
require.Len(t, subs, 0)
}
func newTestWebPushStore(t *testing.T) *webPushStore {
webPush, err := newWebPushStore(filepath.Join(t.TempDir(), "webpush.db"), "")
require.Nil(t, err)
return webPush
}