Compare commits

...

5 Commits

Author SHA1 Message Date
binwiederhier
9e4a48b058 Make server tests also run against postgres 2026-02-19 20:48:01 -05:00
binwiederhier
939b3d1117 Fix lint, make pipeline use psotgres 2026-02-18 21:07:31 -05:00
binwiederhier
9cc9891f49 Add postgres to pipeline 2026-02-18 20:55:03 -05:00
binwiederhier
0d1f3444f2 fmt 2026-02-18 20:48:41 -05:00
binwiederhier
2716ede6e1 Extract message cache into message/ package with model/ types
Move message cache from server/message_cache.go into a dedicated
message/ package with Store interface, SQLite and PostgreSQL
implementations. Extract shared types into model/ package.
2026-02-18 20:22:44 -05:00
39 changed files with 8566 additions and 7071 deletions

View File

@@ -6,6 +6,22 @@ on:
jobs:
release:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: ntfy
POSTGRES_PASSWORD: ntfy
POSTGRES_DB: ntfy_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U ntfy"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
steps:
- name: Checkout code
uses: actions/checkout@v3

View File

@@ -3,6 +3,22 @@ on: [ push, pull_request ]
jobs:
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: ntfy
POSTGRES_PASSWORD: ntfy
POSTGRES_DB: ntfy_test
ports:
- 5432:5432
options: >-
--health-cmd "pg_isready -U ntfy"
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
steps:
- name: Checkout code
uses: actions/checkout@v3
@@ -23,7 +39,7 @@ jobs:
- name: Build web app (required for tests)
run: make web
- name: Run tests, formatting, vetting and linting
run: make check
run: make checkv
- name: Run coverage
run: make coverage
- name: Upload coverage to codecov.io

View File

@@ -265,6 +265,8 @@ cli-build-results:
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
test: .PHONY
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')

628
message/store.go Normal file
View File

@@ -0,0 +1,628 @@
package message
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
"strings"
"sync"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
const (
tagMessageCache = "message_cache"
)
var errNoRows = errors.New("no rows found")
// Store is the interface for a message cache store
type Store interface {
AddMessage(m *model.Message) error
AddMessages(ms []*model.Message) error
DB() *sql.DB
Message(id string) (*model.Message, error)
MessageCounts() (map[string]int, error)
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
MessagesDue() ([]*model.Message, error)
MessagesExpired() ([]string, error)
MarkPublished(m *model.Message) error
UpdateMessageTime(messageID string, timestamp int64) error
Topics() ([]string, error)
DeleteMessages(ids ...string) error
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
ExpireMessages(topics ...string) error
AttachmentsExpired() ([]string, error)
MarkAttachmentsDeleted(ids ...string) error
AttachmentBytesUsedBySender(sender string) (int64, error)
AttachmentBytesUsedByUser(userID string) (int64, error)
UpdateStats(messages int64) error
Stats() (int64, error)
Close() error
}
// storeQueries holds the database-specific SQL queries
type storeQueries struct {
insertMessage string
deleteMessage string
selectScheduledMessageIDsBySeqID string
deleteScheduledBySequenceID string
updateMessagesForTopicExpiry string
selectRowIDFromMessageID string
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeScheduled string
selectMessagesSinceID string
selectMessagesSinceIDScheduled string
selectMessagesLatest string
selectMessagesDue string
selectMessagesExpired string
updateMessagePublished string
selectMessagesCount string
selectMessageCountPerTopic string
selectTopics string
updateAttachmentDeleted string
selectAttachmentsExpired string
selectAttachmentsSizeBySender string
selectAttachmentsSizeByUserID string
selectStats string
updateStats string
updateMessageTime string
}
// commonStore implements store operations that are identical across database backends
type commonStore struct {
db *sql.DB
queue *util.BatchingQueue[*model.Message]
nop bool
mu sync.Mutex
queries storeQueries
}
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
var queue *util.BatchingQueue[*model.Message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
}
c := &commonStore{
db: db,
queue: queue,
nop: nop,
queries: queries,
}
go c.processMessageBatches()
return c
}
// DB returns the underlying database connection
func (c *commonStore) DB() *sql.DB {
return c.db
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
func (c *commonStore) AddMessage(m *model.Message) error {
if c.queue != nil {
c.queue.Enqueue(m)
return nil
}
return c.addMessages([]*model.Message{m})
}
// AddMessages synchronously stores a batch of messages to the message cache
func (c *commonStore) AddMessages(ms []*model.Message) error {
return c.addMessages(ms)
}
func (c *commonStore) addMessages(ms []*model.Message) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.nop {
return nil
}
if len(ms) == 0 {
return nil
}
start := time.Now()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(c.queries.insertMessage)
if err != nil {
return err
}
defer stmt.Close()
for _, m := range ms {
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
return model.ErrUnexpectedMessageType
}
published := m.Time <= time.Now().Unix()
tags := strings.Join(m.Tags, ",")
var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
if m.Attachment != nil {
attachmentName = m.Attachment.Name
attachmentType = m.Attachment.Type
attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires
attachmentURL = m.Attachment.URL
}
var actionsStr string
if len(m.Actions) > 0 {
actionsBytes, err := json.Marshal(m.Actions)
if err != nil {
return err
}
actionsStr = string(actionsBytes)
}
var sender string
if m.Sender.IsValid() {
sender = m.Sender.String()
}
_, err := stmt.Exec(
m.ID,
m.SequenceID,
m.Time,
m.Event,
m.Expires,
m.Topic,
m.Message,
m.Title,
m.Priority,
tags,
m.Click,
m.Icon,
actionsStr,
attachmentName,
attachmentType,
attachmentSize,
attachmentExpires,
attachmentURL,
attachmentDeleted, // Always zero
sender,
m.User,
m.ContentType,
m.Encoding,
published,
)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
return err
}
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
return nil
}
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
if since.IsNone() {
return make([]*model.Message, 0), nil
} else if since.IsLatest() {
return c.messagesLatest(topic)
} else if since.IsID() {
return c.messagesSinceID(topic, since, scheduled)
}
return c.messagesSinceTime(topic, since, scheduled)
}
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
if err != nil {
return nil, err
}
defer idrows.Close()
if !idrows.Next() {
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
}
var rowID int64
if err := idrows.Scan(&rowID); err != nil {
return nil, err
}
idrows.Close()
var rows *sql.Rows
if scheduled {
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
} else {
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
}
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
func (c *commonStore) MessagesExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func (c *commonStore) Message(id string) (*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, model.ErrMessageNotFound
}
defer rows.Close()
return readMessage(rows)
}
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
return err
}
func (c *commonStore) MarkPublished(m *model.Message) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
return err
}
func (c *commonStore) MessageCounts() (map[string]int, error) {
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
if err != nil {
return nil, err
}
defer rows.Close()
var topic string
var count int
counts := make(map[string]int)
for rows.Next() {
if err := rows.Scan(&topic, &count); err != nil {
return nil, err
} else if err := rows.Err(); err != nil {
return nil, err
}
counts[topic] = count
}
return counts, nil
}
func (c *commonStore) Topics() ([]string, error) {
rows, err := c.db.Query(c.queries.selectTopics)
if err != nil {
return nil, err
}
defer rows.Close()
topics := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
topics = append(topics, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return topics, nil
}
func (c *commonStore) DeleteMessages(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
}
}
return tx.Commit()
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
rows.Close()
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return ids, nil
}
func (c *commonStore) ExpireMessages(topics ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
return tx.Commit()
}
func (c *commonStore) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
}
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
}
}
return tx.Commit()
}
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close()
var size int64
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&size); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return size, nil
}
func (c *commonStore) UpdateStats(messages int64) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
func (c *commonStore) Stats() (messages int64, err error) {
rows, err := c.db.Query(c.queries.selectStats)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
func (c *commonStore) Close() error {
return c.db.Close()
}
func (c *commonStore) processMessageBatches() {
if c.queue == nil {
return
}
for messages := range c.queue.Dequeue() {
if err := c.addMessages(messages); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
}
}
}
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
defer rows.Close()
messages := make([]*model.Message, 0)
for rows.Next() {
m, err := readMessage(rows)
if err != nil {
return nil, err
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func readMessage(rows *sql.Rows) (*model.Message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
err := rows.Scan(
&id,
&sequenceID,
&timestamp,
&event,
&expires,
&topic,
&msg,
&title,
&priority,
&tagsStr,
&click,
&icon,
&actionsStr,
&attachmentName,
&attachmentType,
&attachmentSize,
&attachmentExpires,
&attachmentURL,
&sender,
&user,
&contentType,
&encoding,
)
if err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
var actions []*model.Action
if actionsStr != "" {
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
return nil, err
}
}
senderIP, err := netip.ParseAddr(sender)
if err != nil {
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
}
var att *model.Attachment
if attachmentName != "" && attachmentURL != "" {
att = &model.Attachment{
Name: attachmentName,
Type: attachmentType,
Size: attachmentSize,
Expires: attachmentExpires,
URL: attachmentURL,
}
}
return &model.Message{
ID: id,
SequenceID: sequenceID,
Time: timestamp,
Expires: expires,
Event: event,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Icon: icon,
Actions: actions,
Attachment: att,
Sender: senderIP,
User: user,
ContentType: contentType,
Encoding: encoding,
}, nil
}
// Ensure commonStore implements Store
var _ Store = (*commonStore)(nil)
// Needed by store.go but not part of Store interface; unused import guard
var _ = fmt.Sprintf

120
message/store_postgres.go Normal file
View File

@@ -0,0 +1,120 @@
package message
import (
"database/sql"
"time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
)
// PostgreSQL runtime query constants
const (
pgInsertMessageQuery = `
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
`
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
pgSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE mid = $1
`
pgSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2 AND published = TRUE
ORDER BY time, id
`
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND time >= $2
ORDER BY time, id
`
pgSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND id > $2 AND published = TRUE
ORDER BY time, id
`
pgSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND (id > $2 OR published = FALSE)
ORDER BY time, id
`
pgSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE topic = $1 AND published = TRUE
ORDER BY time DESC, id DESC
LIMIT 1
`
pgSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
FROM message
WHERE time <= $1 AND published = FALSE
ORDER BY time, id
`
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
pgUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
)
var pgQueries = storeQueries{
insertMessage: pgInsertMessageQuery,
deleteMessage: pgDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
selectMessagesByID: pgSelectMessagesByIDQuery,
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: pgSelectMessagesLatestQuery,
selectMessagesDue: pgSelectMessagesDueQuery,
selectMessagesExpired: pgSelectMessagesExpiredQuery,
updateMessagePublished: pgUpdateMessagePublishedQuery,
selectMessagesCount: pgSelectMessagesCountQuery,
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
selectTopics: pgSelectTopicsQuery,
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
selectStats: pgSelectStatsQuery,
updateStats: pgUpdateStatsQuery,
updateMessageTime: pgUpdateMessageTimesQuery,
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, err
}
if err := db.Ping(); err != nil {
return nil, err
}
if err := setupPostgresDB(db); err != nil {
return nil, err
}
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
}

View File

@@ -0,0 +1,90 @@
package message
import (
"database/sql"
"fmt"
)
// Initial PostgreSQL schema
const (
pgCreateTablesQuery = `
CREATE TABLE IF NOT EXISTS message (
id BIGSERIAL PRIMARY KEY,
mid TEXT NOT NULL,
sequence_id TEXT NOT NULL,
time BIGINT NOT NULL,
event TEXT NOT NULL,
expires BIGINT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size BIGINT NOT NULL,
attachment_expires BIGINT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
sender TEXT NOT NULL,
user_id TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
CREATE TABLE IF NOT EXISTS message_stats (
key TEXT PRIMARY KEY,
value BIGINT
);
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
CREATE TABLE IF NOT EXISTS schema_version (
store TEXT PRIMARY KEY,
version INT NOT NULL
);
`
)
// PostgreSQL schema management queries
const (
pgCurrentSchemaVersion = 14
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
)
func setupPostgresDB(db *sql.DB) error {
var schemaVersion int
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
if err != nil {
return setupNewPostgresDB(db)
}
if schemaVersion > pgCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
}
return nil
}
func setupNewPostgresDB(db *sql.DB) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
return err
}
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -0,0 +1,120 @@
package message_test
import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/util"
)
func newTestPostgresStore(t *testing.T) message.Store {
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
if dsn == "" {
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
}
// Create a unique schema for this test
schema := fmt.Sprintf("test_%s", util.RandomString(10))
setupDB, err := sql.Open("pgx", dsn)
require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err)
require.Nil(t, setupDB.Close())
// Open store with search_path set to the new schema
u, err := url.Parse(dsn)
require.Nil(t, err)
q := u.Query()
q.Set("search_path", schema)
u.RawQuery = q.Encode()
store, err := message.NewPostgresStore(u.String(), 0, 0)
require.Nil(t, err)
t.Cleanup(func() {
store.Close()
cleanDB, err := sql.Open("pgx", dsn)
if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close()
}
})
return store
}
func TestPostgresStore_Messages(t *testing.T) {
testCacheMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newTestPostgresStore(t))
}
func TestPostgresStore_Topics(t *testing.T) {
testCacheTopics(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_Prune(t *testing.T) {
testCachePrune(t, newTestPostgresStore(t))
}
func TestPostgresStore_Attachments(t *testing.T) {
testCacheAttachments(t, newTestPostgresStore(t))
}
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
}
func TestPostgresStore_Sender(t *testing.T) {
testSender(t, newTestPostgresStore(t))
}
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageByID(t *testing.T) {
testMessageByID(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newTestPostgresStore(t))
}
func TestPostgresStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
}
func TestPostgresStore_Stats(t *testing.T) {
testStats(t, newTestPostgresStore(t))
}
func TestPostgresStore_AddMessages(t *testing.T) {
testAddMessages(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newTestPostgresStore(t))
}
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
}

140
message/store_sqlite.go Normal file
View File

@@ -0,0 +1,140 @@
package message
import (
"database/sql"
"fmt"
"path/filepath"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/util"
)
// SQLite runtime query constants
const (
sqliteInsertMessageQuery = `
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
sqliteSelectMessagesByIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE mid = ?
`
sqliteSelectMessagesSinceTimeQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ? AND published = 1
ORDER BY time, id
`
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND time >= ?
ORDER BY time, id
`
sqliteSelectMessagesSinceIDQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND id > ? AND published = 1
ORDER BY time, id
`
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`
sqliteSelectMessagesLatestQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`
sqliteSelectMessagesDueQuery = `
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE time <= ? AND published = 0
ORDER BY time, id
`
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
)
var sqliteQueries = storeQueries{
insertMessage: sqliteInsertMessageQuery,
deleteMessage: sqliteDeleteMessageQuery,
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
selectMessagesByID: sqliteSelectMessagesByIDQuery,
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
selectMessagesDue: sqliteSelectMessagesDueQuery,
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
selectMessagesCount: sqliteSelectMessagesCountQuery,
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
selectTopics: sqliteSelectTopicsQuery,
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
selectStats: sqliteSelectStatsQuery,
updateStats: sqliteUpdateStatsQuery,
updateMessageTime: sqliteUpdateMessageTimeQuery,
}
// NewSQLiteStore creates a SQLite file-backed cache
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
parentDir := filepath.Dir(filename)
if !util.FileExists(parentDir) {
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
}
db, err := sql.Open("sqlite3", filename)
if err != nil {
return nil, err
}
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
return nil, err
}
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
}
// NewMemStore creates an in-memory cache
func NewMemStore() (Store, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
}
// NewNopStore creates an in-memory cache that discards all messages;
// it is always empty and can be used if caching is entirely disabled
func NewNopStore() (Store, error) {
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
}
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
func createMemoryFilename() string {
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
}

View File

@@ -0,0 +1,466 @@
package message
import (
"database/sql"
"fmt"
"time"
"heckel.io/ntfy/v2/log"
)
// Initial SQLite schema
const (
sqliteCreateMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
sequence_id TEXT NOT NULL,
time INT NOT NULL,
event TEXT NOT NULL,
expires INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_deleted INT NOT NULL,
sender TEXT NOT NULL,
user TEXT NOT NULL,
content_type TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
COMMIT;
`
)
// Schema version management for SQLite
const (
sqliteCurrentSchemaVersion = 14
sqliteCreateSchemaVersionTableQuery = `
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
`
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
)
// Schema migrations for SQLite
const (
// 0 -> 1
sqliteMigrate0To1AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 1 -> 2
sqliteMigrate1To2AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
`
// 2 -> 3
sqliteMigrate2To3AlterMessagesTableQuery = `
BEGIN;
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
COMMIT;
`
// 3 -> 4
sqliteMigrate3To4AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
`
// 4 -> 5
sqliteMigrate4To5AlterMessagesTableQuery = `
BEGIN;
CREATE TABLE IF NOT EXISTS messages_new (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
attachment_owner TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
INSERT
INTO messages_new (
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
SELECT
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
FROM messages;
DROP TABLE messages;
ALTER TABLE messages_new RENAME TO messages;
COMMIT;
`
// 5 -> 6
sqliteMigrate5To6AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
`
// 6 -> 7
sqliteMigrate6To7AlterMessagesTableQuery = `
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
`
// 7 -> 8
sqliteMigrate7To8AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
`
// 8 -> 9
sqliteMigrate8To9AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
`
// 9 -> 10
sqliteMigrate9To10AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
`
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
// 10 -> 11
sqliteMigrate10To11AlterMessagesTableQuery = `
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value INT
);
INSERT INTO stats (key, value) VALUES ('messages', 0);
`
// 11 -> 12
sqliteMigrate11To12AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
`
// 12 -> 13
sqliteMigrate12To13AlterMessagesTableQuery = `
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
`
// 13 -> 14
sqliteMigrate13To14AlterMessagesTableQuery = `
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
`
)
var (
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
0: sqliteMigrateFrom0,
1: sqliteMigrateFrom1,
2: sqliteMigrateFrom2,
3: sqliteMigrateFrom3,
4: sqliteMigrateFrom4,
5: sqliteMigrateFrom5,
6: sqliteMigrateFrom6,
7: sqliteMigrateFrom7,
8: sqliteMigrateFrom8,
9: sqliteMigrateFrom9,
10: sqliteMigrateFrom10,
11: sqliteMigrateFrom11,
12: sqliteMigrateFrom12,
13: sqliteMigrateFrom13,
}
)
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
return err
}
// If 'messages' table does not exist, this must be a new database
rowsMC, err := db.Query(sqliteSelectMessagesCountQuery)
if err != nil {
return setupNewSQLite(db)
}
rowsMC.Close()
// If 'messages' table exists, check 'schemaVersion' table
schemaVersion := 0
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
if err == nil {
defer rowsSV.Close()
if !rowsSV.Next() {
return fmt.Errorf("cannot determine schema version: cache file may be corrupt")
}
if err := rowsSV.Scan(&schemaVersion); err != nil {
return err
}
rowsSV.Close()
}
// Do migrations
if schemaVersion == sqliteCurrentSchemaVersion {
return nil
} else if schemaVersion > sqliteCurrentSchemaVersion {
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
}
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
fn, ok := sqliteMigrations[i]
if !ok {
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
} else if err := fn(db, cacheDuration); err != nil {
return err
}
}
return nil
}
func setupNewSQLite(db *sql.DB) error {
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
return err
}
return nil
}
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
if startupQueries != "" {
if _, err := db.Exec(startupQueries); err != nil {
return err
}
}
return nil
}
func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
if _, err := db.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
return err
}
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
return err
}
return nil
}
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
return err
}
return tx.Commit()
}
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
return err
}
return tx.Commit()
}
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
return err
}
return tx.Commit()
}
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
return err
}
return tx.Commit()
}
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err
}
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
return err
}
return tx.Commit()
}

View File

@@ -0,0 +1,459 @@
package message_test
import (
"database/sql"
"fmt"
"path/filepath"
"testing"
"time"
_ "github.com/mattn/go-sqlite3" // SQLite driver
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func TestSqliteStore_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestStore(t))
}
func TestMemStore_Messages(t *testing.T) {
testCacheMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestStore(t))
}
func TestSqliteStore_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestStore(t))
}
func TestMemStore_Topics(t *testing.T) {
testCacheTopics(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestStore(t))
}
func TestSqliteStore_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestStore(t))
}
func TestMemStore_Prune(t *testing.T) {
testCachePrune(t, newMemTestStore(t))
}
func TestSqliteStore_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestStore(t))
}
func TestMemStore_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestStore(t))
}
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
}
func TestMemStore_AttachmentsExpired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestStore(t))
}
func TestSqliteStore_Sender(t *testing.T) {
testSender(t, newSqliteTestStore(t))
}
func TestMemStore_Sender(t *testing.T) {
testSender(t, newMemTestStore(t))
}
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
}
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
}
func TestSqliteStore_MessageByID(t *testing.T) {
testMessageByID(t, newSqliteTestStore(t))
}
func TestMemStore_MessageByID(t *testing.T) {
testMessageByID(t, newMemTestStore(t))
}
func TestSqliteStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newSqliteTestStore(t))
}
func TestMemStore_MarkPublished(t *testing.T) {
testMarkPublished(t, newMemTestStore(t))
}
func TestSqliteStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newSqliteTestStore(t))
}
func TestMemStore_ExpireMessages(t *testing.T) {
testExpireMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
}
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
testMarkAttachmentsDeleted(t, newMemTestStore(t))
}
func TestSqliteStore_Stats(t *testing.T) {
testStats(t, newSqliteTestStore(t))
}
func TestMemStore_Stats(t *testing.T) {
testStats(t, newMemTestStore(t))
}
func TestSqliteStore_AddMessages(t *testing.T) {
testAddMessages(t, newSqliteTestStore(t))
}
func TestMemStore_AddMessages(t *testing.T) {
testAddMessages(t, newMemTestStore(t))
}
func TestSqliteStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newSqliteTestStore(t))
}
func TestMemStore_MessagesDue(t *testing.T) {
testMessagesDue(t, newMemTestStore(t))
}
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
}
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
testMessageFieldRoundTrip(t, newMemTestStore(t))
}
func TestSqliteStore_Migration_From0(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 0" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
s := newSqliteTestStoreFromFile(t, filename, "")
checkSqliteSchemaVersion(t, filename)
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
require.Equal(t, "some message 5", messages[5].Message)
require.Equal(t, "", messages[5].Title)
require.Nil(t, messages[5].Tags)
require.Equal(t, 0, messages[5].Priority)
}
func TestSqliteStore_Migration_From1(t *testing.T) {
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 1" schema
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(512) NOT NULL,
title VARCHAR(256) NOT NULL,
priority INT NOT NULL,
tags VARCHAR(256) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
s := newSqliteTestStoreFromFile(t, filename, "")
checkSqliteSchemaVersion(t, filename)
// Add delayed message
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(delayedMessage))
// 10, not 11!
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
// 11!
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists
verifyDB, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer verifyDB.Close()
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err)
require.True(t, rows.Next())
var indexName string
require.Nil(t, rows.Scan(&indexName))
require.Equal(t, "idx_topic", indexName)
require.Nil(t, rows.Close())
}
func TestSqliteStore_Migration_From9(t *testing.T) {
// This primarily tests the awkward migration that introduces the "expires" column.
// The migration logic has to update the column, using the existing "cache-duration" value.
filename := newSqliteTestStoreFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 9" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
sender TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
insertQuery := `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
for i := 0; i < 10; i++ {
_, err = db.Exec(
insertQuery,
fmt.Sprintf("abcd%d", i),
time.Now().Unix(),
"mytopic",
fmt.Sprintf("some message %d", i),
"", // title
0, // priority
"", // tags
"", // click
"", // icon
"", // actions
"", // attachment_name
"", // attachment_type
0, // attachment_size
0, // attachment_expires
"", // attachment_url
"9.9.9.9", // sender
"", // encoding
1, // published
)
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create store to trigger migration
cacheDuration := 17 * time.Hour
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
checkSqliteSchemaVersion(t, filename)
// Check version
verifyDB, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer verifyDB.Close()
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
require.Nil(t, err)
require.True(t, rows.Next())
var version int
require.Nil(t, rows.Scan(&version))
require.Equal(t, 14, version)
require.Nil(t, rows.Close())
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
for _, m := range messages {
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
}
}
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
filename := newSqliteTestStoreFile(t)
startupQueries := `pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;`
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.FileExists(t, filename+"-wal")
require.FileExists(t, filename+"-shm")
}
func TestSqliteStore_StartupQueries_None(t *testing.T) {
filename := newSqliteTestStoreFile(t)
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.NoFileExists(t, filename+"-wal")
require.NoFileExists(t, filename+"-shm")
}
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
filename := newSqliteTestStoreFile(t)
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
require.Error(t, err)
}
func TestNopStore(t *testing.T) {
s, err := message.NewNopStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Empty(t, messages)
topics, err := s.Topics()
require.Nil(t, err)
require.Empty(t, topics)
}
func newSqliteTestStore(t *testing.T) message.Store {
filename := filepath.Join(t.TempDir(), "cache.db")
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newSqliteTestStoreFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func newMemTestStore(t *testing.T) message.Store {
s, err := message.NewMemStore()
require.Nil(t, err)
t.Cleanup(func() { s.Close() })
return s
}
func checkSqliteSchemaVersion(t *testing.T, filename string) {
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
defer db.Close()
rows, err := db.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err)
require.True(t, rows.Next())
var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, 14, schemaVersion)
require.Nil(t, rows.Close())
}

767
message/store_test.go Normal file
View File

@@ -0,0 +1,767 @@
package message_test
import (
"net/netip"
"sort"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
)
func testCacheMessages(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
require.Nil(t, s.AddMessage(m2))
// Adding invalid
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
// mytopic: count
counts, err := s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
// mytopic: since all
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, model.MessageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["example"])
// example: since all
messages, _ = s.Messages("example", model.SinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: count
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 0, counts["doesnotexist"])
// non-existing: since all
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
require.Empty(t, messages)
}
func testCacheMessagesLock(t *testing.T, s message.Store) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
}
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := model.NewDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = s.MessagesDue()
require.Empty(t, messages)
}
func testCacheTopics(t *testing.T, s message.Store) {
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
topics, err := s.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Contains(t, topics, "topic1")
require.Contains(t, topics, "topic2")
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, s.AddMessage(m))
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
}
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := model.NewDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := model.NewDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := model.NewDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := model.NewDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := model.NewDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := model.NewDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
require.Nil(t, s.AddMessage(m4))
require.Nil(t, s.AddMessage(m5))
require.Nil(t, s.AddMessage(m6))
require.Nil(t, s.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
}
func testCachePrune(t *testing.T, s message.Store) {
now := time.Now().Unix()
m1 := model.NewDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := model.NewDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := model.NewDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
counts, err := s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessageIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
counts, err = s.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"])
require.Equal(t, 0, counts["another_topic"])
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
}
func testCacheAttachments(t *testing.T, s message.Store) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, s.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = model.NewDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &model.Attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, s.AddMessage(m))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
}
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
m := model.NewDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, s.AddMessage(m))
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &model.Attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, s.AddMessage(m))
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func testSender(t *testing.T, s message.Store) {
m1 := model.NewDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, s.AddMessage(m1))
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, s.AddMessage(m2))
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
}
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
// Create a scheduled (unpublished) message
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, s.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, s.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
}
func testMessageByID(t *testing.T, s message.Store) {
// Add a message
m := model.NewDefaultMessage("mytopic", "some message")
m.Title = "some title"
m.Priority = 4
m.Tags = []string{"tag1", "tag2"}
require.Nil(t, s.AddMessage(m))
// Retrieve by ID
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "some message", retrieved.Message)
require.Equal(t, "some title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
// Non-existent ID returns ErrMessageNotFound
_, err = s.Message("doesnotexist")
require.Equal(t, model.ErrMessageNotFound, err)
}
func testMarkPublished(t *testing.T, s message.Store) {
// Add a scheduled message (future time → unpublished)
m := model.NewDefaultMessage("mytopic", "scheduled message")
m.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m))
// Verify it does not appear in non-scheduled queries
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 0, len(messages))
// Verify it does appear in scheduled queries
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Mark as published
require.Nil(t, s.MarkPublished(m))
// Now it should appear in non-scheduled queries too
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "scheduled message", messages[0].Message)
}
func testExpireMessages(t *testing.T, s message.Store) {
// Add messages to two topics
m1 := model.NewDefaultMessage("topic1", "message 1")
m1.Expires = time.Now().Add(time.Hour).Unix()
m2 := model.NewDefaultMessage("topic1", "message 2")
m2.Expires = time.Now().Add(time.Hour).Unix()
m3 := model.NewDefaultMessage("topic2", "message 3")
m3.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
require.Nil(t, s.AddMessage(m2))
require.Nil(t, s.AddMessage(m3))
// Verify all messages exist
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Expire topic1 messages
require.Nil(t, s.ExpireMessages("topic1"))
// topic1 messages should now be expired (expires set to past)
expiredIDs, err := s.MessagesExpired()
require.Nil(t, err)
require.Equal(t, 2, len(expiredIDs))
sort.Strings(expiredIDs)
expectedIDs := []string{m1.ID, m2.ID}
sort.Strings(expectedIDs)
require.Equal(t, expectedIDs, expiredIDs)
// topic2 should be unaffected
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "message 3", messages[0].Message)
}
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
// Add a message with an expired attachment (file needs cleanup)
m1 := model.NewDefaultMessage("mytopic", "old file")
m1.ID = "msg1"
m1.SequenceID = "msg1"
m1.Expires = time.Now().Add(time.Hour).Unix()
m1.Attachment = &model.Attachment{
Name: "old.pdf",
Type: "application/pdf",
Size: 50000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/old.pdf",
}
require.Nil(t, s.AddMessage(m1))
// Add a message with another expired attachment
m2 := model.NewDefaultMessage("mytopic", "another old file")
m2.ID = "msg2"
m2.SequenceID = "msg2"
m2.Expires = time.Now().Add(time.Hour).Unix()
m2.Attachment = &model.Attachment{
Name: "another.pdf",
Type: "application/pdf",
Size: 30000,
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
URL: "https://ntfy.sh/file/another.pdf",
}
require.Nil(t, s.AddMessage(m2))
// Both should show as expired attachments needing cleanup
ids, err := s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 2, len(ids))
// Mark msg1's attachment as deleted (file cleaned up)
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
// Now only msg2 should show as needing cleanup
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "msg2", ids[0])
// Mark msg2 too
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
// No more expired attachments to clean up
ids, err = s.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 0, len(ids))
// Messages themselves still exist
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
}
func testStats(t *testing.T, s message.Store) {
// Initial stats should be zero
messages, err := s.Stats()
require.Nil(t, err)
require.Equal(t, int64(0), messages)
// Update stats
require.Nil(t, s.UpdateStats(42))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(42), messages)
// Update again (overwrites)
require.Nil(t, s.UpdateStats(100))
messages, err = s.Stats()
require.Nil(t, err)
require.Equal(t, int64(100), messages)
}
func testAddMessages(t *testing.T, s message.Store) {
// Batch add multiple messages
msgs := []*model.Message{
model.NewDefaultMessage("mytopic", "batch 1"),
model.NewDefaultMessage("mytopic", "batch 2"),
model.NewDefaultMessage("othertopic", "batch 3"),
}
require.Nil(t, s.AddMessages(msgs))
// Verify all were inserted
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "batch 3", messages[0].Message)
// Empty batch should succeed
require.Nil(t, s.AddMessages([]*model.Message{}))
// Batch with invalid event type should fail
badMsgs := []*model.Message{
model.NewKeepaliveMessage("mytopic"),
}
require.NotNil(t, s.AddMessages(badMsgs))
}
func testMessagesDue(t *testing.T, s message.Store) {
// Add a message scheduled in the past (i.e. it's due now)
m1 := model.NewDefaultMessage("mytopic", "due message")
m1.Time = time.Now().Add(-time.Second).Unix()
// Set expires in the future so it doesn't get pruned
m1.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m1))
// Add a message scheduled in the future (not due)
m2 := model.NewDefaultMessage("mytopic", "future message")
m2.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, s.AddMessage(m2))
// Mark m1 as published so it won't be "due"
// (MessagesDue returns unpublished messages whose time <= now)
// m1 is auto-published (time <= now), so it should not be due
// m2 is unpublished (time in future), not due yet
due, err := s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Add a message that was explicitly scheduled in the past but time has "arrived"
// We need to manipulate the database to create a truly "due" message:
// a message with published=false and time <= now
m3 := model.NewDefaultMessage("mytopic", "truly due message")
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
require.Nil(t, s.AddMessage(m3))
// Not due yet
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 0, len(due))
// Wait for it to become due
time.Sleep(3 * time.Second)
due, err = s.MessagesDue()
require.Nil(t, err)
require.Equal(t, 1, len(due))
require.Equal(t, "truly due message", due[0].Message)
}
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
// Create a message with all fields populated
m := model.NewDefaultMessage("mytopic", "hello world")
m.SequenceID = "custom_seq_id"
m.Title = "A Title"
m.Priority = 4
m.Tags = []string{"warning", "srv01"}
m.Click = "https://example.com/click"
m.Icon = "https://example.com/icon.png"
m.Actions = []*model.Action{
{
ID: "action1",
Action: "view",
Label: "Open Site",
URL: "https://example.com",
Clear: true,
},
{
ID: "action2",
Action: "http",
Label: "Call Webhook",
URL: "https://example.com/hook",
Method: "PUT",
Headers: map[string]string{"X-Token": "secret"},
Body: `{"key":"value"}`,
},
}
m.ContentType = "text/markdown"
m.Encoding = "base64"
m.Sender = netip.MustParseAddr("9.8.7.6")
m.User = "u_TestUser123"
require.Nil(t, s.AddMessage(m))
// Retrieve and verify every field
retrieved, err := s.Message(m.ID)
require.Nil(t, err)
require.Equal(t, m.ID, retrieved.ID)
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
require.Equal(t, m.Time, retrieved.Time)
require.Equal(t, m.Expires, retrieved.Expires)
require.Equal(t, model.MessageEvent, retrieved.Event)
require.Equal(t, "mytopic", retrieved.Topic)
require.Equal(t, "hello world", retrieved.Message)
require.Equal(t, "A Title", retrieved.Title)
require.Equal(t, 4, retrieved.Priority)
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
require.Equal(t, "https://example.com/click", retrieved.Click)
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
require.Equal(t, "text/markdown", retrieved.ContentType)
require.Equal(t, "base64", retrieved.Encoding)
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
require.Equal(t, "u_TestUser123", retrieved.User)
// Verify actions round-trip
require.Equal(t, 2, len(retrieved.Actions))
require.Equal(t, "action1", retrieved.Actions[0].ID)
require.Equal(t, "view", retrieved.Actions[0].Action)
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
require.Equal(t, true, retrieved.Actions[0].Clear)
require.Equal(t, "action2", retrieved.Actions[1].ID)
require.Equal(t, "http", retrieved.Actions[1].Action)
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
require.Equal(t, "PUT", retrieved.Actions[1].Method)
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
}

205
model/model.go Normal file
View File

@@ -0,0 +1,205 @@
package model
import (
"errors"
"net/netip"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
// List of possible events
const (
OpenEvent = "open"
KeepaliveEvent = "keepalive"
MessageEvent = "message"
MessageDeleteEvent = "message_delete"
MessageClearEvent = "message_clear"
PollRequestEvent = "poll_request"
)
// MessageIDLength is the length of a randomly generated message ID
const MessageIDLength = 12
// Errors for message operations
var (
ErrUnexpectedMessageType = errors.New("unexpected message type")
ErrMessageNotFound = errors.New("message not found")
)
// Message represents a message published to a topic
type Message struct {
ID string `json:"id"` // Random message ID
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*Action `json:"actions,omitempty"`
Attachment *Attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
// Context returns a log context for the message
func (m *Message) Context() log.Context {
fields := map[string]any{
"topic": m.Topic,
"message_id": m.ID,
"message_sequence_id": m.SequenceID,
"message_time": m.Time,
"message_event": m.Event,
"message_body_size": len(m.Message),
}
if m.Sender.IsValid() {
fields["message_sender"] = m.Sender.String()
}
if m.User != "" {
fields["message_user"] = m.User
}
return fields
}
// ForJSON returns a copy of the message suitable for JSON output.
// It clears the SequenceID if it equals the ID to reduce redundancy.
func (m *Message) ForJSON() *Message {
if m.SequenceID == m.ID {
clone := *m
clone.SequenceID = ""
return &clone
}
return m
}
// Attachment represents a file attachment on a message
type Attachment struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"`
URL string `json:"url"`
}
// Action represents a user-defined action on a message
type Action struct {
ID string `json:"id"`
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
Label string `json:"label"` // action button label
Clear bool `json:"clear"` // clear notification after successful execution
URL string `json:"url,omitempty"` // used in "view" and "http" actions
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
Body string `json:"body,omitempty"` // used in "http" action
Intent string `json:"intent,omitempty"` // used in "broadcast" action
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
Value string `json:"value,omitempty"` // used in "copy" action
}
// NewAction creates a new action with initialized maps
func NewAction() *Action {
return &Action{
Headers: make(map[string]string),
Extras: make(map[string]string),
}
}
// NewMessage creates a new message with the current timestamp
func NewMessage(event, topic, msg string) *Message {
return &Message{
ID: util.RandomString(MessageIDLength),
Time: time.Now().Unix(),
Event: event,
Topic: topic,
Message: msg,
}
}
// NewOpenMessage is a convenience method to create an open message
func NewOpenMessage(topic string) *Message {
return NewMessage(OpenEvent, topic, "")
}
// NewKeepaliveMessage is a convenience method to create a keepalive message
func NewKeepaliveMessage(topic string) *Message {
return NewMessage(KeepaliveEvent, topic, "")
}
// NewDefaultMessage is a convenience method to create a notification message
func NewDefaultMessage(topic, msg string) *Message {
return NewMessage(MessageEvent, topic, msg)
}
// NewActionMessage creates a new action message (message_delete or message_clear)
func NewActionMessage(event, topic, sequenceID string) *Message {
m := NewMessage(event, topic, "")
m.SequenceID = sequenceID
return m
}
// ValidMessageID returns true if the given string is a valid message ID
func ValidMessageID(s string) bool {
return util.ValidRandomString(s, MessageIDLength)
}
// SinceMarker represents a point in time or message ID from which to retrieve messages
type SinceMarker struct {
time time.Time
id string
}
// NewSinceTime creates a new SinceMarker from a Unix timestamp
func NewSinceTime(timestamp int64) SinceMarker {
return SinceMarker{time.Unix(timestamp, 0), ""}
}
// NewSinceID creates a new SinceMarker from a message ID
func NewSinceID(id string) SinceMarker {
return SinceMarker{time.Unix(0, 0), id}
}
// IsAll returns true if this is the "all messages" marker
func (t SinceMarker) IsAll() bool {
return t == SinceAllMessages
}
// IsNone returns true if this is the "no messages" marker
func (t SinceMarker) IsNone() bool {
return t == SinceNoMessages
}
// IsLatest returns true if this is the "latest message" marker
func (t SinceMarker) IsLatest() bool {
return t == SinceLatestMessage
}
// IsID returns true if this marker references a specific message ID
func (t SinceMarker) IsID() bool {
return t.id != "" && t.id != "latest"
}
// Time returns the time component of the marker
func (t SinceMarker) Time() time.Time {
return t.time
}
// ID returns the message ID component of the marker
func (t SinceMarker) ID() string {
return t.id
}
// Common SinceMarker values for subscribing to messages
var (
SinceAllMessages = SinceMarker{time.Unix(0, 0), ""}
SinceNoMessages = SinceMarker{time.Unix(1, 0), ""}
SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"}
)

View File

@@ -8,6 +8,7 @@ import (
"strings"
"unicode/utf8"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -39,7 +40,7 @@ type actionParser struct {
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
func parseActions(s string) (actions []*action, err error) {
func parseActions(s string) (actions []*model.Action, err error) {
// Parse JSON or simple format
s = strings.TrimSpace(s)
if strings.HasPrefix(s, "[") {
@@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) {
}
// parseActionsFromJSON converts a JSON array into an array of actions
func parseActionsFromJSON(s string) ([]*action, error) {
actions := make([]*action, 0)
func parseActionsFromJSON(s string) ([]*model.Action, error) {
actions := make([]*model.Action, 0)
if err := json.Unmarshal([]byte(s), &actions); err != nil {
return nil, fmt.Errorf("JSON error: %w", err)
}
@@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) {
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
func parseActionsFromSimple(s string) ([]*action, error) {
func parseActionsFromSimple(s string) ([]*model.Action, error) {
if !utf8.ValidString(s) {
return nil, errors.New("invalid utf-8 string")
}
@@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) {
}
// Parse loops trough parseAction() until the end of the string is reached
func (p *actionParser) Parse() ([]*action, error) {
actions := make([]*action, 0)
func (p *actionParser) Parse() ([]*model.Action, error) {
actions := make([]*model.Action, 0)
for !p.eof() {
a, err := p.parseAction()
if err != nil {
@@ -134,7 +135,7 @@ func (p *actionParser) Parse() ([]*action, error) {
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
// and then uses populateAction to interpret the keys/values. The function terminates
// when EOF or ";" is reached.
func (p *actionParser) parseAction() (*action, error) {
func (p *actionParser) parseAction() (*model.Action, error) {
a := newAction()
section := 0
for {
@@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) {
// populateAction is the "business logic" of the parser. It applies the key/value
// pair to the action instance.
func populateAction(newAction *action, section int, key, value string) error {
func populateAction(newAction *model.Action, section int, key, value string) error {
// Auto-expand keys based on their index
if key == "" && section == 0 {
key = "action"

View File

@@ -10,6 +10,7 @@ import (
"github.com/emersion/go-smtp"
"github.com/gorilla/websocket"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -55,12 +56,12 @@ func logvr(v *visitor, r *http.Request) *log.Event {
}
// logvrm creates a new log event with HTTP request, visitor fields and message fields
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event {
return logvr(v, r).With(m)
}
// logvrm creates a new log event with visitor fields and message fields
func logvm(v *visitor, m *message) *log.Event {
func logvm(v *visitor, m *model.Message) *log.Event {
return logv(v).With(m)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,825 +0,0 @@
package server
import (
"database/sql"
"fmt"
"github.com/stretchr/testify/assert"
"net/netip"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestSqliteCache_Messages(t *testing.T) {
testCacheMessages(t, newSqliteTestCache(t))
}
func TestMemCache_Messages(t *testing.T) {
testCacheMessages(t, newMemTestCache(t))
}
func testCacheMessages(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = 1
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = 2
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
require.Nil(t, c.AddMessage(m2))
// Adding invalid
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
// mytopic: count
counts, err := c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
// mytopic: since all
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
require.Equal(t, 2, len(messages))
require.Equal(t, "my message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
require.Equal(t, messageEvent, messages[0].Event)
require.Equal(t, "", messages[0].Title)
require.Equal(t, 0, messages[0].Priority)
require.Nil(t, messages[0].Tags)
require.Equal(t, "my other message", messages[1].Message)
// mytopic: since none
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
require.Empty(t, messages)
// mytopic: since m1 (by ID)
messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false)
require.Equal(t, 1, len(messages))
require.Equal(t, m2.ID, messages[0].ID)
require.Equal(t, "my other message", messages[0].Message)
require.Equal(t, "mytopic", messages[0].Topic)
// mytopic: since 2
messages, _ = c.Messages("mytopic", newSinceTime(2), false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["example"])
// example: since all
messages, _ = c.Messages("example", sinceAllMessages, false)
require.Equal(t, "my example message", messages[0].Message)
// non-existing: count
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 0, counts["doesnotexist"])
// non-existing: since all
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
require.Empty(t, messages)
}
func TestSqliteCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesLock(t *testing.T) {
testCacheMessagesLock(t, newMemTestCache(t))
}
func testCacheMessagesLock(t *testing.T, c *messageCache) {
var wg sync.WaitGroup
for i := 0; i < 5000; i++ {
wg.Add(1)
go func() {
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message")))
wg.Done()
}()
}
wg.Wait()
}
func TestSqliteCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesScheduled(t *testing.T) {
testCacheMessagesScheduled(t, newMemTestCache(t))
}
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = time.Now().Add(time.Hour).Unix()
m3 := newDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
m4 := newDefaultMessage("mytopic2", "message 4")
m4.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
require.Equal(t, 1, len(messages))
require.Equal(t, "message 1", messages[0].Message)
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
require.Equal(t, 3, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message) // Order!
require.Equal(t, "message 2", messages[2].Message)
messages, _ = c.MessagesDue()
require.Empty(t, messages)
}
func TestSqliteCache_Topics(t *testing.T) {
testCacheTopics(t, newSqliteTestCache(t))
}
func TestMemCache_Topics(t *testing.T) {
testCacheTopics(t, newMemTestCache(t))
}
func testCacheTopics(t *testing.T, c *messageCache) {
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
topics, err := c.Topics()
if err != nil {
t.Fatal(err)
}
require.Equal(t, 2, len(topics))
require.Equal(t, "topic1", topics["topic1"].ID)
require.Equal(t, "topic2", topics["topic2"].ID)
}
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
}
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "some message")
m.Tags = []string{"tag1", "tag2"}
m.Priority = 5
m.Title = "some title"
require.Nil(t, c.AddMessage(m))
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
require.Equal(t, 5, messages[0].Priority)
require.Equal(t, "some title", messages[0].Title)
}
func TestSqliteCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newSqliteTestCache(t))
}
func TestMemCache_MessagesSinceID(t *testing.T) {
testCacheMessagesSinceID(t, newMemTestCache(t))
}
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "message 1")
m1.Time = 100
m2 := newDefaultMessage("mytopic", "message 2")
m2.Time = 200
m3 := newDefaultMessage("mytopic", "message 3")
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
m4 := newDefaultMessage("mytopic", "message 4")
m4.Time = 400
m5 := newDefaultMessage("mytopic", "message 5")
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
m6 := newDefaultMessage("mytopic", "message 6")
m6.Time = 600
m7 := newDefaultMessage("mytopic", "message 7")
m7.Time = 700
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
require.Nil(t, c.AddMessage(m4))
require.Nil(t, c.AddMessage(m5))
require.Nil(t, c.AddMessage(m6))
require.Nil(t, c.AddMessage(m7))
// Case 1: Since ID exists, exclude scheduled
messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false)
require.Equal(t, 3, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
require.Equal(t, "message 7", messages[2].Message)
// Case 2: Since ID exists, include scheduled
messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true)
require.Equal(t, 5, len(messages))
require.Equal(t, "message 4", messages[0].Message)
require.Equal(t, "message 6", messages[1].Message)
require.Equal(t, "message 7", messages[2].Message)
require.Equal(t, "message 5", messages[3].Message) // Order!
require.Equal(t, "message 3", messages[4].Message) // Order!
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true)
require.Equal(t, 7, len(messages))
require.Equal(t, "message 1", messages[0].Message)
require.Equal(t, "message 2", messages[1].Message)
require.Equal(t, "message 4", messages[2].Message)
require.Equal(t, "message 6", messages[3].Message)
require.Equal(t, "message 7", messages[4].Message)
require.Equal(t, "message 5", messages[5].Message) // Order!
require.Equal(t, "message 3", messages[6].Message) // Order!
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false)
require.Equal(t, 0, len(messages))
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true)
require.Equal(t, 2, len(messages))
require.Equal(t, "message 5", messages[0].Message)
require.Equal(t, "message 3", messages[1].Message)
}
func TestSqliteCache_Prune(t *testing.T) {
testCachePrune(t, newSqliteTestCache(t))
}
func TestMemCache_Prune(t *testing.T) {
testCachePrune(t, newMemTestCache(t))
}
func testCachePrune(t *testing.T, c *messageCache) {
now := time.Now().Unix()
m1 := newDefaultMessage("mytopic", "my message")
m1.Time = now - 10
m1.Expires = now - 5
m2 := newDefaultMessage("mytopic", "my other message")
m2.Time = now - 5
m2.Expires = now + 5 // In the future
m3 := newDefaultMessage("another_topic", "and another one")
m3.Time = now - 12
m3.Expires = now - 2
require.Nil(t, c.AddMessage(m1))
require.Nil(t, c.AddMessage(m2))
require.Nil(t, c.AddMessage(m3))
counts, err := c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 2, counts["mytopic"])
require.Equal(t, 1, counts["another_topic"])
expiredMessageIDs, err := c.MessagesExpired()
require.Nil(t, err)
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
counts, err = c.MessageCounts()
require.Nil(t, err)
require.Equal(t, 1, counts["mytopic"])
require.Equal(t, 0, counts["another_topic"])
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
}
func TestSqliteCache_Attachments(t *testing.T) {
testCacheAttachments(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments(t *testing.T) {
testCacheAttachments(t, newMemTestCache(t))
}
func testCacheAttachments(t *testing.T, c *messageCache) {
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "flower.jpg",
Type: "image/jpeg",
Size: 5000,
Expires: expires1,
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
}
require.Nil(t, c.AddMessage(m))
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
m = newDefaultMessage("mytopic", "sending you a car")
m.ID = "m2"
m.SequenceID = "m2"
m.Sender = netip.MustParseAddr("1.2.3.4")
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: expires2,
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
m = newDefaultMessage("another-topic", "sending you another car")
m.ID = "m3"
m.SequenceID = "m3"
m.User = "u_BAsbaAa"
m.Sender = netip.MustParseAddr("5.6.7.8")
m.Attachment = &attachment{
Name: "another-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: expires3,
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
}
require.Nil(t, c.AddMessage(m))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, "flower for you", messages[0].Message)
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
require.Equal(t, int64(5000), messages[0].Attachment.Size)
require.Equal(t, expires1, messages[0].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
require.Equal(t, "sending you a car", messages[1].Message)
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
require.Equal(t, int64(10000), messages[1].Attachment.Size)
require.Equal(t, expires2, messages[1].Attachment.Expires)
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(10000), size)
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
require.Nil(t, err)
require.Equal(t, int64(20000), size)
}
func TestSqliteCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
}
func TestMemCache_Attachments_Expired(t *testing.T) {
testCacheAttachmentsExpired(t, newMemTestCache(t))
}
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
m := newDefaultMessage("mytopic", "flower for you")
m.ID = "m1"
m.SequenceID = "m1"
m.Expires = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with attachment")
m.ID = "m2"
m.SequenceID = "m2"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Size: 10000,
Expires: time.Now().Add(2 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic", "message with external attachment")
m.ID = "m3"
m.SequenceID = "m3"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "car.jpg",
Type: "image/jpeg",
Expires: 0, // Unknown!
URL: "https://somedomain.com/car.jpg",
}
require.Nil(t, c.AddMessage(m))
m = newDefaultMessage("mytopic2", "message with expired attachment")
m.ID = "m4"
m.SequenceID = "m4"
m.Expires = time.Now().Add(2 * time.Hour).Unix()
m.Attachment = &attachment{
Name: "expired-car.jpg",
Type: "image/jpeg",
Size: 20000,
Expires: time.Now().Add(-1 * time.Hour).Unix(),
URL: "https://ntfy.sh/file/aCaRURL.jpg",
}
require.Nil(t, c.AddMessage(m))
ids, err := c.AttachmentsExpired()
require.Nil(t, err)
require.Equal(t, 1, len(ids))
require.Equal(t, "m4", ids[0])
}
func TestSqliteCache_Migration_From0(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 0" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(1024) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
require.Equal(t, "some message 5", messages[5].Message)
require.Equal(t, "", messages[5].Title)
require.Nil(t, messages[5].Tags)
require.Equal(t, 0, messages[5].Priority)
}
func TestSqliteCache_Migration_From1(t *testing.T) {
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 1" schema
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS messages (
id VARCHAR(20) PRIMARY KEY,
time INT NOT NULL,
topic VARCHAR(64) NOT NULL,
message VARCHAR(512) NOT NULL,
title VARCHAR(256) NOT NULL,
priority INT NOT NULL,
tags VARCHAR(256) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
`)
require.Nil(t, err)
// Insert a bunch of messages
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
require.Nil(t, err)
}
require.Nil(t, db.Close())
// Create cache to trigger migration
c := newSqliteTestCacheFromFile(t, filename, "")
checkSchemaVersion(t, c.db)
// Add delayed message
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
require.Nil(t, c.AddMessage(delayedMessage))
// 10, not 11!
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
// 11!
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 11, len(messages))
// Check that index "idx_topic" exists
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
require.Nil(t, err)
require.True(t, rows.Next())
var indexName string
require.Nil(t, rows.Scan(&indexName))
require.Equal(t, "idx_topic", indexName)
}
func TestSqliteCache_Migration_From9(t *testing.T) {
// This primarily tests the awkward migration that introduces the "expires" column.
// The migration logic has to update the column, using the existing "cache-duration" value.
filename := newSqliteTestCacheFile(t)
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 8" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mid TEXT NOT NULL,
time INT NOT NULL,
topic TEXT NOT NULL,
message TEXT NOT NULL,
title TEXT NOT NULL,
priority INT NOT NULL,
tags TEXT NOT NULL,
click TEXT NOT NULL,
icon TEXT NOT NULL,
actions TEXT NOT NULL,
attachment_name TEXT NOT NULL,
attachment_type TEXT NOT NULL,
attachment_size INT NOT NULL,
attachment_expires INT NOT NULL,
attachment_url TEXT NOT NULL,
sender TEXT NOT NULL,
encoding TEXT NOT NULL,
published INT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
COMMIT;
`)
require.Nil(t, err)
// Insert a bunch of messages
insertQuery := `
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
for i := 0; i < 10; i++ {
_, err = db.Exec(
insertQuery,
fmt.Sprintf("abcd%d", i),
time.Now().Unix(),
"mytopic",
fmt.Sprintf("some message %d", i),
"", // title
0, // priority
"", // tags
"", // click
"", // icon
"", // actions
"", // attachment_name
"", // attachment_type
0, // attachment_size
0, // attachment_type
"", // attachment_url
"9.9.9.9", // sender
"", // encoding
1, // published
)
require.Nil(t, err)
}
// Create cache to trigger migration
cacheDuration := 17 * time.Hour
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
require.Nil(t, err)
checkSchemaVersion(t, c.db)
// Check version
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
require.Nil(t, err)
require.True(t, rows.Next())
var version int
require.Nil(t, rows.Scan(&version))
require.Equal(t, currentSchemaVersion, version)
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 10, len(messages))
for _, m := range messages {
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
}
}
func TestSqliteCache_StartupQueries_WAL(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := `pragma journal_mode = WAL;
pragma synchronous = normal;
pragma temp_store = memory;`
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.FileExists(t, filename+"-wal")
require.FileExists(t, filename+"-shm")
}
func TestSqliteCache_StartupQueries_None(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := ""
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
require.FileExists(t, filename)
require.NoFileExists(t, filename+"-wal")
require.NoFileExists(t, filename+"-shm")
}
func TestSqliteCache_StartupQueries_Fail(t *testing.T) {
filename := newSqliteTestCacheFile(t)
startupQueries := `xx error`
_, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Error(t, err)
}
func TestSqliteCache_Sender(t *testing.T) {
testSender(t, newSqliteTestCache(t))
}
func TestMemCache_Sender(t *testing.T) {
testSender(t, newMemTestCache(t))
}
func testSender(t *testing.T, c *messageCache) {
m1 := newDefaultMessage("mytopic", "mymessage")
m1.Sender = netip.MustParseAddr("1.2.3.4")
require.Nil(t, c.AddMessage(m1))
m2 := newDefaultMessage("mytopic", "mymessage without sender")
require.Nil(t, c.AddMessage(m2))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
require.Equal(t, messages[1].Sender, netip.Addr{})
}
func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newSqliteTestCache(t))
}
func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) {
testDeleteScheduledBySequenceID(t, newMemTestCache(t))
}
func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) {
// Create a scheduled (unpublished) message
scheduledMsg := newDefaultMessage("mytopic", "scheduled message")
scheduledMsg.ID = "scheduled1"
scheduledMsg.SequenceID = "seq123"
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
require.Nil(t, c.AddMessage(scheduledMsg))
// Create a published message with different sequence ID
publishedMsg := newDefaultMessage("mytopic", "published message")
publishedMsg.ID = "published1"
publishedMsg.SequenceID = "seq456"
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
require.Nil(t, c.AddMessage(publishedMsg))
// Create a scheduled message in a different topic
otherTopicMsg := newDefaultMessage("othertopic", "other scheduled")
otherTopicMsg.ID = "other1"
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
require.Nil(t, c.AddMessage(otherTopicMsg))
// Verify all messages exist (including scheduled)
messages, err := c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 2, len(messages))
messages, err = c.Messages("othertopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
// Delete scheduled message by sequence ID and verify returned IDs
deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123")
require.Nil(t, err)
require.Equal(t, 1, len(deletedIDs))
require.Equal(t, "scheduled1", deletedIDs[0])
// Verify scheduled message is deleted
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
// Verify other topic's message still exists (topic-scoped deletion)
messages, err = c.Messages("othertopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "other scheduled", messages[0].Message)
// Deleting non-existent sequence ID should return empty list
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent")
require.Nil(t, err)
require.Empty(t, deletedIDs)
// Deleting published message should not affect it (only deletes unpublished)
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456")
require.Nil(t, err)
require.Empty(t, deletedIDs)
messages, err = c.Messages("mytopic", sinceAllMessages, true)
require.Nil(t, err)
require.Equal(t, 1, len(messages))
require.Equal(t, "published message", messages[0].Message)
}
func checkSchemaVersion(t *testing.T, db *sql.DB) {
rows, err := db.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err)
require.True(t, rows.Next())
var schemaVersion int
require.Nil(t, rows.Scan(&schemaVersion))
require.Equal(t, currentSchemaVersion, schemaVersion)
require.Nil(t, rows.Close())
}
func TestMemCache_NopCache(t *testing.T) {
c, _ := newNopCache()
require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
messages, err := c.Messages("mytopic", sinceAllMessages, false)
require.Nil(t, err)
require.Empty(t, messages)
topics, err := c.Topics()
require.Nil(t, err)
require.Empty(t, topics)
}
func newSqliteTestCache(t *testing.T) *messageCache {
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
if err != nil {
t.Fatal(err)
}
return c
}
func newSqliteTestCacheFile(t *testing.T) string {
return filepath.Join(t.TempDir(), "cache.db")
}
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
require.Nil(t, err)
return c
}
func newMemTestCache(t *testing.T) *messageCache {
c, err := newMemCache()
require.Nil(t, err)
return c
}

View File

@@ -33,6 +33,8 @@ import (
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
@@ -57,7 +59,7 @@ type Server struct {
messages int64 // Total number of messages (persisted if messageCache enabled)
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
userManager *user.Manager // Might be nil!
messageCache *messageCache // Database that stores the messages
messageCache message.Store // Database that stores the messages
webPush webpush.Store // Database that stores web push subscriptions
fileCache *fileCache // File system based cache that stores attachments
stripe stripeAPI // Stripe API, can be replaced with a mock
@@ -188,10 +190,14 @@ func New(conf *Config) (*Server, error) {
return nil, err
}
}
topics, err := messageCache.Topics()
topicIDs, err := messageCache.Topics()
if err != nil {
return nil, err
}
topics := make(map[string]*topic, len(topicIDs))
for _, id := range topicIDs {
topics[id] = newTopic(id)
}
messages, err := messageCache.Stats()
if err != nil {
return nil, err
@@ -263,13 +269,15 @@ func New(conf *Config) (*Server, error) {
return s, nil
}
func createMessageCache(conf *Config) (*messageCache, error) {
func createMessageCache(conf *Config) (message.Store, error) {
if conf.CacheDuration == 0 {
return newNopCache()
return message.NewNopStore()
} else if conf.DatabaseURL != "" {
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
} else if conf.CacheFile != "" {
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
}
return newMemCache()
return message.NewMemStore()
}
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
@@ -750,7 +758,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
if s.config.CacheBatchTimeout > 0 {
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
// and messages are persisted asynchronously, retry fetching from the database
m, err = util.Retry(func() (*message, error) {
m, err = util.Retry(func() (*model.Message, error) {
return s.messageCache.Message(messageID)
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
}
@@ -796,7 +804,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
return writeMatrixDiscoveryResponse(w)
}
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) {
start := time.Now()
t, err := fromContext[*topic](r, contextTopic)
if err != nil {
@@ -924,7 +932,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err
}
minc(metricMessagesPublishedSuccess)
return s.writeJSON(w, m.forJSON())
return s.writeJSON(w, m.ForJSON())
}
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -1014,10 +1022,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
s.mu.Lock()
s.messages++
s.mu.Unlock()
return s.writeJSON(w, m.forJSON())
return s.writeJSON(w, m.ForJSON())
}
func (s *Server) sendToFirebase(v *visitor, m *message) {
func (s *Server) sendToFirebase(v *visitor, m *model.Message) {
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
if err := s.firebaseClient.Send(v, m); err != nil {
minc(metricFirebasePublishedFailure)
@@ -1031,7 +1039,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
minc(metricFirebasePublishedSuccess)
}
func (s *Server) sendEmail(v *visitor, m *message, email string) {
func (s *Server) sendEmail(v *visitor, m *model.Message, email string) {
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
if err := s.smtpSender.Send(v, m, email); err != nil {
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
@@ -1041,7 +1049,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) {
minc(metricEmailsPublishedSuccess)
}
func (s *Server) forwardPollRequest(v *visitor, m *message) {
func (s *Server) forwardPollRequest(v *visitor, m *model.Message) {
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
@@ -1073,7 +1081,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
}
}
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
if err != nil {
@@ -1100,7 +1108,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
filename := readParam(r, "x-filename", "filename", "file", "f")
attach := readParam(r, "x-attach", "attach", "a")
if attach != "" || filename != "" {
m.Attachment = &attachment{}
m.Attachment = &model.Attachment{}
}
if filename != "" {
m.Attachment.Name = filename
@@ -1221,7 +1229,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
// 7. curl -T file.txt ntfy.sh/mytopic
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
if m.Event == pollRequestEvent { // Case 1
return s.handleBodyDiscard(body)
} else if unifiedpush {
@@ -1244,7 +1252,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
return err
}
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error {
if utf8.Valid(body.PeekedBytes) {
m.Message = string(body.PeekedBytes) // Do not trim
} else {
@@ -1254,7 +1262,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
return nil
}
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error {
if !utf8.Valid(body.PeekedBytes) {
return errHTTPBadRequestMessageNotUTF8.With(m)
}
@@ -1267,7 +1275,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
return nil
}
func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
if err != nil {
return err
@@ -1292,7 +1300,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
// The template file must be in the templates directory, or in the configured template directory.
func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error {
func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error {
if !templateNameRegex.MatchString(templateName) {
return errHTTPBadRequestTemplateFileNotFound
}
@@ -1334,7 +1342,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
// message, title, and priority parameters.
func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error {
func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error {
var err error
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
return err
@@ -1375,7 +1383,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
return errHTTPBadRequestAttachmentsDisallowed.With(m)
}
@@ -1399,7 +1407,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
}
}
if m.Attachment == nil {
m.Attachment = &attachment{}
m.Attachment = &model.Attachment{}
}
var ext string
m.Attachment.Expires = attachmentExpiry
@@ -1426,9 +1434,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
}
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
encoder := func(msg *model.Message) (string, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
return "", err
}
return buf.String(), nil
@@ -1437,9 +1445,9 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
}
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
encoder := func(msg *model.Message) (string, error) {
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
return "", err
}
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
@@ -1451,7 +1459,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
}
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
encoder := func(msg *message) (string, error) {
encoder := func(msg *model.Message) (string, error) {
if msg.Event == messageEvent { // only handle default events
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
}
@@ -1487,7 +1495,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
closed = true
wlock.Unlock()
}()
sub := func(v *visitor, msg *message) error {
sub := func(v *visitor, msg *model.Message) error {
if !filters.Pass(msg) {
return nil
}
@@ -1649,7 +1657,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
}
})
sub := func(v *visitor, msg *message) error {
sub := func(v *visitor, msg *model.Message) error {
if !filters.Pass(msg) {
return nil
}
@@ -1696,7 +1704,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return nil
}
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) {
poll = readBoolParam(r, false, "x-poll", "poll", "po")
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
since, err = parseSince(r, poll)
@@ -1777,11 +1785,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
// marker, returning only messages that are newer than the marker.
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error {
if since.IsNone() {
return nil
}
messages := make([]*message, 0)
messages := make([]*model.Message, 0)
for _, t := range topics {
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
if err != nil {
@@ -1804,7 +1812,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
//
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
// "all" for all messages, or "latest" for the most recent message for a topic
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
since := readParam(r, "x-since", "since", "si")
// Easy cases (empty, all, none)
@@ -2035,7 +2043,7 @@ func (s *Server) sendDelayedMessages() error {
return nil
}
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error {
logvm(v, m).Debug("Sending delayed message")
s.mu.RLock()
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published

File diff suppressed because it is too large Load Diff

View File

@@ -11,427 +11,451 @@ import (
)
func TestVersion_Admin(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.BuildVersion = "1.2.3"
c.BuildCommit = "abcdef0"
c.BuildDate = "2026-02-08T00:00:00Z"
s := newTestServer(t, c)
defer s.closeDatabases()
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.BuildVersion = "1.2.3"
c.BuildCommit = "abcdef0"
c.BuildDate = "2026-02-08T00:00:00Z"
s := newTestServer(t, c)
defer s.closeDatabases()
// Create admin and regular user
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Create admin and regular user
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Admin can access /v1/version
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
// Admin can access /v1/version
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
var versionResponse apiVersionResponse
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
require.Equal(t, "1.2.3", versionResponse.Version)
require.Equal(t, "abcdef0", versionResponse.Commit)
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
// Non-admin user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Unauthenticated user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", nil)
require.Equal(t, 401, rr.Code)
})
require.Equal(t, 200, rr.Code)
var versionResponse apiVersionResponse
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
require.Equal(t, "1.2.3", versionResponse.Version)
require.Equal(t, "abcdef0", versionResponse.Commit)
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
// Non-admin user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Unauthenticated user cannot access /v1/version
rr = request(t, s, "GET", "/v1/version", "", nil)
require.Equal(t, 401, rr.Code)
}
func TestUser_AddRemove(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check user was deleted
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "emma", users[1].Name)
require.Equal(t, user.Everyone, users[2].Name)
// Reject invalid user change
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
})
require.Equal(t, 200, rr.Code)
// Create user with tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 4, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Nil(t, users[1].Tier)
require.Equal(t, "emma", users[2].Name)
require.Equal(t, user.RoleUser, users[2].Role)
require.Equal(t, "tier1", users[2].Tier.Code)
require.Equal(t, user.Everyone, users[3].Name)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check user was deleted
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "emma", users[1].Name)
require.Equal(t, user.Everyone, users[2].Name)
// Reject invalid user change
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
}
func TestUser_AddWithPasswordHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check that user can login with password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, user.RoleAdmin, users[0].Role)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
}
func TestUser_ChangeUserPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Change password via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_ChangeUserTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users again
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
// Check new tier
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "not-ben"),
})
require.Equal(t, 200, rr.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
// Try to change password via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 403, rr.Code)
}
func TestUser_AddRemove_Failures(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Cannot create user with invalid username
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
}
func TestAccess_AllowReset(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
// Check that user can login with password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, user.RoleAdmin, users[0].Role)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
})
}
func TestUser_ChangeUserPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Change password via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestUser_ChangeUserTier(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users again
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
})
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
// Check new tier
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
})
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "not-ben"),
})
require.Equal(t, 200, rr.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
// Try to change password via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 403, rr.Code)
})
}
func TestUser_AddRemove_Failures(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Cannot create user with invalid username
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
// Cannot delete user as non-admin
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Delete user via API
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
})
}
func TestAccess_AllowReset(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
})
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
})
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
})
})
}

View File

@@ -10,6 +10,7 @@ import (
"firebase.google.com/go/v4/messaging"
"fmt"
"google.golang.org/api/option"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
"strings"
@@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
}
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
if !v.FirebaseAllowed() {
return errFirebaseTemporarilyBanned
}
@@ -121,7 +122,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
// to Firebase here. This is mainly for iOS to support self-hosted servers.
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) {
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
var apnsConfig *messaging.APNSConfig
switch m.Event {
@@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
// Extension in iOS can modify the message.
func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig {
func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig {
apnsData := make(map[string]any)
for k, v := range data {
apnsData[k] = v
@@ -296,7 +297,7 @@ func maybeTruncateAPNSBodyMessage(s string) string {
//
// This empties all the fields that are not needed for a poll request and just sets the required fields,
// most importantly, the PollID.
func toPollRequest(m *message) *message {
func toPollRequest(m *model.Message) *model.Message {
pr := newPollRequestMessage(m.Topic, m.ID)
pr.ID = m.ID
pr.Time = m.Time

View File

@@ -4,6 +4,7 @@ package server
import (
"errors"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
)
@@ -21,7 +22,7 @@ var (
type firebaseClient struct {
}
func (c *firebaseClient) Send(v *visitor, m *message) error {
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
return errFirebaseNotAvailable
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"net/netip"
"strings"
@@ -131,7 +132,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
m.Click = "https://google.com"
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
m.Title = "some title"
m.Actions = []*action{
m.Actions = []*model.Action{
{
ID: "123",
Action: "view",
@@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
},
},
}
m.Attachment = &attachment{
m.Attachment = &model.Attachment{
Name: "some file.jpg",
Type: "image/jpeg",
Size: 12345,
@@ -346,16 +347,16 @@ func TestToFirebaseSender_Abuse(t *testing.T) {
client := newFirebaseClient(sender, &testAuther{})
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 1, len(sender.Messages()))
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.Messages()))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 2, len(sender.Messages()))
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"}))
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"}))
require.Equal(t, 0, len(sender.Messages()))
}

View File

@@ -6,23 +6,25 @@ import (
)
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T) {
// Tests that the manager runs without attachment-cache-dir set, see #617
c := newTestConfig(t)
c.AttachmentCacheDir = ""
s := newTestServer(t, c)
// Publish a message
rr := request(t, s, "POST", "/mytopic", "hi", nil)
require.Equal(t, 200, rr.Code)
m := toMessage(t, rr.Body.String())
// Publish a message
rr := request(t, s, "POST", "/mytopic", "hi", nil)
require.Equal(t, 200, rr.Code)
m := toMessage(t, rr.Body.String())
// Expire message
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
// Expire message
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
// Does not panic
s.pruneMessages()
// Does not panic
s.pruneMessages()
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, errMessageNotFound, err)
// Actually deleted
_, err := s.messageCache.Message(m.ID)
require.Equal(t, errMessageNotFound, err)
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -11,6 +11,7 @@ import (
"text/template"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
// Failures will be logged, but not returned to the caller.
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) {
u, sender := v.User(), m.Sender.String()
if u != nil {
sender = u.Name

View File

@@ -14,217 +14,224 @@ import (
)
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
if code.Load() != nil {
forEachBackend(t, func(t *testing.T) {
var called, verified atomic.Bool
var code atomic.Pointer[string]
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
if code.Load() != nil {
t.Fatal("Should be only called once")
}
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
code.Store(util.String("123456"))
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
if verified.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
verified.Store(true)
} else {
t.Fatal("Unexpected path:", r.URL.Path)
}
}))
defer twilioVerifyServer.Close()
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
code.Store(util.String("123456"))
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
if verified.Load() {
t.Fatal("Should be only called once")
}
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
verified.Store(true)
} else {
t.Fatal("Unexpected path:", r.URL.Path)
}
}))
defer twilioVerifyServer.Close()
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioVerifyService = "VA1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioCallsServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
c.TwilioCallsBaseURL = twilioCallsServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioVerifyService = "VA1234567890"
s := newTestServer(t, c)
// Send verification code for phone number
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return *code.Load() == "123456"
})
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
// Add phone number with code
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return verified.Load()
})
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 1, len(phoneNumbers))
require.Equal(t, "+12223334444", phoneNumbers[0])
// Send verification code for phone number
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return *code.Load() == "123456"
})
// Do the thing
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
// Add phone number with code
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
waitFor(t, func() bool {
return verified.Load()
})
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 1, len(phoneNumbers))
require.Equal(t, "+12223334444", phoneNumbers[0])
// Remove the phone number
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Do the thing
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes",
// Verify the phone number is gone from the DB
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 0, len(phoneNumbers))
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
// Remove the phone number
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, response.Code)
// Verify the phone number is gone from the DB
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
require.Nil(t, err)
require.Equal(t, 0, len(phoneNumbers))
}
func TestServer_Twilio_Call_Success(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
forEachBackend(t, func(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
forEachBackend(t, func(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes", // <<<------
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "yes", // <<<------
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
forEachBackend(t, func(t *testing.T) {
var called atomic.Bool
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if called.Load() {
t.Fatal("Should be only called once")
}
body, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
called.Store(true)
}))
defer twilioServer.Close()
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = twilioServer.URL
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
<Response>
<Pause length="1"/>
<Say language="de-DE" loop="3">
@@ -240,88 +247,97 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
</Say>
<Say language="de-DE">Auf Wiederhören.</Say>
</Response>`))
s := newTestServer(t, c)
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
// Do the thing
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
waitFor(t, func() bool {
return called.Load()
})
})
}
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "http://dummy.invalid"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Add tier and user
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Do the thing
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
// Do the thing
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"authorization": util.BasicAuth("phil", "phil"),
"x-call": "+11122233344",
})
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+invalid",
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+invalid",
})
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
forEachBackend(t, func(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.TwilioCallsBaseURL = "https://127.0.0.1"
c.TwilioAccount = "AC1234567890"
c.TwilioAuthToken = "AAEAA1234567890"
c.TwilioPhoneNumber = "+1234567890"
s := newTestServer(t, c)
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+123123",
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+123123",
})
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
}
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"x-call": "+1234",
})
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
})
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/SherClockHolmes/webpush-go"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
wpush "heckel.io/ntfy/v2/webpush"
)
@@ -83,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
if err != nil {
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
return
}
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON()))
if err != nil {
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
return

View File

@@ -4,6 +4,8 @@ package server
import (
"net/http"
"heckel.io/ntfy/v2/model"
)
const (
@@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
return errHTTPNotFound
}
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
// Nothing to see here
}

View File

@@ -26,236 +26,262 @@ const (
)
func TestServer_WebPush_Enabled(t *testing.T) {
conf := newTestConfig(t)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
forEachBackend(t, func(t *testing.T) {
conf := newTestConfig(t)
conf.WebRoot = "" // Disable web app
s := newTestServer(t, conf)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf2 := newTestConfig(t)
s2 := newTestServer(t, conf2)
conf2 := newTestConfig(t)
s2 := newTestServer(t, conf2)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 404, rr.Code)
conf3 := newTestConfigWithWebPush(t)
s3 := newTestServer(t, conf3)
conf3 := newTestConfigWithWebPush(t)
s3 := newTestServer(t, conf3)
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
require.Equal(t, 200, rr.Code)
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
})
}
func TestServer_WebPush_Disabled(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 404, response.Code)
})
}
func TestServer_WebPush_TopicAdd(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
require.Len(t, subs, 1)
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
require.Equal(t, subs[0].P256dh, "p256dh-key")
require.Equal(t, subs[0].Auth, "auth-key")
require.Equal(t, subs[0].UserID, "")
})
}
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
})
}
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
topicList := make([]string, 51)
for i := range topicList {
topicList[i] = util.RandomString(5)
}
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
require.Equal(t, 400, response.Code)
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
})
}
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_Delete(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
addSubscription(t, s, testWebPushEndpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
}
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
require.Equal(t, 403, response.Code)
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic", 0)
})
}
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
forEachBackend(t, func(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
})
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
requireSubscriptionCount(t, s, "test-topic", 1)
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
// should've been deleted with the account
requireSubscriptionCount(t, s, "test-topic", 0)
}
func TestServer_WebPush_Publish(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
require.Equal(t, "/push-receive", r.URL.Path)
require.Equal(t, "high", r.Header.Get("Urgency"))
require.Equal(t, "", r.Header.Get("Topic"))
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
waitFor(t, func() bool {
return received.Load()
})
})
}
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(http.StatusGone)
received.Store(true)
}))
defer pushService.Close()
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
requireSubscriptionCount(t, s, "test-topic", 1)
requireSubscriptionCount(t, s, "test-topic-abc", 1)
request(t, s, "POST", "/test-topic", "web push test", nil)
request(t, s, "POST", "/test-topic", "web push test", nil)
waitFor(t, func() bool {
return received.Load()
waitFor(t, func() bool {
return received.Load()
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
})
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
requireSubscriptionCount(t, s, "test-topic", 0)
requireSubscriptionCount(t, s, "test-topic-abc", 0)
}
func TestServer_WebPush_Expiry(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
forEachBackend(t, func(t *testing.T) {
s := newTestServer(t, newTestConfigWithWebPush(t))
var received atomic.Bool
var received atomic.Bool
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
require.Nil(t, err)
w.WriteHeader(200)
w.Write([]byte(``))
received.Store(true)
}))
defer pushService.Close()
endpoint := pushService.URL + "/push-receive"
addSubscription(t, s, endpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
endpoint := pushService.URL + "/push-receive"
addSubscription(t, s, endpoint, "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
s.pruneAndNotifyWebPushSubscriptions()
requireSubscriptionCount(t, s, "test-topic", 1)
waitFor(t, func() bool {
return received.Load()
})
waitFor(t, func() bool {
return received.Load()
})
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
s.pruneAndNotifyWebPushSubscriptions()
waitFor(t, func() bool {
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
return len(subs) == 0
})
})
}
@@ -285,7 +311,9 @@ func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
require.Nil(t, err)
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
if conf.DatabaseURL == "" {
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
}
conf.WebPushEmailAddress = "testing@example.com"
conf.WebPushPrivateKey = privateKey
conf.WebPushPublicKey = publicKey

View File

@@ -12,11 +12,12 @@ import (
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
type mailer interface {
Send(v *visitor, m *message, to string) error
Send(v *visitor, m *model.Message, to string) error
Counts() (total int64, success int64, failure int64)
}
@@ -27,7 +28,7 @@ type smtpSender struct {
mu sync.Mutex
}
func (s *smtpSender) Send(v *visitor, m *message, to string) error {
func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error {
return s.withCount(v, m, func() error {
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
if err != nil {
@@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) {
return s.success + s.failure, s.success, s.failure
}
func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error {
err := fn()
s.mu.Lock()
defer s.mu.Unlock()
@@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
return err
}
func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) {
func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) {
topicURL := baseURL + "/" + m.Topic
subject := m.Title
if subject == "" {

View File

@@ -1,12 +1,14 @@
package server
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
)
func TestFormatMail_Basic(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustEmojis(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustOtherTags(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_JustPriority(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_UTF8Subject(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",
@@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
}
func TestFormatMail_WithAllTheThings(t *testing.T) {
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
ID: "abc",
Time: 1640382204,
Event: "message",

View File

@@ -19,6 +19,7 @@ import (
"github.com/emersion/go-smtp"
"github.com/microcosm-cc/bluemonday"
"heckel.io/ntfy/v2/model"
)
var (
@@ -183,7 +184,7 @@ func (s *smtpSession) Data(r io.Reader) error {
})
}
func (s *smtpSession) publishMessage(m *message) error {
func (s *smtpSession) publishMessage(m *model.Message) error {
// Extract remote address (for rate limiting)
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
if err != nil {

View File

@@ -6,6 +6,7 @@ import (
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
@@ -33,7 +34,7 @@ type topicSubscriber struct {
}
// subscriber is a function that is called for every new message on a topic
type subscriber func(v *visitor, msg *message) error
type subscriber func(v *visitor, msg *model.Message) error
// newTopic creates a new topic
func newTopic(id string) *topic {
@@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) {
}
// Publish asynchronously publishes to all subscribers
func (t *topic) Publish(v *visitor, m *message) error {
func (t *topic) Publish(v *visitor, m *model.Message) error {
go func() {
// We want to lock the topic as short as possible, so we make a shallow copy of the
// subscribers map here. Actually sending out the messages then doesn't have to lock.

View File

@@ -7,10 +7,11 @@ import (
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/model"
)
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}
canceled1 := atomic.Bool{}
@@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
func TestTopic_CancelSubscribersUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}
canceled1 := atomic.Bool{}
@@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) {
cancel: func() {},
}
subFn := func(v *visitor, msg *message) error {
subFn := func(v *visitor, msg *model.Message) error {
return nil
}

View File

@@ -2,219 +2,78 @@ package server
import (
"net/http"
"net/netip"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
// List of possible events
// Event constants
const (
openEvent = "open"
keepaliveEvent = "keepalive"
messageEvent = "message"
messageDeleteEvent = "message_delete"
messageClearEvent = "message_clear"
pollRequestEvent = "poll_request"
openEvent = model.OpenEvent
keepaliveEvent = model.KeepaliveEvent
messageEvent = model.MessageEvent
messageDeleteEvent = model.MessageDeleteEvent
messageClearEvent = model.MessageClearEvent
pollRequestEvent = model.PollRequestEvent
messageIDLength = model.MessageIDLength
)
const (
messageIDLength = 12
// SinceMarker aliases
var (
sinceAllMessages = model.SinceAllMessages
sinceNoMessages = model.SinceNoMessages
sinceLatestMessage = model.SinceLatestMessage
)
// message represents a message published to a topic
type message struct {
ID string `json:"id"` // Random message ID
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
Time int64 `json:"time"` // Unix time in seconds
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Click string `json:"click,omitempty"`
Icon string `json:"icon,omitempty"`
Actions []*action `json:"actions,omitempty"`
Attachment *attachment `json:"attachment,omitempty"`
PollID string `json:"poll_id,omitempty"`
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // UserID of the uploader, used to associated attachments
}
// Error aliases
var (
errMessageNotFound = model.ErrMessageNotFound
)
func (m *message) Context() log.Context {
fields := map[string]any{
"topic": m.Topic,
"message_id": m.ID,
"message_sequence_id": m.SequenceID,
"message_time": m.Time,
"message_event": m.Event,
"message_body_size": len(m.Message),
}
if m.Sender.IsValid() {
fields["message_sender"] = m.Sender.String()
}
if m.User != "" {
fields["message_user"] = m.User
}
return fields
}
// forJSON returns a copy of the message suitable for JSON output.
// It clears the SequenceID if it equals the ID to reduce redundancy.
func (m *message) forJSON() *message {
if m.SequenceID == m.ID {
clone := *m
clone.SequenceID = ""
return &clone
}
return m
}
type attachment struct {
Name string `json:"name"`
Type string `json:"type,omitempty"`
Size int64 `json:"size,omitempty"`
Expires int64 `json:"expires,omitempty"`
URL string `json:"url"`
}
type action struct {
ID string `json:"id"`
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
Label string `json:"label"` // action button label
Clear bool `json:"clear"` // clear notification after successful execution
URL string `json:"url,omitempty"` // used in "view" and "http" actions
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
Body string `json:"body,omitempty"` // used in "http" action
Intent string `json:"intent,omitempty"` // used in "broadcast" action
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
Value string `json:"value,omitempty"` // used in "copy" action
}
func newAction() *action {
return &action{
Headers: make(map[string]string),
Extras: make(map[string]string),
}
}
// publishMessage is used as input when publishing as JSON
type publishMessage struct {
Topic string `json:"topic"`
SequenceID string `json:"sequence_id"`
Title string `json:"title"`
Message string `json:"message"`
Priority int `json:"priority"`
Tags []string `json:"tags"`
Click string `json:"click"`
Icon string `json:"icon"`
Actions []action `json:"actions"`
Attach string `json:"attach"`
Markdown bool `json:"markdown"`
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
Delay string `json:"delay"`
}
// messageEncoder is a function that knows how to encode a message
type messageEncoder func(msg *message) (string, error)
// newMessage creates a new message with the current timestamp
func newMessage(event, topic, msg string) *message {
return &message{
ID: util.RandomString(messageIDLength),
Time: time.Now().Unix(),
Event: event,
Topic: topic,
Message: msg,
}
}
// newOpenMessage is a convenience method to create an open message
func newOpenMessage(topic string) *message {
return newMessage(openEvent, topic, "")
}
// newKeepaliveMessage is a convenience method to create a keepalive message
func newKeepaliveMessage(topic string) *message {
return newMessage(keepaliveEvent, topic, "")
}
// newDefaultMessage is a convenience method to create a notification message
func newDefaultMessage(topic, msg string) *message {
return newMessage(messageEvent, topic, msg)
}
// Constructors and helpers
var (
newMessage = model.NewMessage
newDefaultMessage = model.NewDefaultMessage
newOpenMessage = model.NewOpenMessage
newKeepaliveMessage = model.NewKeepaliveMessage
newActionMessage = model.NewActionMessage
newAction = model.NewAction
newSinceTime = model.NewSinceTime
newSinceID = model.NewSinceID
validMessageID = model.ValidMessageID
)
// newPollRequestMessage is a convenience method to create a poll request message
func newPollRequestMessage(topic, pollID string) *message {
func newPollRequestMessage(topic, pollID string) *model.Message {
m := newMessage(pollRequestEvent, topic, newMessageBody)
m.PollID = pollID
return m
}
// newActionMessage creates a new action message (message_delete or message_clear)
func newActionMessage(event, topic, sequenceID string) *message {
m := newMessage(event, topic, "")
m.SequenceID = sequenceID
return m
// publishMessage is used as input when publishing as JSON
type publishMessage struct {
Topic string `json:"topic"`
SequenceID string `json:"sequence_id"`
Title string `json:"title"`
Message string `json:"message"`
Priority int `json:"priority"`
Tags []string `json:"tags"`
Click string `json:"click"`
Icon string `json:"icon"`
Actions []model.Action `json:"actions"`
Attach string `json:"attach"`
Markdown bool `json:"markdown"`
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
Delay string `json:"delay"`
}
func validMessageID(s string) bool {
return util.ValidRandomString(s, messageIDLength)
}
type sinceMarker struct {
time time.Time
id string
}
func newSinceTime(timestamp int64) sinceMarker {
return sinceMarker{time.Unix(timestamp, 0), ""}
}
func newSinceID(id string) sinceMarker {
return sinceMarker{time.Unix(0, 0), id}
}
func (t sinceMarker) IsAll() bool {
return t == sinceAllMessages
}
func (t sinceMarker) IsNone() bool {
return t == sinceNoMessages
}
func (t sinceMarker) IsLatest() bool {
return t == sinceLatestMessage
}
func (t sinceMarker) IsID() bool {
return t.id != "" && t.id != "latest"
}
func (t sinceMarker) Time() time.Time {
return t.time
}
func (t sinceMarker) ID() string {
return t.id
}
var (
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
)
// messageEncoder is a function that knows how to encode a message
type messageEncoder func(msg *model.Message) (string, error)
type queryFilter struct {
ID string
@@ -246,7 +105,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
}, nil
}
func (q *queryFilter) Pass(msg *message) bool {
func (q *queryFilter) Pass(msg *model.Message) bool {
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
return true // filters only apply to messages
} else if q.ID != "" && msg.ID != q.ID {
@@ -570,12 +429,12 @@ const (
)
type webPushPayload struct {
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *message `json:"message"`
Event string `json:"event"`
SubscriptionID string `json:"subscription_id"`
Message *model.Message `json:"message"`
}
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload {
return &webPushPayload{
Event: webPushMessageEvent,
SubscriptionID: subscriptionID,

View File

@@ -8,6 +8,7 @@ import (
"golang.org/x/time/rate"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/message"
"heckel.io/ntfy/v2/user"
"heckel.io/ntfy/v2/util"
)
@@ -53,7 +54,7 @@ const (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
messageCache *messageCache
messageCache message.Store
userManager *user.Manager // May be nil
ip netip.Addr // Visitor IP address
user *user.User // Only set if authenticated user, otherwise nil
@@ -114,7 +115,7 @@ const (
visitorLimitBasisTier = visitorLimitBasis("tier")
)
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
var messages, emails, calls int64
if user != nil {
messages = user.Stats.Messages

View File

@@ -691,7 +691,6 @@ func TestManager_Token_Expire(t *testing.T) {
// But the token row should still exist
tokens, err := a.Tokens(u.ID)
require.Nil(t, err)
require.Equal(t, token1.Value, tokens[0].Value)
require.Equal(t, 2, len(tokens))
// Expire tokens and check that token1 is gone

View File

@@ -94,38 +94,6 @@ func nullInt64(v int64) sql.NullInt64 {
return sql.NullInt64{Int64: v, Valid: true}
}
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
// and escapes '_', assuming '\' as escape character.
func toSQLWildcard(s string) string {