From 2716ede6e1dc92f07719aff48c2a6fa1415009e2 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Wed, 18 Feb 2026 20:22:44 -0500 Subject: [PATCH] Extract message cache into message/ package with model/ types Move message cache from server/message_cache.go into a dedicated message/ package with Store interface, SQLite and PostgreSQL implementations. Extract shared types into model/ package. --- message/store.go | 620 +++++++++++++++++ message/store_postgres.go | 118 ++++ message/store_postgres_schema.go | 90 +++ message/store_postgres_test.go | 120 ++++ message/store_sqlite.go | 138 ++++ message/store_sqlite_schema.go | 466 +++++++++++++ message/store_sqlite_test.go | 459 +++++++++++++ message/store_test.go | 767 +++++++++++++++++++++ model/model.go | 204 ++++++ server/actions.go | 17 +- server/log.go | 5 +- server/message_cache.go | 1104 ------------------------------ server/message_cache_test.go | 825 ---------------------- server/server.go | 78 ++- server/server_firebase.go | 9 +- server/server_firebase_dummy.go | 3 +- server/server_firebase_test.go | 13 +- server/server_test.go | 32 +- server/server_twilio.go | 3 +- server/server_webpush.go | 5 +- server/server_webpush_dummy.go | 4 +- server/smtp_sender.go | 9 +- server/smtp_sender_test.go | 16 +- server/smtp_server.go | 3 +- server/topic.go | 5 +- server/topic_test.go | 7 +- server/types.go | 218 +----- server/visitor.go | 5 +- user/manager_test.go | 1 - 29 files changed, 3142 insertions(+), 2202 deletions(-) create mode 100644 message/store.go create mode 100644 message/store_postgres.go create mode 100644 message/store_postgres_schema.go create mode 100644 message/store_postgres_test.go create mode 100644 message/store_sqlite.go create mode 100644 message/store_sqlite_schema.go create mode 100644 message/store_sqlite_test.go create mode 100644 message/store_test.go create mode 100644 model/model.go delete mode 100644 server/message_cache.go delete mode 100644 server/message_cache_test.go diff --git a/message/store.go b/message/store.go new file mode 100644 index 00000000..c51f732f --- /dev/null +++ b/message/store.go @@ -0,0 +1,620 @@ +package message + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/netip" + "strings" + "sync" + "time" + + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/util" +) + +const ( + tagMessageCache = "message_cache" +) + +var errNoRows = errors.New("no rows found") + +// Store is the interface for a message cache store +type Store interface { + AddMessage(m *model.Message) error + AddMessages(ms []*model.Message) error + DB() *sql.DB + Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) + MessagesDue() ([]*model.Message, error) + MessagesExpired() ([]string, error) + Message(id string) (*model.Message, error) + MarkPublished(m *model.Message) error + MessageCounts() (map[string]int, error) + Topics() ([]string, error) + DeleteMessages(ids ...string) error + DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) + ExpireMessages(topics ...string) error + AttachmentsExpired() ([]string, error) + MarkAttachmentsDeleted(ids ...string) error + AttachmentBytesUsedBySender(sender string) (int64, error) + AttachmentBytesUsedByUser(userID string) (int64, error) + UpdateStats(messages int64) error + Stats() (int64, error) + Close() error +} + +// storeQueries holds the database-specific SQL queries +type storeQueries struct { + insertMessage string + deleteMessage string + selectScheduledMessageIDsBySeqID string + deleteScheduledBySequenceID string + updateMessagesForTopicExpiry string + selectRowIDFromMessageID string + selectMessagesByID string + selectMessagesSinceTime string + selectMessagesSinceTimeScheduled string + selectMessagesSinceID string + selectMessagesSinceIDScheduled string + selectMessagesLatest string + selectMessagesDue string + selectMessagesExpired string + updateMessagePublished string + selectMessagesCount string + selectMessageCountPerTopic string + selectTopics string + updateAttachmentDeleted string + selectAttachmentsExpired string + selectAttachmentsSizeBySender string + selectAttachmentsSizeByUserID string + selectStats string + updateStats string +} + +// commonStore implements store operations that are identical across database backends +type commonStore struct { + db *sql.DB + queue *util.BatchingQueue[*model.Message] + nop bool + mu sync.Mutex + queries storeQueries +} + +func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore { + var queue *util.BatchingQueue[*model.Message] + if batchSize > 0 || batchTimeout > 0 { + queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout) + } + c := &commonStore{ + db: db, + queue: queue, + nop: nop, + queries: queries, + } + go c.processMessageBatches() + return c +} + +// DB returns the underlying database connection +func (c *commonStore) DB() *sql.DB { + return c.db +} + +// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously. +func (c *commonStore) AddMessage(m *model.Message) error { + if c.queue != nil { + c.queue.Enqueue(m) + return nil + } + return c.addMessages([]*model.Message{m}) +} + +// AddMessages synchronously stores a batch of messages to the message cache +func (c *commonStore) AddMessages(ms []*model.Message) error { + return c.addMessages(ms) +} + +func (c *commonStore) addMessages(ms []*model.Message) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.nop { + return nil + } + if len(ms) == 0 { + return nil + } + start := time.Now() + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + stmt, err := tx.Prepare(c.queries.insertMessage) + if err != nil { + return err + } + defer stmt.Close() + for _, m := range ms { + if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent { + return model.ErrUnexpectedMessageType + } + published := m.Time <= time.Now().Unix() + tags := strings.Join(m.Tags, ",") + var attachmentName, attachmentType, attachmentURL string + var attachmentSize, attachmentExpires int64 + var attachmentDeleted bool + if m.Attachment != nil { + attachmentName = m.Attachment.Name + attachmentType = m.Attachment.Type + attachmentSize = m.Attachment.Size + attachmentExpires = m.Attachment.Expires + attachmentURL = m.Attachment.URL + } + var actionsStr string + if len(m.Actions) > 0 { + actionsBytes, err := json.Marshal(m.Actions) + if err != nil { + return err + } + actionsStr = string(actionsBytes) + } + var sender string + if m.Sender.IsValid() { + sender = m.Sender.String() + } + _, err := stmt.Exec( + m.ID, + m.SequenceID, + m.Time, + m.Event, + m.Expires, + m.Topic, + m.Message, + m.Title, + m.Priority, + tags, + m.Click, + m.Icon, + actionsStr, + attachmentName, + attachmentType, + attachmentSize, + attachmentExpires, + attachmentURL, + attachmentDeleted, // Always zero + sender, + m.User, + m.ContentType, + m.Encoding, + published, + ) + if err != nil { + return err + } + } + if err := tx.Commit(); err != nil { + log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start)) + return err + } + log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start)) + return nil +} + +func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { + if since.IsNone() { + return make([]*model.Message, 0), nil + } else if since.IsLatest() { + return c.messagesLatest(topic) + } else if since.IsID() { + return c.messagesSinceID(topic, since, scheduled) + } + return c.messagesSinceTime(topic, since, scheduled) +} + +func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { + var rows *sql.Rows + var err error + if scheduled { + rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix()) + } else { + rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix()) + } + if err != nil { + return nil, err + } + return readMessages(rows) +} + +func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { + idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID()) + if err != nil { + return nil, err + } + defer idrows.Close() + if !idrows.Next() { + return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled) + } + var rowID int64 + if err := idrows.Scan(&rowID); err != nil { + return nil, err + } + idrows.Close() + var rows *sql.Rows + if scheduled { + rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID) + } else { + rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID) + } + if err != nil { + return nil, err + } + return readMessages(rows) +} + +func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) { + rows, err := c.db.Query(c.queries.selectMessagesLatest, topic) + if err != nil { + return nil, err + } + return readMessages(rows) +} + +func (c *commonStore) MessagesDue() ([]*model.Message, error) { + rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix()) + if err != nil { + return nil, err + } + return readMessages(rows) +} + +// MessagesExpired returns a list of IDs for messages that have expired (should be deleted) +func (c *commonStore) MessagesExpired() ([]string, error) { + rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return ids, nil +} + +func (c *commonStore) Message(id string) (*model.Message, error) { + rows, err := c.db.Query(c.queries.selectMessagesByID, id) + if err != nil { + return nil, err + } + if !rows.Next() { + return nil, model.ErrMessageNotFound + } + defer rows.Close() + return readMessage(rows) +} + +func (c *commonStore) MarkPublished(m *model.Message) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.db.Exec(c.queries.updateMessagePublished, m.ID) + return err +} + +func (c *commonStore) MessageCounts() (map[string]int, error) { + rows, err := c.db.Query(c.queries.selectMessageCountPerTopic) + if err != nil { + return nil, err + } + defer rows.Close() + var topic string + var count int + counts := make(map[string]int) + for rows.Next() { + if err := rows.Scan(&topic, &count); err != nil { + return nil, err + } else if err := rows.Err(); err != nil { + return nil, err + } + counts[topic] = count + } + return counts, nil +} + +func (c *commonStore) Topics() ([]string, error) { + rows, err := c.db.Query(c.queries.selectTopics) + if err != nil { + return nil, err + } + defer rows.Close() + topics := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + topics = append(topics, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return topics, nil +} + +func (c *commonStore) DeleteMessages(ids ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, id := range ids { + if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil { + return err + } + } + return tx.Commit() +} + +// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. +// It returns the message IDs of the deleted messages, which can be used to clean up attachment files. +func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { + c.mu.Lock() + defer c.mu.Unlock() + tx, err := c.db.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID) + if err != nil { + return nil, err + } + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + rows.Close() + if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return ids, nil +} + +func (c *commonStore) ExpireMessages(topics ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, t := range topics { + if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil { + return err + } + } + return tx.Commit() +} + +func (c *commonStore) AttachmentsExpired() ([]string, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return ids, nil +} + +func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error { + c.mu.Lock() + defer c.mu.Unlock() + tx, err := c.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + for _, id := range ids { + if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil { + return err + } + } + return tx.Commit() +} + +func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) + if err != nil { + return 0, err + } + return c.readAttachmentBytesUsed(rows) +} + +func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) { + rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) + if err != nil { + return 0, err + } + return c.readAttachmentBytesUsed(rows) +} + +func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { + defer rows.Close() + var size int64 + if !rows.Next() { + return 0, errors.New("no rows found") + } + if err := rows.Scan(&size); err != nil { + return 0, err + } else if err := rows.Err(); err != nil { + return 0, err + } + return size, nil +} + +func (c *commonStore) UpdateStats(messages int64) error { + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.db.Exec(c.queries.updateStats, messages) + return err +} + +func (c *commonStore) Stats() (messages int64, err error) { + rows, err := c.db.Query(c.queries.selectStats) + if err != nil { + return 0, err + } + defer rows.Close() + if !rows.Next() { + return 0, errNoRows + } + if err := rows.Scan(&messages); err != nil { + return 0, err + } + return messages, nil +} + +func (c *commonStore) Close() error { + return c.db.Close() +} + +func (c *commonStore) processMessageBatches() { + if c.queue == nil { + return + } + for messages := range c.queue.Dequeue() { + if err := c.addMessages(messages); err != nil { + log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch") + } + } +} + +func readMessages(rows *sql.Rows) ([]*model.Message, error) { + defer rows.Close() + messages := make([]*model.Message, 0) + for rows.Next() { + m, err := readMessage(rows) + if err != nil { + return nil, err + } + messages = append(messages, m) + } + if err := rows.Err(); err != nil { + return nil, err + } + return messages, nil +} + +func readMessage(rows *sql.Rows) (*model.Message, error) { + var timestamp, expires, attachmentSize, attachmentExpires int64 + var priority int + var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string + err := rows.Scan( + &id, + &sequenceID, + ×tamp, + &event, + &expires, + &topic, + &msg, + &title, + &priority, + &tagsStr, + &click, + &icon, + &actionsStr, + &attachmentName, + &attachmentType, + &attachmentSize, + &attachmentExpires, + &attachmentURL, + &sender, + &user, + &contentType, + &encoding, + ) + if err != nil { + return nil, err + } + var tags []string + if tagsStr != "" { + tags = strings.Split(tagsStr, ",") + } + var actions []*model.Action + if actionsStr != "" { + if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil { + return nil, err + } + } + senderIP, err := netip.ParseAddr(sender) + if err != nil { + senderIP = netip.Addr{} // if no IP stored in database, return invalid address + } + var att *model.Attachment + if attachmentName != "" && attachmentURL != "" { + att = &model.Attachment{ + Name: attachmentName, + Type: attachmentType, + Size: attachmentSize, + Expires: attachmentExpires, + URL: attachmentURL, + } + } + return &model.Message{ + ID: id, + SequenceID: sequenceID, + Time: timestamp, + Expires: expires, + Event: event, + Topic: topic, + Message: msg, + Title: title, + Priority: priority, + Tags: tags, + Click: click, + Icon: icon, + Actions: actions, + Attachment: att, + Sender: senderIP, + User: user, + ContentType: contentType, + Encoding: encoding, + }, nil +} + +// Ensure commonStore implements Store +var _ Store = (*commonStore)(nil) + +// Needed by store.go but not part of Store interface; unused import guard +var _ = fmt.Sprintf diff --git a/message/store_postgres.go b/message/store_postgres.go new file mode 100644 index 00000000..3efab244 --- /dev/null +++ b/message/store_postgres.go @@ -0,0 +1,118 @@ +package message + +import ( + "database/sql" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver +) + +// PostgreSQL runtime query constants +const ( + pgInsertMessageQuery = ` + INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24) + ` + pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1` + pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE` + pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE` + pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2` + pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1` + pgSelectMessagesByIDQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE mid = $1 + ` + pgSelectMessagesSinceTimeQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE topic = $1 AND time >= $2 AND published = TRUE + ORDER BY time, id + ` + pgSelectMessagesSinceTimeIncludeScheduledQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE topic = $1 AND time >= $2 + ORDER BY time, id + ` + pgSelectMessagesSinceIDQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE topic = $1 AND id > $2 AND published = TRUE + ORDER BY time, id + ` + pgSelectMessagesSinceIDIncludeScheduledQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE topic = $1 AND (id > $2 OR published = FALSE) + ORDER BY time, id + ` + pgSelectMessagesLatestQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE topic = $1 AND published = TRUE + ORDER BY time DESC, id DESC + LIMIT 1 + ` + pgSelectMessagesDueQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding + FROM message + WHERE time <= $1 AND published = FALSE + ORDER BY time, id + ` + pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE` + pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1` + pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message` + pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic` + pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic` + + pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1` + pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE` + pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2` + pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2` + + pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'` + pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'` +) + +var pgQueries = storeQueries{ + insertMessage: pgInsertMessageQuery, + deleteMessage: pgDeleteMessageQuery, + selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery, + deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery, + updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery, + selectRowIDFromMessageID: pgSelectRowIDFromMessageID, + selectMessagesByID: pgSelectMessagesByIDQuery, + selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery, + selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery, + selectMessagesSinceID: pgSelectMessagesSinceIDQuery, + selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery, + selectMessagesLatest: pgSelectMessagesLatestQuery, + selectMessagesDue: pgSelectMessagesDueQuery, + selectMessagesExpired: pgSelectMessagesExpiredQuery, + updateMessagePublished: pgUpdateMessagePublishedQuery, + selectMessagesCount: pgSelectMessagesCountQuery, + selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery, + selectTopics: pgSelectTopicsQuery, + updateAttachmentDeleted: pgUpdateAttachmentDeleted, + selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery, + selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery, + selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery, + selectStats: pgSelectStatsQuery, + updateStats: pgUpdateStatsQuery, +} + +// NewPostgresStore creates a new PostgreSQL-backed message cache store. +func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) { + db, err := sql.Open("pgx", dsn) + if err != nil { + return nil, err + } + if err := db.Ping(); err != nil { + return nil, err + } + if err := setupPostgresDB(db); err != nil { + return nil, err + } + return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil +} diff --git a/message/store_postgres_schema.go b/message/store_postgres_schema.go new file mode 100644 index 00000000..e7aa3fe8 --- /dev/null +++ b/message/store_postgres_schema.go @@ -0,0 +1,90 @@ +package message + +import ( + "database/sql" + "fmt" +) + +// Initial PostgreSQL schema +const ( + pgCreateTablesQuery = ` + CREATE TABLE IF NOT EXISTS message ( + id BIGSERIAL PRIMARY KEY, + mid TEXT NOT NULL, + sequence_id TEXT NOT NULL, + time BIGINT NOT NULL, + event TEXT NOT NULL, + expires BIGINT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size BIGINT NOT NULL, + attachment_expires BIGINT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE, + sender TEXT NOT NULL, + user_id TEXT NOT NULL, + content_type TEXT NOT NULL, + encoding TEXT NOT NULL, + published BOOLEAN NOT NULL DEFAULT FALSE + ); + CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid); + CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id); + CREATE INDEX IF NOT EXISTS idx_message_time ON message (time); + CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic); + CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires); + CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender); + CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id); + CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires); + CREATE TABLE IF NOT EXISTS message_stats ( + key TEXT PRIMARY KEY, + value BIGINT + ); + INSERT INTO message_stats (key, value) VALUES ('messages', 0); + CREATE TABLE IF NOT EXISTS schema_version ( + store TEXT PRIMARY KEY, + version INT NOT NULL + ); + ` +) + +// PostgreSQL schema management queries +const ( + pgCurrentSchemaVersion = 14 + pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)` + pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'` +) + +func setupPostgresDB(db *sql.DB) error { + var schemaVersion int + err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion) + if err != nil { + return setupNewPostgresDB(db) + } + if schemaVersion > pgCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion) + } + return nil +} + +func setupNewPostgresDB(db *sql.DB) error { + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(pgCreateTablesQuery); err != nil { + return err + } + if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil { + return err + } + return tx.Commit() +} diff --git a/message/store_postgres_test.go b/message/store_postgres_test.go new file mode 100644 index 00000000..930d700d --- /dev/null +++ b/message/store_postgres_test.go @@ -0,0 +1,120 @@ +package message_test + +import ( + "database/sql" + "fmt" + "net/url" + "os" + "testing" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/util" +) + +func newTestPostgresStore(t *testing.T) message.Store { + dsn := os.Getenv("NTFY_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests") + } + // Create a unique schema for this test + schema := fmt.Sprintf("test_%s", util.RandomString(10)) + setupDB, err := sql.Open("pgx", dsn) + require.Nil(t, err) + _, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) + require.Nil(t, err) + require.Nil(t, setupDB.Close()) + // Open store with search_path set to the new schema + u, err := url.Parse(dsn) + require.Nil(t, err) + q := u.Query() + q.Set("search_path", schema) + u.RawQuery = q.Encode() + store, err := message.NewPostgresStore(u.String(), 0, 0) + require.Nil(t, err) + t.Cleanup(func() { + store.Close() + cleanDB, err := sql.Open("pgx", dsn) + if err == nil { + cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) + cleanDB.Close() + } + }) + return store +} + +func TestPostgresStore_Messages(t *testing.T) { + testCacheMessages(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessagesLock(t *testing.T) { + testCacheMessagesLock(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessagesScheduled(t *testing.T) { + testCacheMessagesScheduled(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_Topics(t *testing.T) { + testCacheTopics(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) { + testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessagesSinceID(t *testing.T) { + testCacheMessagesSinceID(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_Prune(t *testing.T) { + testCachePrune(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_Attachments(t *testing.T) { + testCacheAttachments(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_AttachmentsExpired(t *testing.T) { + testCacheAttachmentsExpired(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_Sender(t *testing.T) { + testSender(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) { + testDeleteScheduledBySequenceID(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessageByID(t *testing.T) { + testMessageByID(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MarkPublished(t *testing.T) { + testMarkPublished(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_ExpireMessages(t *testing.T) { + testExpireMessages(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) { + testMarkAttachmentsDeleted(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_Stats(t *testing.T) { + testStats(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_AddMessages(t *testing.T) { + testAddMessages(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessagesDue(t *testing.T) { + testMessagesDue(t, newTestPostgresStore(t)) +} + +func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) { + testMessageFieldRoundTrip(t, newTestPostgresStore(t)) +} diff --git a/message/store_sqlite.go b/message/store_sqlite.go new file mode 100644 index 00000000..275fa43d --- /dev/null +++ b/message/store_sqlite.go @@ -0,0 +1,138 @@ +package message + +import ( + "database/sql" + "fmt" + "path/filepath" + "time" + + _ "github.com/mattn/go-sqlite3" // SQLite driver + "heckel.io/ntfy/v2/util" +) + +// SQLite runtime query constants +const ( + sqliteInsertMessageQuery = ` + INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?` + sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0` + sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0` + sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` + sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` + sqliteSelectMessagesByIDQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE mid = ? + ` + sqliteSelectMessagesSinceTimeQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND time >= ? AND published = 1 + ORDER BY time, id + ` + sqliteSelectMessagesSinceTimeIncludeScheduledQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND time >= ? + ORDER BY time, id + ` + sqliteSelectMessagesSinceIDQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND id > ? AND published = 1 + ORDER BY time, id + ` + sqliteSelectMessagesSinceIDIncludeScheduledQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND (id > ? OR published = 0) + ORDER BY time, id + ` + sqliteSelectMessagesLatestQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE topic = ? AND published = 1 + ORDER BY time DESC, id DESC + LIMIT 1 + ` + sqliteSelectMessagesDueQuery = ` + SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding + FROM messages + WHERE time <= ? AND published = 0 + ORDER BY time, id + ` + sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1` + sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` + sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages` + sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` + sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` + + sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` + sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` + sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` + sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` + + sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` + sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` +) + +var sqliteQueries = storeQueries{ + insertMessage: sqliteInsertMessageQuery, + deleteMessage: sqliteDeleteMessageQuery, + selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery, + deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery, + updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery, + selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID, + selectMessagesByID: sqliteSelectMessagesByIDQuery, + selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery, + selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery, + selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery, + selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery, + selectMessagesLatest: sqliteSelectMessagesLatestQuery, + selectMessagesDue: sqliteSelectMessagesDueQuery, + selectMessagesExpired: sqliteSelectMessagesExpiredQuery, + updateMessagePublished: sqliteUpdateMessagePublishedQuery, + selectMessagesCount: sqliteSelectMessagesCountQuery, + selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery, + selectTopics: sqliteSelectTopicsQuery, + updateAttachmentDeleted: sqliteUpdateAttachmentDeleted, + selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery, + selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery, + selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery, + selectStats: sqliteSelectStatsQuery, + updateStats: sqliteUpdateStatsQuery, +} + +// NewSQLiteStore creates a SQLite file-backed cache +func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) { + parentDir := filepath.Dir(filename) + if !util.FileExists(parentDir) { + return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) + } + db, err := sql.Open("sqlite3", filename) + if err != nil { + return nil, err + } + if err := setupSQLite(db, startupQueries, cacheDuration); err != nil { + return nil, err + } + return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil +} + +// NewMemStore creates an in-memory cache +func NewMemStore() (Store, error) { + return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false) +} + +// NewNopStore creates an in-memory cache that discards all messages; +// it is always empty and can be used if caching is entirely disabled +func NewNopStore() (Store, error) { + return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true) +} + +// createMemoryFilename creates a unique memory filename to use for the SQLite backend. +func createMemoryFilename() string { + return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) +} diff --git a/message/store_sqlite_schema.go b/message/store_sqlite_schema.go new file mode 100644 index 00000000..cd01555f --- /dev/null +++ b/message/store_sqlite_schema.go @@ -0,0 +1,466 @@ +package message + +import ( + "database/sql" + "fmt" + "time" + + "heckel.io/ntfy/v2/log" +) + +// Initial SQLite schema +const ( + sqliteCreateMessagesTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + sequence_id TEXT NOT NULL, + time INT NOT NULL, + event TEXT NOT NULL, + expires INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_deleted INT NOT NULL, + sender TEXT NOT NULL, + user TEXT NOT NULL, + content_type TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); + CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); + CREATE INDEX IF NOT EXISTS idx_user ON messages (user); + CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + COMMIT; + ` +) + +// Schema version management for SQLite +const ( + sqliteCurrentSchemaVersion = 14 + sqliteCreateSchemaVersionTableQuery = ` + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + ` + sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` + sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` + sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` +) + +// Schema migrations for SQLite +const ( + // 0 -> 1 + sqliteMigrate0To1AlterMessagesTableQuery = ` + BEGIN; + ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0); + ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT(''); + COMMIT; + ` + + // 1 -> 2 + sqliteMigrate1To2AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1); + ` + + // 2 -> 3 + sqliteMigrate2To3AlterMessagesTableQuery = ` + BEGIN; + ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT(''); + COMMIT; + ` + // 3 -> 4 + sqliteMigrate3To4AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT(''); + ` + + // 4 -> 5 + sqliteMigrate4To5AlterMessagesTableQuery = ` + BEGIN; + CREATE TABLE IF NOT EXISTS messages_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + time INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + attachment_owner TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid); + CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic); + INSERT + INTO messages_new ( + mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, + attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) + SELECT + id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, + attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published + FROM messages; + DROP TABLE messages; + ALTER TABLE messages_new RENAME TO messages; + COMMIT; + ` + + // 5 -> 6 + sqliteMigrate5To6AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT(''); + ` + + // 6 -> 7 + sqliteMigrate6To7AlterMessagesTableQuery = ` + ALTER TABLE messages RENAME COLUMN attachment_owner TO sender; + ` + + // 7 -> 8 + sqliteMigrate7To8AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT(''); + ` + + // 8 -> 9 + sqliteMigrate8To9AlterMessagesTableQuery = ` + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + ` + + // 9 -> 10 + sqliteMigrate9To10AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0'); + ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0'); + CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); + CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); + CREATE INDEX IF NOT EXISTS idx_user ON messages (user); + CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); + ` + sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` + + // 10 -> 11 + sqliteMigrate10To11AlterMessagesTableQuery = ` + CREATE TABLE IF NOT EXISTS stats ( + key TEXT PRIMARY KEY, + value INT + ); + INSERT INTO stats (key, value) VALUES ('messages', 0); + ` + + // 11 -> 12 + sqliteMigrate11To12AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT(''); + ` + + // 12 -> 13 + sqliteMigrate12To13AlterMessagesTableQuery = ` + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + ` + + // 13 -> 14 + sqliteMigrate13To14AlterMessagesTableQuery = ` + ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT(''); + ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message'); + CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id); + ` +) + +var ( + sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ + 0: sqliteMigrateFrom0, + 1: sqliteMigrateFrom1, + 2: sqliteMigrateFrom2, + 3: sqliteMigrateFrom3, + 4: sqliteMigrateFrom4, + 5: sqliteMigrateFrom5, + 6: sqliteMigrateFrom6, + 7: sqliteMigrateFrom7, + 8: sqliteMigrateFrom8, + 9: sqliteMigrateFrom9, + 10: sqliteMigrateFrom10, + 11: sqliteMigrateFrom11, + 12: sqliteMigrateFrom12, + 13: sqliteMigrateFrom13, + } +) + +func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error { + if err := runSQLiteStartupQueries(db, startupQueries); err != nil { + return err + } + // If 'messages' table does not exist, this must be a new database + rowsMC, err := db.Query(sqliteSelectMessagesCountQuery) + if err != nil { + return setupNewSQLite(db) + } + rowsMC.Close() + // If 'messages' table exists, check 'schemaVersion' table + schemaVersion := 0 + rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery) + if err == nil { + defer rowsSV.Close() + if !rowsSV.Next() { + return fmt.Errorf("cannot determine schema version: cache file may be corrupt") + } + if err := rowsSV.Scan(&schemaVersion); err != nil { + return err + } + rowsSV.Close() + } + // Do migrations + if schemaVersion == sqliteCurrentSchemaVersion { + return nil + } else if schemaVersion > sqliteCurrentSchemaVersion { + return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion) + } + for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ { + fn, ok := sqliteMigrations[i] + if !ok { + return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) + } else if err := fn(db, cacheDuration); err != nil { + return err + } + } + return nil +} + +func setupNewSQLite(db *sql.DB) error { + if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil { + return err + } + return nil +} + +func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { + if startupQueries != "" { + if _, err := db.Exec(startupQueries); err != nil { + return err + } + } + return nil +} + +func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1") + if _, err := db.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2") + if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3") + if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4") + if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5") + if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6") + if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7") + if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8") + if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9") + if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil { + return err + } + if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil { + return err + } + return nil +} + +func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil { + return err + } + return tx.Commit() +} + +func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil { + return err + } + return tx.Commit() +} + +func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil { + return err + } + return tx.Commit() +} + +func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil { + return err + } + return tx.Commit() +} + +func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error { + log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14") + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil { + return err + } + if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil { + return err + } + return tx.Commit() +} diff --git a/message/store_sqlite_test.go b/message/store_sqlite_test.go new file mode 100644 index 00000000..aa102044 --- /dev/null +++ b/message/store_sqlite_test.go @@ -0,0 +1,459 @@ +package message_test + +import ( + "database/sql" + "fmt" + "path/filepath" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" // SQLite driver + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/model" +) + +func TestSqliteStore_Messages(t *testing.T) { + testCacheMessages(t, newSqliteTestStore(t)) +} + +func TestMemStore_Messages(t *testing.T) { + testCacheMessages(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessagesLock(t *testing.T) { + testCacheMessagesLock(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessagesLock(t *testing.T) { + testCacheMessagesLock(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessagesScheduled(t *testing.T) { + testCacheMessagesScheduled(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessagesScheduled(t *testing.T) { + testCacheMessagesScheduled(t, newMemTestStore(t)) +} + +func TestSqliteStore_Topics(t *testing.T) { + testCacheTopics(t, newSqliteTestStore(t)) +} + +func TestMemStore_Topics(t *testing.T) { + testCacheTopics(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) { + testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) { + testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessagesSinceID(t *testing.T) { + testCacheMessagesSinceID(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessagesSinceID(t *testing.T) { + testCacheMessagesSinceID(t, newMemTestStore(t)) +} + +func TestSqliteStore_Prune(t *testing.T) { + testCachePrune(t, newSqliteTestStore(t)) +} + +func TestMemStore_Prune(t *testing.T) { + testCachePrune(t, newMemTestStore(t)) +} + +func TestSqliteStore_Attachments(t *testing.T) { + testCacheAttachments(t, newSqliteTestStore(t)) +} + +func TestMemStore_Attachments(t *testing.T) { + testCacheAttachments(t, newMemTestStore(t)) +} + +func TestSqliteStore_AttachmentsExpired(t *testing.T) { + testCacheAttachmentsExpired(t, newSqliteTestStore(t)) +} + +func TestMemStore_AttachmentsExpired(t *testing.T) { + testCacheAttachmentsExpired(t, newMemTestStore(t)) +} + +func TestSqliteStore_Sender(t *testing.T) { + testSender(t, newSqliteTestStore(t)) +} + +func TestMemStore_Sender(t *testing.T) { + testSender(t, newMemTestStore(t)) +} + +func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) { + testDeleteScheduledBySequenceID(t, newSqliteTestStore(t)) +} + +func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) { + testDeleteScheduledBySequenceID(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessageByID(t *testing.T) { + testMessageByID(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessageByID(t *testing.T) { + testMessageByID(t, newMemTestStore(t)) +} + +func TestSqliteStore_MarkPublished(t *testing.T) { + testMarkPublished(t, newSqliteTestStore(t)) +} + +func TestMemStore_MarkPublished(t *testing.T) { + testMarkPublished(t, newMemTestStore(t)) +} + +func TestSqliteStore_ExpireMessages(t *testing.T) { + testExpireMessages(t, newSqliteTestStore(t)) +} + +func TestMemStore_ExpireMessages(t *testing.T) { + testExpireMessages(t, newMemTestStore(t)) +} + +func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) { + testMarkAttachmentsDeleted(t, newSqliteTestStore(t)) +} + +func TestMemStore_MarkAttachmentsDeleted(t *testing.T) { + testMarkAttachmentsDeleted(t, newMemTestStore(t)) +} + +func TestSqliteStore_Stats(t *testing.T) { + testStats(t, newSqliteTestStore(t)) +} + +func TestMemStore_Stats(t *testing.T) { + testStats(t, newMemTestStore(t)) +} + +func TestSqliteStore_AddMessages(t *testing.T) { + testAddMessages(t, newSqliteTestStore(t)) +} + +func TestMemStore_AddMessages(t *testing.T) { + testAddMessages(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessagesDue(t *testing.T) { + testMessagesDue(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessagesDue(t *testing.T) { + testMessagesDue(t, newMemTestStore(t)) +} + +func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) { + testMessageFieldRoundTrip(t, newSqliteTestStore(t)) +} + +func TestMemStore_MessageFieldRoundTrip(t *testing.T) { + testMessageFieldRoundTrip(t, newMemTestStore(t)) +} + +func TestSqliteStore_Migration_From0(t *testing.T) { + filename := newSqliteTestStoreFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 0" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id VARCHAR(20) PRIMARY KEY, + time INT NOT NULL, + topic VARCHAR(64) NOT NULL, + message VARCHAR(1024) NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + COMMIT; + `) + require.Nil(t, err) + + // Insert a bunch of messages + for i := 0; i < 10; i++ { + _, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`, + fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i)) + require.Nil(t, err) + } + require.Nil(t, db.Close()) + + // Create store to trigger migration + s := newSqliteTestStoreFromFile(t, filename, "") + checkSqliteSchemaVersion(t, filename) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + require.Equal(t, "some message 5", messages[5].Message) + require.Equal(t, "", messages[5].Title) + require.Nil(t, messages[5].Tags) + require.Equal(t, 0, messages[5].Priority) +} + +func TestSqliteStore_Migration_From1(t *testing.T) { + filename := newSqliteTestStoreFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 1" schema + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS messages ( + id VARCHAR(20) PRIMARY KEY, + time INT NOT NULL, + topic VARCHAR(64) NOT NULL, + message VARCHAR(512) NOT NULL, + title VARCHAR(256) NOT NULL, + priority INT NOT NULL, + tags VARCHAR(256) NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schemaVersion (id, version) VALUES (1, 1); + `) + require.Nil(t, err) + + // Insert a bunch of messages + for i := 0; i < 10; i++ { + _, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`, + fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "") + require.Nil(t, err) + } + require.Nil(t, db.Close()) + + // Create store to trigger migration + s := newSqliteTestStoreFromFile(t, filename, "") + checkSqliteSchemaVersion(t, filename) + + // Add delayed message + delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message") + delayedMessage.Time = time.Now().Add(time.Minute).Unix() + require.Nil(t, s.AddMessage(delayedMessage)) + + // 10, not 11! + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + + // 11! + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 11, len(messages)) + + // Check that index "idx_topic" exists + verifyDB, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + defer verifyDB.Close() + rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`) + require.Nil(t, err) + require.True(t, rows.Next()) + var indexName string + require.Nil(t, rows.Scan(&indexName)) + require.Equal(t, "idx_topic", indexName) + require.Nil(t, rows.Close()) +} + +func TestSqliteStore_Migration_From9(t *testing.T) { + // This primarily tests the awkward migration that introduces the "expires" column. + // The migration logic has to update the column, using the existing "cache-duration" value. + + filename := newSqliteTestStoreFile(t) + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + + // Create "version 9" schema + _, err = db.Exec(` + BEGIN; + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + mid TEXT NOT NULL, + time INT NOT NULL, + topic TEXT NOT NULL, + message TEXT NOT NULL, + title TEXT NOT NULL, + priority INT NOT NULL, + tags TEXT NOT NULL, + click TEXT NOT NULL, + icon TEXT NOT NULL, + actions TEXT NOT NULL, + attachment_name TEXT NOT NULL, + attachment_type TEXT NOT NULL, + attachment_size INT NOT NULL, + attachment_expires INT NOT NULL, + attachment_url TEXT NOT NULL, + sender TEXT NOT NULL, + encoding TEXT NOT NULL, + published INT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); + CREATE INDEX IF NOT EXISTS idx_time ON messages (time); + CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); + CREATE TABLE IF NOT EXISTS schemaVersion ( + id INT PRIMARY KEY, + version INT NOT NULL + ); + INSERT INTO schemaVersion (id, version) VALUES (1, 9); + COMMIT; + `) + require.Nil(t, err) + + // Insert a bunch of messages + insertQuery := ` + INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + for i := 0; i < 10; i++ { + _, err = db.Exec( + insertQuery, + fmt.Sprintf("abcd%d", i), + time.Now().Unix(), + "mytopic", + fmt.Sprintf("some message %d", i), + "", // title + 0, // priority + "", // tags + "", // click + "", // icon + "", // actions + "", // attachment_name + "", // attachment_type + 0, // attachment_size + 0, // attachment_expires + "", // attachment_url + "9.9.9.9", // sender + "", // encoding + 1, // published + ) + require.Nil(t, err) + } + require.Nil(t, db.Close()) + + // Create store to trigger migration + cacheDuration := 17 * time.Hour + s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false) + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + checkSqliteSchemaVersion(t, filename) + + // Check version + verifyDB, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + defer verifyDB.Close() + rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`) + require.Nil(t, err) + require.True(t, rows.Next()) + var version int + require.Nil(t, rows.Scan(&version)) + require.Equal(t, 14, version) + require.Nil(t, rows.Close()) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 10, len(messages)) + for _, m := range messages { + require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix()) + require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix()) + } +} + +func TestSqliteStore_StartupQueries_WAL(t *testing.T) { + filename := newSqliteTestStoreFile(t) + startupQueries := `pragma journal_mode = WAL; +pragma synchronous = normal; +pragma temp_store = memory;` + s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false) + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message"))) + require.FileExists(t, filename) + require.FileExists(t, filename+"-wal") + require.FileExists(t, filename+"-shm") +} + +func TestSqliteStore_StartupQueries_None(t *testing.T) { + filename := newSqliteTestStoreFile(t) + s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false) + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message"))) + require.FileExists(t, filename) + require.NoFileExists(t, filename+"-wal") + require.NoFileExists(t, filename+"-shm") +} + +func TestSqliteStore_StartupQueries_Fail(t *testing.T) { + filename := newSqliteTestStoreFile(t) + _, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false) + require.Error(t, err) +} + +func TestNopStore(t *testing.T) { + s, err := message.NewNopStore() + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message"))) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Empty(t, messages) + + topics, err := s.Topics() + require.Nil(t, err) + require.Empty(t, topics) +} + +func newSqliteTestStore(t *testing.T) message.Store { + filename := filepath.Join(t.TempDir(), "cache.db") + s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false) + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + return s +} + +func newSqliteTestStoreFile(t *testing.T) string { + return filepath.Join(t.TempDir(), "cache.db") +} + +func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store { + s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false) + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + return s +} + +func newMemTestStore(t *testing.T) message.Store { + s, err := message.NewMemStore() + require.Nil(t, err) + t.Cleanup(func() { s.Close() }) + return s +} + +func checkSqliteSchemaVersion(t *testing.T, filename string) { + db, err := sql.Open("sqlite3", filename) + require.Nil(t, err) + defer db.Close() + rows, err := db.Query(`SELECT version FROM schemaVersion`) + require.Nil(t, err) + require.True(t, rows.Next()) + var schemaVersion int + require.Nil(t, rows.Scan(&schemaVersion)) + require.Equal(t, 14, schemaVersion) + require.Nil(t, rows.Close()) +} diff --git a/message/store_test.go b/message/store_test.go new file mode 100644 index 00000000..e3297c96 --- /dev/null +++ b/message/store_test.go @@ -0,0 +1,767 @@ +package message_test + +import ( + "net/netip" + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/model" +) + +func testCacheMessages(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "my message") + m1.Time = 1 + + m2 := model.NewDefaultMessage("mytopic", "my other message") + m2.Time = 2 + + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message"))) + require.Nil(t, s.AddMessage(m2)) + + // Adding invalid + require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added! + require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added! + + // mytopic: count + counts, err := s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + + // mytopic: since all + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) + require.Equal(t, 2, len(messages)) + require.Equal(t, "my message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) + require.Equal(t, model.MessageEvent, messages[0].Event) + require.Equal(t, "", messages[0].Title) + require.Equal(t, 0, messages[0].Priority) + require.Nil(t, messages[0].Tags) + require.Equal(t, "my other message", messages[1].Message) + + // mytopic: since none + messages, _ = s.Messages("mytopic", model.SinceNoMessages, false) + require.Empty(t, messages) + + // mytopic: since m1 (by ID) + messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, m2.ID, messages[0].ID) + require.Equal(t, "my other message", messages[0].Message) + require.Equal(t, "mytopic", messages[0].Topic) + + // mytopic: since 2 + messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + + // mytopic: latest + messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) + + // example: count + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["example"]) + + // example: since all + messages, _ = s.Messages("example", model.SinceAllMessages, false) + require.Equal(t, "my example message", messages[0].Message) + + // non-existing: count + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 0, counts["doesnotexist"]) + + // non-existing: since all + messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false) + require.Empty(t, messages) +} + +func testCacheMessagesLock(t *testing.T, s message.Store) { + var wg sync.WaitGroup + for i := 0; i < 5000; i++ { + wg.Add(1) + go func() { + assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message"))) + wg.Done() + }() + } + wg.Wait() +} + +func testCacheMessagesScheduled(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "message 1") + m2 := model.NewDefaultMessage("mytopic", "message 2") + m2.Time = time.Now().Add(time.Hour).Unix() + m3 := model.NewDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! + m4 := model.NewDefaultMessage("mytopic2", "message 4") + m4.Time = time.Now().Add(time.Minute).Unix() + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) + + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled + require.Equal(t, 1, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + + messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) // Order! + require.Equal(t, "message 2", messages[2].Message) + + messages, _ = s.MessagesDue() + require.Empty(t, messages) +} + +func testCacheTopics(t *testing.T, s message.Store) { + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2"))) + require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3"))) + + topics, err := s.Topics() + if err != nil { + t.Fatal(err) + } + require.Equal(t, 2, len(topics)) + require.Contains(t, topics, "topic1") + require.Contains(t, topics, "topic2") +} + +func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) { + m := model.NewDefaultMessage("mytopic", "some message") + m.Tags = []string{"tag1", "tag2"} + m.Priority = 5 + m.Title = "some title" + require.Nil(t, s.AddMessage(m)) + + messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) + require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) + require.Equal(t, 5, messages[0].Priority) + require.Equal(t, "some title", messages[0].Title) +} + +func testCacheMessagesSinceID(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "message 1") + m1.Time = 100 + m2 := model.NewDefaultMessage("mytopic", "message 2") + m2.Time = 200 + m3 := model.NewDefaultMessage("mytopic", "message 3") + m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 + m4 := model.NewDefaultMessage("mytopic", "message 4") + m4.Time = 400 + m5 := model.NewDefaultMessage("mytopic", "message 5") + m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 + m6 := model.NewDefaultMessage("mytopic", "message 6") + m6.Time = 600 + m7 := model.NewDefaultMessage("mytopic", "message 7") + m7.Time = 700 + + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) + require.Nil(t, s.AddMessage(m4)) + require.Nil(t, s.AddMessage(m5)) + require.Nil(t, s.AddMessage(m6)) + require.Nil(t, s.AddMessage(m7)) + + // Case 1: Since ID exists, exclude scheduled + messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false) + require.Equal(t, 3, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! + require.Equal(t, "message 7", messages[2].Message) + + // Case 2: Since ID exists, include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true) + require.Equal(t, 5, len(messages)) + require.Equal(t, "message 4", messages[0].Message) + require.Equal(t, "message 6", messages[1].Message) + require.Equal(t, "message 7", messages[2].Message) + require.Equal(t, "message 5", messages[3].Message) // Order! + require.Equal(t, "message 3", messages[4].Message) // Order! + + // Case 3: Since ID does not exist (-> Return all messages), include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true) + require.Equal(t, 7, len(messages)) + require.Equal(t, "message 1", messages[0].Message) + require.Equal(t, "message 2", messages[1].Message) + require.Equal(t, "message 4", messages[2].Message) + require.Equal(t, "message 6", messages[3].Message) + require.Equal(t, "message 7", messages[4].Message) + require.Equal(t, "message 5", messages[5].Message) // Order! + require.Equal(t, "message 3", messages[6].Message) // Order! + + // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false) + require.Equal(t, 0, len(messages)) + + // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled + messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true) + require.Equal(t, 2, len(messages)) + require.Equal(t, "message 5", messages[0].Message) + require.Equal(t, "message 3", messages[1].Message) +} + +func testCachePrune(t *testing.T, s message.Store) { + now := time.Now().Unix() + + m1 := model.NewDefaultMessage("mytopic", "my message") + m1.Time = now - 10 + m1.Expires = now - 5 + + m2 := model.NewDefaultMessage("mytopic", "my other message") + m2.Time = now - 5 + m2.Expires = now + 5 // In the future + + m3 := model.NewDefaultMessage("another_topic", "and another one") + m3.Time = now - 12 + m3.Expires = now - 2 + + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) + + counts, err := s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 2, counts["mytopic"]) + require.Equal(t, 1, counts["another_topic"]) + + expiredMessageIDs, err := s.MessagesExpired() + require.Nil(t, err) + require.Nil(t, s.DeleteMessages(expiredMessageIDs...)) + + counts, err = s.MessageCounts() + require.Nil(t, err) + require.Equal(t, 1, counts["mytopic"]) + require.Equal(t, 0, counts["another_topic"]) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "my other message", messages[0].Message) +} + +func testCacheAttachments(t *testing.T, s message.Store) { + expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired + m := model.NewDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.SequenceID = "m1" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &model.Attachment{ + Name: "flower.jpg", + Type: "image/jpeg", + Size: 5000, + Expires: expires1, + URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", + } + require.Nil(t, s.AddMessage(m)) + + expires2 := time.Now().Add(2 * time.Hour).Unix() // Future + m = model.NewDefaultMessage("mytopic", "sending you a car") + m.ID = "m2" + m.SequenceID = "m2" + m.Sender = netip.MustParseAddr("1.2.3.4") + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: expires2, + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) + + expires3 := time.Now().Add(1 * time.Hour).Unix() // Future + m = model.NewDefaultMessage("another-topic", "sending you another car") + m.ID = "m3" + m.SequenceID = "m3" + m.User = "u_BAsbaAa" + m.Sender = netip.MustParseAddr("5.6.7.8") + m.Attachment = &model.Attachment{ + Name: "another-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: expires3, + URL: "https://ntfy.sh/file/zakaDHFW.jpg", + } + require.Nil(t, s.AddMessage(m)) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + require.Equal(t, "flower for you", messages[0].Message) + require.Equal(t, "flower.jpg", messages[0].Attachment.Name) + require.Equal(t, "image/jpeg", messages[0].Attachment.Type) + require.Equal(t, int64(5000), messages[0].Attachment.Size) + require.Equal(t, expires1, messages[0].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[0].Sender.String()) + + require.Equal(t, "sending you a car", messages[1].Message) + require.Equal(t, "car.jpg", messages[1].Attachment.Name) + require.Equal(t, "image/jpeg", messages[1].Attachment.Type) + require.Equal(t, int64(10000), messages[1].Attachment.Size) + require.Equal(t, expires2, messages[1].Attachment.Expires) + require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) + require.Equal(t, "1.2.3.4", messages[1].Sender.String()) + + size, err := s.AttachmentBytesUsedBySender("1.2.3.4") + require.Nil(t, err) + require.Equal(t, int64(10000), size) + + size, err = s.AttachmentBytesUsedBySender("5.6.7.8") + require.Nil(t, err) + require.Equal(t, int64(0), size) // Accounted to the user, not the IP! + + size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa") + require.Nil(t, err) + require.Equal(t, int64(20000), size) +} + +func testCacheAttachmentsExpired(t *testing.T, s message.Store) { + m := model.NewDefaultMessage("mytopic", "flower for you") + m.ID = "m1" + m.SequenceID = "m1" + m.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m)) + + m = model.NewDefaultMessage("mytopic", "message with attachment") + m.ID = "m2" + m.SequenceID = "m2" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Size: 10000, + Expires: time.Now().Add(2 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) + + m = model.NewDefaultMessage("mytopic", "message with external attachment") + m.ID = "m3" + m.SequenceID = "m3" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "car.jpg", + Type: "image/jpeg", + Expires: 0, // Unknown! + URL: "https://somedomain.com/car.jpg", + } + require.Nil(t, s.AddMessage(m)) + + m = model.NewDefaultMessage("mytopic2", "message with expired attachment") + m.ID = "m4" + m.SequenceID = "m4" + m.Expires = time.Now().Add(2 * time.Hour).Unix() + m.Attachment = &model.Attachment{ + Name: "expired-car.jpg", + Type: "image/jpeg", + Size: 20000, + Expires: time.Now().Add(-1 * time.Hour).Unix(), + URL: "https://ntfy.sh/file/aCaRURL.jpg", + } + require.Nil(t, s.AddMessage(m)) + + ids, err := s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "m4", ids[0]) +} + +func testSender(t *testing.T, s message.Store) { + m1 := model.NewDefaultMessage("mytopic", "mymessage") + m1.Sender = netip.MustParseAddr("1.2.3.4") + require.Nil(t, s.AddMessage(m1)) + + m2 := model.NewDefaultMessage("mytopic", "mymessage without sender") + require.Nil(t, s.AddMessage(m2)) + + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) + require.Equal(t, messages[1].Sender, netip.Addr{}) +} + +func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) { + // Create a scheduled (unpublished) message + scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message") + scheduledMsg.ID = "scheduled1" + scheduledMsg.SequenceID = "seq123" + scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled + require.Nil(t, s.AddMessage(scheduledMsg)) + + // Create a published message with different sequence ID + publishedMsg := model.NewDefaultMessage("mytopic", "published message") + publishedMsg.ID = "published1" + publishedMsg.SequenceID = "seq456" + publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published + require.Nil(t, s.AddMessage(publishedMsg)) + + // Create a scheduled message in a different topic + otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled") + otherTopicMsg.ID = "other1" + otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg + otherTopicMsg.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(otherTopicMsg)) + + // Verify all messages exist (including scheduled) + messages, err := s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + messages, err = s.Messages("othertopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + + // Delete scheduled message by sequence ID and verify returned IDs + deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123") + require.Nil(t, err) + require.Equal(t, 1, len(deletedIDs)) + require.Equal(t, "scheduled1", deletedIDs[0]) + + // Verify scheduled message is deleted + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "published message", messages[0].Message) + + // Verify other topic's message still exists (topic-scoped deletion) + messages, err = s.Messages("othertopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "other scheduled", messages[0].Message) + + // Deleting non-existent sequence ID should return empty list + deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent") + require.Nil(t, err) + require.Empty(t, deletedIDs) + + // Deleting published message should not affect it (only deletes unpublished) + deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456") + require.Nil(t, err) + require.Empty(t, deletedIDs) + + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "published message", messages[0].Message) +} + +func testMessageByID(t *testing.T, s message.Store) { + // Add a message + m := model.NewDefaultMessage("mytopic", "some message") + m.Title = "some title" + m.Priority = 4 + m.Tags = []string{"tag1", "tag2"} + require.Nil(t, s.AddMessage(m)) + + // Retrieve by ID + retrieved, err := s.Message(m.ID) + require.Nil(t, err) + require.Equal(t, m.ID, retrieved.ID) + require.Equal(t, "mytopic", retrieved.Topic) + require.Equal(t, "some message", retrieved.Message) + require.Equal(t, "some title", retrieved.Title) + require.Equal(t, 4, retrieved.Priority) + require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags) + + // Non-existent ID returns ErrMessageNotFound + _, err = s.Message("doesnotexist") + require.Equal(t, model.ErrMessageNotFound, err) +} + +func testMarkPublished(t *testing.T, s message.Store) { + // Add a scheduled message (future time → unpublished) + m := model.NewDefaultMessage("mytopic", "scheduled message") + m.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m)) + + // Verify it does not appear in non-scheduled queries + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 0, len(messages)) + + // Verify it does appear in scheduled queries + messages, err = s.Messages("mytopic", model.SinceAllMessages, true) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + + // Mark as published + require.Nil(t, s.MarkPublished(m)) + + // Now it should appear in non-scheduled queries too + messages, err = s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "scheduled message", messages[0].Message) +} + +func testExpireMessages(t *testing.T, s message.Store) { + // Add messages to two topics + m1 := model.NewDefaultMessage("topic1", "message 1") + m1.Expires = time.Now().Add(time.Hour).Unix() + m2 := model.NewDefaultMessage("topic1", "message 2") + m2.Expires = time.Now().Add(time.Hour).Unix() + m3 := model.NewDefaultMessage("topic2", "message 3") + m3.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m1)) + require.Nil(t, s.AddMessage(m2)) + require.Nil(t, s.AddMessage(m3)) + + // Verify all messages exist + messages, err := s.Messages("topic1", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + messages, err = s.Messages("topic2", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + + // Expire topic1 messages + require.Nil(t, s.ExpireMessages("topic1")) + + // topic1 messages should now be expired (expires set to past) + expiredIDs, err := s.MessagesExpired() + require.Nil(t, err) + require.Equal(t, 2, len(expiredIDs)) + sort.Strings(expiredIDs) + expectedIDs := []string{m1.ID, m2.ID} + sort.Strings(expectedIDs) + require.Equal(t, expectedIDs, expiredIDs) + + // topic2 should be unaffected + messages, err = s.Messages("topic2", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "message 3", messages[0].Message) +} + +func testMarkAttachmentsDeleted(t *testing.T, s message.Store) { + // Add a message with an expired attachment (file needs cleanup) + m1 := model.NewDefaultMessage("mytopic", "old file") + m1.ID = "msg1" + m1.SequenceID = "msg1" + m1.Expires = time.Now().Add(time.Hour).Unix() + m1.Attachment = &model.Attachment{ + Name: "old.pdf", + Type: "application/pdf", + Size: 50000, + Expires: time.Now().Add(-time.Hour).Unix(), // Expired + URL: "https://ntfy.sh/file/old.pdf", + } + require.Nil(t, s.AddMessage(m1)) + + // Add a message with another expired attachment + m2 := model.NewDefaultMessage("mytopic", "another old file") + m2.ID = "msg2" + m2.SequenceID = "msg2" + m2.Expires = time.Now().Add(time.Hour).Unix() + m2.Attachment = &model.Attachment{ + Name: "another.pdf", + Type: "application/pdf", + Size: 30000, + Expires: time.Now().Add(-time.Hour).Unix(), // Expired + URL: "https://ntfy.sh/file/another.pdf", + } + require.Nil(t, s.AddMessage(m2)) + + // Both should show as expired attachments needing cleanup + ids, err := s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 2, len(ids)) + + // Mark msg1's attachment as deleted (file cleaned up) + require.Nil(t, s.MarkAttachmentsDeleted("msg1")) + + // Now only msg2 should show as needing cleanup + ids, err = s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 1, len(ids)) + require.Equal(t, "msg2", ids[0]) + + // Mark msg2 too + require.Nil(t, s.MarkAttachmentsDeleted("msg2")) + + // No more expired attachments to clean up + ids, err = s.AttachmentsExpired() + require.Nil(t, err) + require.Equal(t, 0, len(ids)) + + // Messages themselves still exist + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) +} + +func testStats(t *testing.T, s message.Store) { + // Initial stats should be zero + messages, err := s.Stats() + require.Nil(t, err) + require.Equal(t, int64(0), messages) + + // Update stats + require.Nil(t, s.UpdateStats(42)) + messages, err = s.Stats() + require.Nil(t, err) + require.Equal(t, int64(42), messages) + + // Update again (overwrites) + require.Nil(t, s.UpdateStats(100)) + messages, err = s.Stats() + require.Nil(t, err) + require.Equal(t, int64(100), messages) +} + +func testAddMessages(t *testing.T, s message.Store) { + // Batch add multiple messages + msgs := []*model.Message{ + model.NewDefaultMessage("mytopic", "batch 1"), + model.NewDefaultMessage("mytopic", "batch 2"), + model.NewDefaultMessage("othertopic", "batch 3"), + } + require.Nil(t, s.AddMessages(msgs)) + + // Verify all were inserted + messages, err := s.Messages("mytopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 2, len(messages)) + + messages, err = s.Messages("othertopic", model.SinceAllMessages, false) + require.Nil(t, err) + require.Equal(t, 1, len(messages)) + require.Equal(t, "batch 3", messages[0].Message) + + // Empty batch should succeed + require.Nil(t, s.AddMessages([]*model.Message{})) + + // Batch with invalid event type should fail + badMsgs := []*model.Message{ + model.NewKeepaliveMessage("mytopic"), + } + require.NotNil(t, s.AddMessages(badMsgs)) +} + +func testMessagesDue(t *testing.T, s message.Store) { + // Add a message scheduled in the past (i.e. it's due now) + m1 := model.NewDefaultMessage("mytopic", "due message") + m1.Time = time.Now().Add(-time.Second).Unix() + // Set expires in the future so it doesn't get pruned + m1.Expires = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m1)) + + // Add a message scheduled in the future (not due) + m2 := model.NewDefaultMessage("mytopic", "future message") + m2.Time = time.Now().Add(time.Hour).Unix() + require.Nil(t, s.AddMessage(m2)) + + // Mark m1 as published so it won't be "due" + // (MessagesDue returns unpublished messages whose time <= now) + // m1 is auto-published (time <= now), so it should not be due + // m2 is unpublished (time in future), not due yet + due, err := s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 0, len(due)) + + // Add a message that was explicitly scheduled in the past but time has "arrived" + // We need to manipulate the database to create a truly "due" message: + // a message with published=false and time <= now + m3 := model.NewDefaultMessage("mytopic", "truly due message") + m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now + require.Nil(t, s.AddMessage(m3)) + + // Not due yet + due, err = s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 0, len(due)) + + // Wait for it to become due + time.Sleep(3 * time.Second) + + due, err = s.MessagesDue() + require.Nil(t, err) + require.Equal(t, 1, len(due)) + require.Equal(t, "truly due message", due[0].Message) +} + +func testMessageFieldRoundTrip(t *testing.T, s message.Store) { + // Create a message with all fields populated + m := model.NewDefaultMessage("mytopic", "hello world") + m.SequenceID = "custom_seq_id" + m.Title = "A Title" + m.Priority = 4 + m.Tags = []string{"warning", "srv01"} + m.Click = "https://example.com/click" + m.Icon = "https://example.com/icon.png" + m.Actions = []*model.Action{ + { + ID: "action1", + Action: "view", + Label: "Open Site", + URL: "https://example.com", + Clear: true, + }, + { + ID: "action2", + Action: "http", + Label: "Call Webhook", + URL: "https://example.com/hook", + Method: "PUT", + Headers: map[string]string{"X-Token": "secret"}, + Body: `{"key":"value"}`, + }, + } + m.ContentType = "text/markdown" + m.Encoding = "base64" + m.Sender = netip.MustParseAddr("9.8.7.6") + m.User = "u_TestUser123" + require.Nil(t, s.AddMessage(m)) + + // Retrieve and verify every field + retrieved, err := s.Message(m.ID) + require.Nil(t, err) + + require.Equal(t, m.ID, retrieved.ID) + require.Equal(t, "custom_seq_id", retrieved.SequenceID) + require.Equal(t, m.Time, retrieved.Time) + require.Equal(t, m.Expires, retrieved.Expires) + require.Equal(t, model.MessageEvent, retrieved.Event) + require.Equal(t, "mytopic", retrieved.Topic) + require.Equal(t, "hello world", retrieved.Message) + require.Equal(t, "A Title", retrieved.Title) + require.Equal(t, 4, retrieved.Priority) + require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags) + require.Equal(t, "https://example.com/click", retrieved.Click) + require.Equal(t, "https://example.com/icon.png", retrieved.Icon) + require.Equal(t, "text/markdown", retrieved.ContentType) + require.Equal(t, "base64", retrieved.Encoding) + require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender) + require.Equal(t, "u_TestUser123", retrieved.User) + + // Verify actions round-trip + require.Equal(t, 2, len(retrieved.Actions)) + + require.Equal(t, "action1", retrieved.Actions[0].ID) + require.Equal(t, "view", retrieved.Actions[0].Action) + require.Equal(t, "Open Site", retrieved.Actions[0].Label) + require.Equal(t, "https://example.com", retrieved.Actions[0].URL) + require.Equal(t, true, retrieved.Actions[0].Clear) + + require.Equal(t, "action2", retrieved.Actions[1].ID) + require.Equal(t, "http", retrieved.Actions[1].Action) + require.Equal(t, "Call Webhook", retrieved.Actions[1].Label) + require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL) + require.Equal(t, "PUT", retrieved.Actions[1].Method) + require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"]) + require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body) +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 00000000..8431554d --- /dev/null +++ b/model/model.go @@ -0,0 +1,204 @@ +package model + +import ( + "errors" + "net/netip" + "time" + + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" +) + +// List of possible events +const ( + OpenEvent = "open" + KeepaliveEvent = "keepalive" + MessageEvent = "message" + MessageDeleteEvent = "message_delete" + MessageClearEvent = "message_clear" + PollRequestEvent = "poll_request" +) + +const ( + MessageIDLength = 12 +) + +var ( + ErrUnexpectedMessageType = errors.New("unexpected message type") + ErrMessageNotFound = errors.New("message not found") +) + +// Message represents a message published to a topic +type Message struct { + ID string `json:"id"` // Random message ID + SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID) + Time int64 `json:"time"` // Unix time in seconds + Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive) + Event string `json:"event"` // One of the above + Topic string `json:"topic"` + Title string `json:"title,omitempty"` + Message string `json:"message,omitempty"` + Priority int `json:"priority,omitempty"` + Tags []string `json:"tags,omitempty"` + Click string `json:"click,omitempty"` + Icon string `json:"icon,omitempty"` + Actions []*Action `json:"actions,omitempty"` + Attachment *Attachment `json:"attachment,omitempty"` + PollID string `json:"poll_id,omitempty"` + ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown + Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes + Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting + User string `json:"-"` // UserID of the uploader, used to associated attachments +} + +// Context returns a log context for the message +func (m *Message) Context() log.Context { + fields := map[string]any{ + "topic": m.Topic, + "message_id": m.ID, + "message_sequence_id": m.SequenceID, + "message_time": m.Time, + "message_event": m.Event, + "message_body_size": len(m.Message), + } + if m.Sender.IsValid() { + fields["message_sender"] = m.Sender.String() + } + if m.User != "" { + fields["message_user"] = m.User + } + return fields +} + +// ForJSON returns a copy of the message suitable for JSON output. +// It clears the SequenceID if it equals the ID to reduce redundancy. +func (m *Message) ForJSON() *Message { + if m.SequenceID == m.ID { + clone := *m + clone.SequenceID = "" + return &clone + } + return m +} + +// Attachment represents a file attachment on a message +type Attachment struct { + Name string `json:"name"` + Type string `json:"type,omitempty"` + Size int64 `json:"size,omitempty"` + Expires int64 `json:"expires,omitempty"` + URL string `json:"url"` +} + +// Action represents a user-defined action on a message +type Action struct { + ID string `json:"id"` + Action string `json:"action"` // "view", "broadcast", "http", or "copy" + Label string `json:"label"` // action button label + Clear bool `json:"clear"` // clear notification after successful execution + URL string `json:"url,omitempty"` // used in "view" and "http" actions + Method string `json:"method,omitempty"` // used in "http" action, default is POST (!) + Headers map[string]string `json:"headers,omitempty"` // used in "http" action + Body string `json:"body,omitempty"` // used in "http" action + Intent string `json:"intent,omitempty"` // used in "broadcast" action + Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action + Value string `json:"value,omitempty"` // used in "copy" action +} + +// NewAction creates a new action with initialized maps +func NewAction() *Action { + return &Action{ + Headers: make(map[string]string), + Extras: make(map[string]string), + } +} + +// NewMessage creates a new message with the current timestamp +func NewMessage(event, topic, msg string) *Message { + return &Message{ + ID: util.RandomString(MessageIDLength), + Time: time.Now().Unix(), + Event: event, + Topic: topic, + Message: msg, + } +} + +// NewOpenMessage is a convenience method to create an open message +func NewOpenMessage(topic string) *Message { + return NewMessage(OpenEvent, topic, "") +} + +// NewKeepaliveMessage is a convenience method to create a keepalive message +func NewKeepaliveMessage(topic string) *Message { + return NewMessage(KeepaliveEvent, topic, "") +} + +// NewDefaultMessage is a convenience method to create a notification message +func NewDefaultMessage(topic, msg string) *Message { + return NewMessage(MessageEvent, topic, msg) +} + +// NewActionMessage creates a new action message (message_delete or message_clear) +func NewActionMessage(event, topic, sequenceID string) *Message { + m := NewMessage(event, topic, "") + m.SequenceID = sequenceID + return m +} + +// ValidMessageID returns true if the given string is a valid message ID +func ValidMessageID(s string) bool { + return util.ValidRandomString(s, MessageIDLength) +} + +// SinceMarker represents a point in time or message ID from which to retrieve messages +type SinceMarker struct { + time time.Time + id string +} + +// NewSinceTime creates a new SinceMarker from a Unix timestamp +func NewSinceTime(timestamp int64) SinceMarker { + return SinceMarker{time.Unix(timestamp, 0), ""} +} + +// NewSinceID creates a new SinceMarker from a message ID +func NewSinceID(id string) SinceMarker { + return SinceMarker{time.Unix(0, 0), id} +} + +// IsAll returns true if this is the "all messages" marker +func (t SinceMarker) IsAll() bool { + return t == SinceAllMessages +} + +// IsNone returns true if this is the "no messages" marker +func (t SinceMarker) IsNone() bool { + return t == SinceNoMessages +} + +// IsLatest returns true if this is the "latest message" marker +func (t SinceMarker) IsLatest() bool { + return t == SinceLatestMessage +} + +// IsID returns true if this marker references a specific message ID +func (t SinceMarker) IsID() bool { + return t.id != "" && t.id != "latest" +} + +// Time returns the time component of the marker +func (t SinceMarker) Time() time.Time { + return t.time +} + +// ID returns the message ID component of the marker +func (t SinceMarker) ID() string { + return t.id +} + +var ( + SinceAllMessages = SinceMarker{time.Unix(0, 0), ""} + SinceNoMessages = SinceMarker{time.Unix(1, 0), ""} + SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"} +) diff --git a/server/actions.go b/server/actions.go index 7bd4c903..dc4a6b43 100644 --- a/server/actions.go +++ b/server/actions.go @@ -8,6 +8,7 @@ import ( "strings" "unicode/utf8" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" ) @@ -39,7 +40,7 @@ type actionParser struct { // parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons. // It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON), // and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple). -func parseActions(s string) (actions []*action, err error) { +func parseActions(s string) (actions []*model.Action, err error) { // Parse JSON or simple format s = strings.TrimSpace(s) if strings.HasPrefix(s, "[") { @@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) { } // parseActionsFromJSON converts a JSON array into an array of actions -func parseActionsFromJSON(s string) ([]*action, error) { - actions := make([]*action, 0) +func parseActionsFromJSON(s string) ([]*model.Action, error) { + actions := make([]*model.Action, 0) if err := json.Unmarshal([]byte(s), &actions); err != nil { return nil, fmt.Errorf("JSON error: %w", err) } @@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) { // https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go // https://github.com/benbjohnson/sql-parser/blob/master/scanner.go // https://blog.gopheracademy.com/advent-2014/parsers-lexers/ -func parseActionsFromSimple(s string) ([]*action, error) { +func parseActionsFromSimple(s string) ([]*model.Action, error) { if !utf8.ValidString(s) { return nil, errors.New("invalid utf-8 string") } @@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) { } // Parse loops trough parseAction() until the end of the string is reached -func (p *actionParser) Parse() ([]*action, error) { - actions := make([]*action, 0) +func (p *actionParser) Parse() ([]*model.Action, error) { + actions := make([]*model.Action, 0) for !p.eof() { a, err := p.parseAction() if err != nil { @@ -134,7 +135,7 @@ func (p *actionParser) Parse() ([]*action, error) { // parseAction parses the individual sections of an action using parseSection into key/value pairs, // and then uses populateAction to interpret the keys/values. The function terminates // when EOF or ";" is reached. -func (p *actionParser) parseAction() (*action, error) { +func (p *actionParser) parseAction() (*model.Action, error) { a := newAction() section := 0 for { @@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) { // populateAction is the "business logic" of the parser. It applies the key/value // pair to the action instance. -func populateAction(newAction *action, section int, key, value string) error { +func populateAction(newAction *model.Action, section int, key, value string) error { // Auto-expand keys based on their index if key == "" && section == 0 { key = "action" diff --git a/server/log.go b/server/log.go index 72eaf4d1..03600c0d 100644 --- a/server/log.go +++ b/server/log.go @@ -10,6 +10,7 @@ import ( "github.com/emersion/go-smtp" "github.com/gorilla/websocket" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" ) @@ -55,12 +56,12 @@ func logvr(v *visitor, r *http.Request) *log.Event { } // logvrm creates a new log event with HTTP request, visitor fields and message fields -func logvrm(v *visitor, r *http.Request, m *message) *log.Event { +func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event { return logvr(v, r).With(m) } // logvrm creates a new log event with visitor fields and message fields -func logvm(v *visitor, m *message) *log.Event { +func logvm(v *visitor, m *model.Message) *log.Event { return logv(v).With(m) } diff --git a/server/message_cache.go b/server/message_cache.go deleted file mode 100644 index 84083aee..00000000 --- a/server/message_cache.go +++ /dev/null @@ -1,1104 +0,0 @@ -package server - -import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "net/netip" - "path/filepath" - "strings" - "sync" - "time" - - _ "github.com/mattn/go-sqlite3" // SQLite driver - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/util" -) - -var ( - errUnexpectedMessageType = errors.New("unexpected message type") - errMessageNotFound = errors.New("message not found") - errNoRows = errors.New("no rows found") -) - -// Messages cache -const ( - createMessagesTableQuery = ` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - sequence_id TEXT NOT NULL, - time INT NOT NULL, - event TEXT NOT NULL, - expires INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - icon TEXT NOT NULL, - actions TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - attachment_deleted INT NOT NULL, - sender TEXT NOT NULL, - user TEXT NOT NULL, - content_type TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); - CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id); - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); - CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); - CREATE INDEX IF NOT EXISTS idx_user ON messages (user); - CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); - CREATE TABLE IF NOT EXISTS stats ( - key TEXT PRIMARY KEY, - value INT - ); - INSERT INTO stats (key, value) VALUES ('messages', 0); - COMMIT; - ` - insertMessageQuery = ` - INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` - selectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0` - deleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0` - updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` - selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics - selectMessagesByIDQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE mid = ? - ` - selectMessagesSinceTimeQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND time >= ? AND published = 1 - ORDER BY time, id - ` - selectMessagesSinceTimeIncludeScheduledQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND time >= ? - ORDER BY time, id - ` - selectMessagesSinceIDQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND id > ? AND published = 1 - ORDER BY time, id - ` - selectMessagesSinceIDIncludeScheduledQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND (id > ? OR published = 0) - ORDER BY time, id - ` - selectMessagesLatestQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE topic = ? AND published = 1 - ORDER BY time DESC, id DESC - LIMIT 1 - ` - selectMessagesDueQuery = ` - SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding - FROM messages - WHERE time <= ? AND published = 0 - ORDER BY time, id - ` - selectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1` - updateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?` - selectMessagesCountQuery = `SELECT COUNT(*) FROM messages` - selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic` - selectTopicsQuery = `SELECT topic FROM messages GROUP BY topic` - - updateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?` - selectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0` - selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?` - selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?` - - selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` - updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` -) - -// Schema management queries -const ( - currentSchemaVersion = 14 - createSchemaVersionTableQuery = ` - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - ` - insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` - updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` - selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` - - // 0 -> 1 - migrate0To1AlterMessagesTableQuery = ` - BEGIN; - ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0); - ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT(''); - COMMIT; - ` - - // 1 -> 2 - migrate1To2AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1); - ` - - // 2 -> 3 - migrate2To3AlterMessagesTableQuery = ` - BEGIN; - ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT(''); - COMMIT; - ` - // 3 -> 4 - migrate3To4AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT(''); - ` - - // 4 -> 5 - migrate4To5AlterMessagesTableQuery = ` - BEGIN; - CREATE TABLE IF NOT EXISTS messages_new ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - time INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - attachment_owner TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid); - CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic); - INSERT - INTO messages_new ( - mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, - attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published) - SELECT - id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type, - attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published - FROM messages; - DROP TABLE messages; - ALTER TABLE messages_new RENAME TO messages; - COMMIT; - ` - - // 5 -> 6 - migrate5To6AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT(''); - ` - - // 6 -> 7 - migrate6To7AlterMessagesTableQuery = ` - ALTER TABLE messages RENAME COLUMN attachment_owner TO sender; - ` - - // 7 -> 8 - migrate7To8AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT(''); - ` - - // 8 -> 9 - migrate8To9AlterMessagesTableQuery = ` - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - ` - - // 9 -> 10 - migrate9To10AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0'); - ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0'); - CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires); - CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender); - CREATE INDEX IF NOT EXISTS idx_user ON messages (user); - CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires); - ` - migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?` - - // 10 -> 11 - migrate10To11AlterMessagesTableQuery = ` - CREATE TABLE IF NOT EXISTS stats ( - key TEXT PRIMARY KEY, - value INT - ); - INSERT INTO stats (key, value) VALUES ('messages', 0); - ` - - // 11 -> 12 - migrate11To12AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT(''); - ` - - // 12 -> 13 - migrate12To13AlterMessagesTableQuery = ` - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - ` - - //13 -> 14 - migrate13To14AlterMessagesTableQuery = ` - ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT(''); - ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message'); - CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id); - ` -) - -var ( - migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{ - 0: migrateFrom0, - 1: migrateFrom1, - 2: migrateFrom2, - 3: migrateFrom3, - 4: migrateFrom4, - 5: migrateFrom5, - 6: migrateFrom6, - 7: migrateFrom7, - 8: migrateFrom8, - 9: migrateFrom9, - 10: migrateFrom10, - 11: migrateFrom11, - 12: migrateFrom12, - 13: migrateFrom13, - } -) - -type messageCache struct { - db *sql.DB - queue *util.BatchingQueue[*message] - nop bool - mu sync.Mutex -} - -// newSqliteCache creates a SQLite file-backed cache -func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { - // Check the parent directory of the database file (makes for friendly error messages) - parentDir := filepath.Dir(filename) - if !util.FileExists(parentDir) { - return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) - } - // Open database - db, err := sql.Open("sqlite3", filename) - if err != nil { - return nil, err - } - if err := setupMessagesDB(db, startupQueries, cacheDuration); err != nil { - return nil, err - } - var queue *util.BatchingQueue[*message] - if batchSize > 0 || batchTimeout > 0 { - queue = util.NewBatchingQueue[*message](batchSize, batchTimeout) - } - cache := &messageCache{ - db: db, - queue: queue, - nop: nop, - } - go cache.processMessageBatches() - return cache, nil -} - -// newMemCache creates an in-memory cache -func newMemCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, false) -} - -// newNopCache creates an in-memory cache that discards all messages; -// it is always empty and can be used if caching is entirely disabled -func newNopCache() (*messageCache, error) { - return newSqliteCache(createMemoryFilename(), "", 0, 0, 0, true) -} - -// createMemoryFilename creates a unique memory filename to use for the SQLite backend. -// From mattn/go-sqlite3: "Each connection to ":memory:" opens a brand new in-memory -// sql database, so if the stdlib's sql engine happens to open another connection and -// you've only specified ":memory:", that connection will see a brand new database. -// A workaround is to use "file::memory:?cache=shared" (or "file:foobar?mode=memory&cache=shared"). -// Every connection to this string will point to the same in-memory database." -func createMemoryFilename() string { - return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10)) -} - -// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asyncronously. -// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor. -func (c *messageCache) AddMessage(m *message) error { - if c.queue != nil { - c.queue.Enqueue(m) - return nil - } - return c.addMessages([]*message{m}) -} - -// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until -// SQLite's busy_timeout is exceeded before erroring out. -func (c *messageCache) addMessages(ms []*message) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.nop { - return nil - } - if len(ms) == 0 { - return nil - } - start := time.Now() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - stmt, err := tx.Prepare(insertMessageQuery) - if err != nil { - return err - } - defer stmt.Close() - for _, m := range ms { - if m.Event != messageEvent && m.Event != messageDeleteEvent && m.Event != messageClearEvent { - return errUnexpectedMessageType - } - published := m.Time <= time.Now().Unix() - tags := strings.Join(m.Tags, ",") - var attachmentName, attachmentType, attachmentURL string - var attachmentSize, attachmentExpires, attachmentDeleted int64 - if m.Attachment != nil { - attachmentName = m.Attachment.Name - attachmentType = m.Attachment.Type - attachmentSize = m.Attachment.Size - attachmentExpires = m.Attachment.Expires - attachmentURL = m.Attachment.URL - } - var actionsStr string - if len(m.Actions) > 0 { - actionsBytes, err := json.Marshal(m.Actions) - if err != nil { - return err - } - actionsStr = string(actionsBytes) - } - var sender string - if m.Sender.IsValid() { - sender = m.Sender.String() - } - _, err := stmt.Exec( - m.ID, - m.SequenceID, - m.Time, - m.Event, - m.Expires, - m.Topic, - m.Message, - m.Title, - m.Priority, - tags, - m.Click, - m.Icon, - actionsStr, - attachmentName, - attachmentType, - attachmentSize, - attachmentExpires, - attachmentURL, - attachmentDeleted, // Always zero - sender, - m.User, - m.ContentType, - m.Encoding, - published, - ) - if err != nil { - return err - } - } - if err := tx.Commit(); err != nil { - log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start)) - return err - } - log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start)) - return nil -} - -func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) { - if since.IsNone() { - return make([]*message, 0), nil - } else if since.IsLatest() { - return c.messagesLatest(topic) - } else if since.IsID() { - return c.messagesSinceID(topic, since, scheduled) - } - return c.messagesSinceTime(topic, since, scheduled) -} - -func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, scheduled bool) ([]*message, error) { - var rows *sql.Rows - var err error - if scheduled { - rows, err = c.db.Query(selectMessagesSinceTimeIncludeScheduledQuery, topic, since.Time().Unix()) - } else { - rows, err = c.db.Query(selectMessagesSinceTimeQuery, topic, since.Time().Unix()) - } - if err != nil { - return nil, err - } - return readMessages(rows) -} - -func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) { - idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID()) - if err != nil { - return nil, err - } - defer idrows.Close() - if !idrows.Next() { - return c.messagesSinceTime(topic, sinceAllMessages, scheduled) - } - var rowID int64 - if err := idrows.Scan(&rowID); err != nil { - return nil, err - } - idrows.Close() - var rows *sql.Rows - if scheduled { - rows, err = c.db.Query(selectMessagesSinceIDIncludeScheduledQuery, topic, rowID) - } else { - rows, err = c.db.Query(selectMessagesSinceIDQuery, topic, rowID) - } - if err != nil { - return nil, err - } - return readMessages(rows) -} - -func (c *messageCache) messagesLatest(topic string) ([]*message, error) { - rows, err := c.db.Query(selectMessagesLatestQuery, topic) - if err != nil { - return nil, err - } - return readMessages(rows) -} - -func (c *messageCache) MessagesDue() ([]*message, error) { - rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix()) - if err != nil { - return nil, err - } - return readMessages(rows) -} - -// MessagesExpired returns a list of IDs for messages that have expires (should be deleted) -func (c *messageCache) MessagesExpired() ([]string, error) { - rows, err := c.db.Query(selectMessagesExpiredQuery, time.Now().Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil -} - -func (c *messageCache) Message(id string) (*message, error) { - rows, err := c.db.Query(selectMessagesByIDQuery, id) - if err != nil { - return nil, err - } - if !rows.Next() { - return nil, errMessageNotFound - } - defer rows.Close() - return readMessage(rows) -} - -func (c *messageCache) MarkPublished(m *message) error { - c.mu.Lock() - defer c.mu.Unlock() - _, err := c.db.Exec(updateMessagePublishedQuery, m.ID) - return err -} - -func (c *messageCache) MessageCounts() (map[string]int, error) { - rows, err := c.db.Query(selectMessageCountPerTopicQuery) - if err != nil { - return nil, err - } - defer rows.Close() - var topic string - var count int - counts := make(map[string]int) - for rows.Next() { - if err := rows.Scan(&topic, &count); err != nil { - return nil, err - } else if err := rows.Err(); err != nil { - return nil, err - } - counts[topic] = count - } - return counts, nil -} - -func (c *messageCache) Topics() (map[string]*topic, error) { - rows, err := c.db.Query(selectTopicsQuery) - if err != nil { - return nil, err - } - defer rows.Close() - topics := make(map[string]*topic) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - topics[id] = newTopic(id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return topics, nil -} - -func (c *messageCache) DeleteMessages(ids ...string) error { - c.mu.Lock() - defer c.mu.Unlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, id := range ids { - if _, err := tx.Exec(deleteMessageQuery, id); err != nil { - return err - } - } - return tx.Commit() -} - -// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. -// It returns the message IDs of the deleted messages, which can be used to clean up attachment files. -func (c *messageCache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { - c.mu.Lock() - defer c.mu.Unlock() - tx, err := c.db.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - // First, get the message IDs of scheduled messages to be deleted - rows, err := tx.Query(selectScheduledMessageIDsBySeqIDQuery, topic, sequenceID) - if err != nil { - return nil, err - } - defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - rows.Close() // Close rows before executing delete in same transaction - // Then delete the messages - if _, err := tx.Exec(deleteScheduledBySequenceIDQuery, topic, sequenceID); err != nil { - return nil, err - } - if err := tx.Commit(); err != nil { - return nil, err - } - return ids, nil -} - -func (c *messageCache) ExpireMessages(topics ...string) error { - c.mu.Lock() - defer c.mu.Unlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, t := range topics { - if _, err := tx.Exec(updateMessagesForTopicExpiryQuery, time.Now().Unix()-1, t); err != nil { - return err - } - } - return tx.Commit() -} - -func (c *messageCache) AttachmentsExpired() ([]string, error) { - rows, err := c.db.Query(selectAttachmentsExpiredQuery, time.Now().Unix()) - if err != nil { - return nil, err - } - defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil -} - -func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error { - c.mu.Lock() - defer c.mu.Unlock() - tx, err := c.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - for _, id := range ids { - if _, err := tx.Exec(updateAttachmentDeleted, id); err != nil { - return err - } - } - return tx.Commit() -} - -func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix()) - if err != nil { - return 0, err - } - return c.readAttachmentBytesUsed(rows) -} - -func (c *messageCache) AttachmentBytesUsedByUser(userID string) (int64, error) { - rows, err := c.db.Query(selectAttachmentsSizeByUserIDQuery, userID, time.Now().Unix()) - if err != nil { - return 0, err - } - return c.readAttachmentBytesUsed(rows) -} - -func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { - defer rows.Close() - var size int64 - if !rows.Next() { - return 0, errors.New("no rows found") - } - if err := rows.Scan(&size); err != nil { - return 0, err - } else if err := rows.Err(); err != nil { - return 0, err - } - return size, nil -} - -func (c *messageCache) processMessageBatches() { - if c.queue == nil { - return - } - for messages := range c.queue.Dequeue() { - if err := c.addMessages(messages); err != nil { - log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch") - } - } -} - -func readMessages(rows *sql.Rows) ([]*message, error) { - defer rows.Close() - messages := make([]*message, 0) - for rows.Next() { - m, err := readMessage(rows) - if err != nil { - return nil, err - } - messages = append(messages, m) - } - if err := rows.Err(); err != nil { - return nil, err - } - return messages, nil -} - -func readMessage(rows *sql.Rows) (*message, error) { - var timestamp, expires, attachmentSize, attachmentExpires int64 - var priority int - var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string - err := rows.Scan( - &id, - &sequenceID, - ×tamp, - &event, - &expires, - &topic, - &msg, - &title, - &priority, - &tagsStr, - &click, - &icon, - &actionsStr, - &attachmentName, - &attachmentType, - &attachmentSize, - &attachmentExpires, - &attachmentURL, - &sender, - &user, - &contentType, - &encoding, - ) - if err != nil { - return nil, err - } - var tags []string - if tagsStr != "" { - tags = strings.Split(tagsStr, ",") - } - var actions []*action - if actionsStr != "" { - if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil { - return nil, err - } - } - senderIP, err := netip.ParseAddr(sender) - if err != nil { - senderIP = netip.Addr{} // if no IP stored in database, return invalid address - } - var att *attachment - if attachmentName != "" && attachmentURL != "" { - att = &attachment{ - Name: attachmentName, - Type: attachmentType, - Size: attachmentSize, - Expires: attachmentExpires, - URL: attachmentURL, - } - } - return &message{ - ID: id, - SequenceID: sequenceID, - Time: timestamp, - Expires: expires, - Event: event, - Topic: topic, - Message: msg, - Title: title, - Priority: priority, - Tags: tags, - Click: click, - Icon: icon, - Actions: actions, - Attachment: att, - Sender: senderIP, // Must parse assuming database must be correct - User: user, - ContentType: contentType, - Encoding: encoding, - }, nil -} - -func (c *messageCache) UpdateStats(messages int64) error { - c.mu.Lock() - defer c.mu.Unlock() - _, err := c.db.Exec(updateStatsQuery, messages) - return err -} - -func (c *messageCache) Stats() (messages int64, err error) { - rows, err := c.db.Query(selectStatsQuery) - if err != nil { - return 0, err - } - defer rows.Close() - if !rows.Next() { - return 0, errNoRows - } - if err := rows.Scan(&messages); err != nil { - return 0, err - } - return messages, nil -} - -func (c *messageCache) Close() error { - return c.db.Close() -} - -func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error { - // Run startup queries - if startupQueries != "" { - if _, err := db.Exec(startupQueries); err != nil { - return err - } - } - - // If 'messages' table does not exist, this must be a new database - rowsMC, err := db.Query(selectMessagesCountQuery) - if err != nil { - return setupNewCacheDB(db) - } - rowsMC.Close() - - // If 'messages' table exists, check 'schemaVersion' table - schemaVersion := 0 - rowsSV, err := db.Query(selectSchemaVersionQuery) - if err == nil { - defer rowsSV.Close() - if !rowsSV.Next() { - return errors.New("cannot determine schema version: cache file may be corrupt") - } - if err := rowsSV.Scan(&schemaVersion); err != nil { - return err - } - rowsSV.Close() - } - - // Do migrations - if schemaVersion == currentSchemaVersion { - return nil - } else if schemaVersion > currentSchemaVersion { - return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion) - } - for i := schemaVersion; i < currentSchemaVersion; i++ { - fn, ok := migrations[i] - if !ok { - return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1) - } else if err := fn(db, cacheDuration); err != nil { - return err - } - } - return nil -} - -func setupNewCacheDB(db *sql.DB) error { - if _, err := db.Exec(createMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil { - return err - } - return nil -} - -func migrateFrom0(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1") - if _, err := db.Exec(migrate0To1AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(createSchemaVersionTableQuery); err != nil { - return err - } - if _, err := db.Exec(insertSchemaVersion, 1); err != nil { - return err - } - return nil -} - -func migrateFrom1(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2") - if _, err := db.Exec(migrate1To2AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 2); err != nil { - return err - } - return nil -} - -func migrateFrom2(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3") - if _, err := db.Exec(migrate2To3AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 3); err != nil { - return err - } - return nil -} - -func migrateFrom3(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4") - if _, err := db.Exec(migrate3To4AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 4); err != nil { - return err - } - return nil -} - -func migrateFrom4(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5") - if _, err := db.Exec(migrate4To5AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 5); err != nil { - return err - } - return nil -} - -func migrateFrom5(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6") - if _, err := db.Exec(migrate5To6AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 6); err != nil { - return err - } - return nil -} - -func migrateFrom6(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7") - if _, err := db.Exec(migrate6To7AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 7); err != nil { - return err - } - return nil -} - -func migrateFrom7(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8") - if _, err := db.Exec(migrate7To8AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 8); err != nil { - return err - } - return nil -} - -func migrateFrom8(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9") - if _, err := db.Exec(migrate8To9AlterMessagesTableQuery); err != nil { - return err - } - if _, err := db.Exec(updateSchemaVersion, 9); err != nil { - return err - } - return nil -} - -func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate9To10AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(migrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 10); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom10(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 11); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom11(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate11To12AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 12); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom12(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate12To13AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 13); err != nil { - return err - } - return tx.Commit() -} - -func migrateFrom13(db *sql.DB, _ time.Duration) error { - log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14") - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - if _, err := tx.Exec(migrate13To14AlterMessagesTableQuery); err != nil { - return err - } - if _, err := tx.Exec(updateSchemaVersion, 14); err != nil { - return err - } - return tx.Commit() -} diff --git a/server/message_cache_test.go b/server/message_cache_test.go deleted file mode 100644 index 672f91b0..00000000 --- a/server/message_cache_test.go +++ /dev/null @@ -1,825 +0,0 @@ -package server - -import ( - "database/sql" - "fmt" - "github.com/stretchr/testify/assert" - "net/netip" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestSqliteCache_Messages(t *testing.T) { - testCacheMessages(t, newSqliteTestCache(t)) -} - -func TestMemCache_Messages(t *testing.T) { - testCacheMessages(t, newMemTestCache(t)) -} - -func testCacheMessages(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "my message") - m1.Time = 1 - - m2 := newDefaultMessage("mytopic", "my other message") - m2.Time = 2 - - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message"))) - require.Nil(t, c.AddMessage(m2)) - - // Adding invalid - require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added! - require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added! - - // mytopic: count - counts, err := c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - - // mytopic: since all - messages, _ := c.Messages("mytopic", sinceAllMessages, false) - require.Equal(t, 2, len(messages)) - require.Equal(t, "my message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - require.Equal(t, messageEvent, messages[0].Event) - require.Equal(t, "", messages[0].Title) - require.Equal(t, 0, messages[0].Priority) - require.Nil(t, messages[0].Tags) - require.Equal(t, "my other message", messages[1].Message) - - // mytopic: since none - messages, _ = c.Messages("mytopic", sinceNoMessages, false) - require.Empty(t, messages) - - // mytopic: since m1 (by ID) - messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, m2.ID, messages[0].ID) - require.Equal(t, "my other message", messages[0].Message) - require.Equal(t, "mytopic", messages[0].Topic) - - // mytopic: since 2 - messages, _ = c.Messages("mytopic", newSinceTime(2), false) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) - - // mytopic: latest - messages, _ = c.Messages("mytopic", sinceLatestMessage, false) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) - - // example: count - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["example"]) - - // example: since all - messages, _ = c.Messages("example", sinceAllMessages, false) - require.Equal(t, "my example message", messages[0].Message) - - // non-existing: count - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 0, counts["doesnotexist"]) - - // non-existing: since all - messages, _ = c.Messages("doesnotexist", sinceAllMessages, false) - require.Empty(t, messages) -} - -func TestSqliteCache_MessagesLock(t *testing.T) { - testCacheMessagesLock(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesLock(t *testing.T) { - testCacheMessagesLock(t, newMemTestCache(t)) -} - -func testCacheMessagesLock(t *testing.T, c *messageCache) { - var wg sync.WaitGroup - for i := 0; i < 5000; i++ { - wg.Add(1) - go func() { - assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message"))) - wg.Done() - }() - } - wg.Wait() -} - -func TestSqliteCache_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesScheduled(t *testing.T) { - testCacheMessagesScheduled(t, newMemTestCache(t)) -} - -func testCacheMessagesScheduled(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "message 1") - m2 := newDefaultMessage("mytopic", "message 2") - m2.Time = time.Now().Add(time.Hour).Unix() - m3 := newDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2! - m4 := newDefaultMessage("mytopic2", "message 4") - m4.Time = time.Now().Add(time.Minute).Unix() - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - - messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled - require.Equal(t, 1, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - - messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) // Order! - require.Equal(t, "message 2", messages[2].Message) - - messages, _ = c.MessagesDue() - require.Empty(t, messages) -} - -func TestSqliteCache_Topics(t *testing.T) { - testCacheTopics(t, newSqliteTestCache(t)) -} - -func TestMemCache_Topics(t *testing.T) { - testCacheTopics(t, newMemTestCache(t)) -} - -func testCacheTopics(t *testing.T, c *messageCache) { - require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2"))) - require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3"))) - - topics, err := c.Topics() - if err != nil { - t.Fatal(err) - } - require.Equal(t, 2, len(topics)) - require.Equal(t, "topic1", topics["topic1"].ID) - require.Equal(t, "topic2", topics["topic2"].ID) -} - -func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) { - testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t)) -} - -func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) { - m := newDefaultMessage("mytopic", "some message") - m.Tags = []string{"tag1", "tag2"} - m.Priority = 5 - m.Title = "some title" - require.Nil(t, c.AddMessage(m)) - - messages, _ := c.Messages("mytopic", sinceAllMessages, false) - require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags) - require.Equal(t, 5, messages[0].Priority) - require.Equal(t, "some title", messages[0].Title) -} - -func TestSqliteCache_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newSqliteTestCache(t)) -} - -func TestMemCache_MessagesSinceID(t *testing.T) { - testCacheMessagesSinceID(t, newMemTestCache(t)) -} - -func testCacheMessagesSinceID(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "message 1") - m1.Time = 100 - m2 := newDefaultMessage("mytopic", "message 2") - m2.Time = 200 - m3 := newDefaultMessage("mytopic", "message 3") - m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5 - m4 := newDefaultMessage("mytopic", "message 4") - m4.Time = 400 - m5 := newDefaultMessage("mytopic", "message 5") - m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7 - m6 := newDefaultMessage("mytopic", "message 6") - m6.Time = 600 - m7 := newDefaultMessage("mytopic", "message 7") - m7.Time = 700 - - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - require.Nil(t, c.AddMessage(m4)) - require.Nil(t, c.AddMessage(m5)) - require.Nil(t, c.AddMessage(m6)) - require.Nil(t, c.AddMessage(m7)) - - // Case 1: Since ID exists, exclude scheduled - messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false) - require.Equal(t, 3, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5! - require.Equal(t, "message 7", messages[2].Message) - - // Case 2: Since ID exists, include scheduled - messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true) - require.Equal(t, 5, len(messages)) - require.Equal(t, "message 4", messages[0].Message) - require.Equal(t, "message 6", messages[1].Message) - require.Equal(t, "message 7", messages[2].Message) - require.Equal(t, "message 5", messages[3].Message) // Order! - require.Equal(t, "message 3", messages[4].Message) // Order! - - // Case 3: Since ID does not exist (-> Return all messages), include scheduled - messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true) - require.Equal(t, 7, len(messages)) - require.Equal(t, "message 1", messages[0].Message) - require.Equal(t, "message 2", messages[1].Message) - require.Equal(t, "message 4", messages[2].Message) - require.Equal(t, "message 6", messages[3].Message) - require.Equal(t, "message 7", messages[4].Message) - require.Equal(t, "message 5", messages[5].Message) // Order! - require.Equal(t, "message 3", messages[6].Message) // Order! - - // Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled - messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false) - require.Equal(t, 0, len(messages)) - - // Case 5: Since ID exists and is last message (-> Return no messages), include scheduled - messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true) - require.Equal(t, 2, len(messages)) - require.Equal(t, "message 5", messages[0].Message) - require.Equal(t, "message 3", messages[1].Message) -} - -func TestSqliteCache_Prune(t *testing.T) { - testCachePrune(t, newSqliteTestCache(t)) -} - -func TestMemCache_Prune(t *testing.T) { - testCachePrune(t, newMemTestCache(t)) -} - -func testCachePrune(t *testing.T, c *messageCache) { - now := time.Now().Unix() - - m1 := newDefaultMessage("mytopic", "my message") - m1.Time = now - 10 - m1.Expires = now - 5 - - m2 := newDefaultMessage("mytopic", "my other message") - m2.Time = now - 5 - m2.Expires = now + 5 // In the future - - m3 := newDefaultMessage("another_topic", "and another one") - m3.Time = now - 12 - m3.Expires = now - 2 - - require.Nil(t, c.AddMessage(m1)) - require.Nil(t, c.AddMessage(m2)) - require.Nil(t, c.AddMessage(m3)) - - counts, err := c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 2, counts["mytopic"]) - require.Equal(t, 1, counts["another_topic"]) - - expiredMessageIDs, err := c.MessagesExpired() - require.Nil(t, err) - require.Nil(t, c.DeleteMessages(expiredMessageIDs...)) - - counts, err = c.MessageCounts() - require.Nil(t, err) - require.Equal(t, 1, counts["mytopic"]) - require.Equal(t, 0, counts["another_topic"]) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "my other message", messages[0].Message) -} - -func TestSqliteCache_Attachments(t *testing.T) { - testCacheAttachments(t, newSqliteTestCache(t)) -} - -func TestMemCache_Attachments(t *testing.T) { - testCacheAttachments(t, newMemTestCache(t)) -} - -func testCacheAttachments(t *testing.T, c *messageCache) { - expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired - m := newDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.SequenceID = "m1" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &attachment{ - Name: "flower.jpg", - Type: "image/jpeg", - Size: 5000, - Expires: expires1, - URL: "https://ntfy.sh/file/AbDeFgJhal.jpg", - } - require.Nil(t, c.AddMessage(m)) - - expires2 := time.Now().Add(2 * time.Hour).Unix() // Future - m = newDefaultMessage("mytopic", "sending you a car") - m.ID = "m2" - m.SequenceID = "m2" - m.Sender = netip.MustParseAddr("1.2.3.4") - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: expires2, - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - expires3 := time.Now().Add(1 * time.Hour).Unix() // Future - m = newDefaultMessage("another-topic", "sending you another car") - m.ID = "m3" - m.SequenceID = "m3" - m.User = "u_BAsbaAa" - m.Sender = netip.MustParseAddr("5.6.7.8") - m.Attachment = &attachment{ - Name: "another-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: expires3, - URL: "https://ntfy.sh/file/zakaDHFW.jpg", - } - require.Nil(t, c.AddMessage(m)) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - - require.Equal(t, "flower for you", messages[0].Message) - require.Equal(t, "flower.jpg", messages[0].Attachment.Name) - require.Equal(t, "image/jpeg", messages[0].Attachment.Type) - require.Equal(t, int64(5000), messages[0].Attachment.Size) - require.Equal(t, expires1, messages[0].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[0].Sender.String()) - - require.Equal(t, "sending you a car", messages[1].Message) - require.Equal(t, "car.jpg", messages[1].Attachment.Name) - require.Equal(t, "image/jpeg", messages[1].Attachment.Type) - require.Equal(t, int64(10000), messages[1].Attachment.Size) - require.Equal(t, expires2, messages[1].Attachment.Expires) - require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL) - require.Equal(t, "1.2.3.4", messages[1].Sender.String()) - - size, err := c.AttachmentBytesUsedBySender("1.2.3.4") - require.Nil(t, err) - require.Equal(t, int64(10000), size) - - size, err = c.AttachmentBytesUsedBySender("5.6.7.8") - require.Nil(t, err) - require.Equal(t, int64(0), size) // Accounted to the user, not the IP! - - size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa") - require.Nil(t, err) - require.Equal(t, int64(20000), size) -} - -func TestSqliteCache_Attachments_Expired(t *testing.T) { - testCacheAttachmentsExpired(t, newSqliteTestCache(t)) -} - -func TestMemCache_Attachments_Expired(t *testing.T) { - testCacheAttachmentsExpired(t, newMemTestCache(t)) -} - -func testCacheAttachmentsExpired(t *testing.T, c *messageCache) { - m := newDefaultMessage("mytopic", "flower for you") - m.ID = "m1" - m.SequenceID = "m1" - m.Expires = time.Now().Add(time.Hour).Unix() - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic", "message with attachment") - m.ID = "m2" - m.SequenceID = "m2" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Size: 10000, - Expires: time.Now().Add(2 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic", "message with external attachment") - m.ID = "m3" - m.SequenceID = "m3" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "car.jpg", - Type: "image/jpeg", - Expires: 0, // Unknown! - URL: "https://somedomain.com/car.jpg", - } - require.Nil(t, c.AddMessage(m)) - - m = newDefaultMessage("mytopic2", "message with expired attachment") - m.ID = "m4" - m.SequenceID = "m4" - m.Expires = time.Now().Add(2 * time.Hour).Unix() - m.Attachment = &attachment{ - Name: "expired-car.jpg", - Type: "image/jpeg", - Size: 20000, - Expires: time.Now().Add(-1 * time.Hour).Unix(), - URL: "https://ntfy.sh/file/aCaRURL.jpg", - } - require.Nil(t, c.AddMessage(m)) - - ids, err := c.AttachmentsExpired() - require.Nil(t, err) - require.Equal(t, 1, len(ids)) - require.Equal(t, "m4", ids[0]) -} - -func TestSqliteCache_Migration_From0(t *testing.T) { - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) - - // Create "version 0" schema - _, err = db.Exec(` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id VARCHAR(20) PRIMARY KEY, - time INT NOT NULL, - topic VARCHAR(64) NOT NULL, - message VARCHAR(1024) NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - COMMIT; - `) - require.Nil(t, err) - - // Insert a bunch of messages - for i := 0; i < 10; i++ { - _, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`, - fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i)) - require.Nil(t, err) - } - require.Nil(t, db.Close()) - - // Create cache to trigger migration - c := newSqliteTestCacheFromFile(t, filename, "") - checkSchemaVersion(t, c.db) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - require.Equal(t, "some message 5", messages[5].Message) - require.Equal(t, "", messages[5].Title) - require.Nil(t, messages[5].Tags) - require.Equal(t, 0, messages[5].Priority) -} - -func TestSqliteCache_Migration_From1(t *testing.T) { - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) - - // Create "version 1" schema - _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS messages ( - id VARCHAR(20) PRIMARY KEY, - time INT NOT NULL, - topic VARCHAR(64) NOT NULL, - message VARCHAR(512) NOT NULL, - title VARCHAR(256) NOT NULL, - priority INT NOT NULL, - tags VARCHAR(256) NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO schemaVersion (id, version) VALUES (1, 1); - `) - require.Nil(t, err) - - // Insert a bunch of messages - for i := 0; i < 10; i++ { - _, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`, - fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "") - require.Nil(t, err) - } - require.Nil(t, db.Close()) - - // Create cache to trigger migration - c := newSqliteTestCacheFromFile(t, filename, "") - checkSchemaVersion(t, c.db) - - // Add delayed message - delayedMessage := newDefaultMessage("mytopic", "some delayed message") - delayedMessage.Time = time.Now().Add(time.Minute).Unix() - require.Nil(t, c.AddMessage(delayedMessage)) - - // 10, not 11! - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - - // 11! - messages, err = c.Messages("mytopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 11, len(messages)) - - // Check that index "idx_topic" exists - rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`) - require.Nil(t, err) - require.True(t, rows.Next()) - var indexName string - require.Nil(t, rows.Scan(&indexName)) - require.Equal(t, "idx_topic", indexName) -} - -func TestSqliteCache_Migration_From9(t *testing.T) { - // This primarily tests the awkward migration that introduces the "expires" column. - // The migration logic has to update the column, using the existing "cache-duration" value. - - filename := newSqliteTestCacheFile(t) - db, err := sql.Open("sqlite3", filename) - require.Nil(t, err) - - // Create "version 8" schema - _, err = db.Exec(` - BEGIN; - CREATE TABLE IF NOT EXISTS messages ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mid TEXT NOT NULL, - time INT NOT NULL, - topic TEXT NOT NULL, - message TEXT NOT NULL, - title TEXT NOT NULL, - priority INT NOT NULL, - tags TEXT NOT NULL, - click TEXT NOT NULL, - icon TEXT NOT NULL, - actions TEXT NOT NULL, - attachment_name TEXT NOT NULL, - attachment_type TEXT NOT NULL, - attachment_size INT NOT NULL, - attachment_expires INT NOT NULL, - attachment_url TEXT NOT NULL, - sender TEXT NOT NULL, - encoding TEXT NOT NULL, - published INT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid); - CREATE INDEX IF NOT EXISTS idx_time ON messages (time); - CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic); - CREATE TABLE IF NOT EXISTS schemaVersion ( - id INT PRIMARY KEY, - version INT NOT NULL - ); - INSERT INTO schemaVersion (id, version) VALUES (1, 9); - COMMIT; - `) - require.Nil(t, err) - - // Insert a bunch of messages - insertQuery := ` - INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - for i := 0; i < 10; i++ { - _, err = db.Exec( - insertQuery, - fmt.Sprintf("abcd%d", i), - time.Now().Unix(), - "mytopic", - fmt.Sprintf("some message %d", i), - "", // title - 0, // priority - "", // tags - "", // click - "", // icon - "", // actions - "", // attachment_name - "", // attachment_type - 0, // attachment_size - 0, // attachment_type - "", // attachment_url - "9.9.9.9", // sender - "", // encoding - 1, // published - ) - require.Nil(t, err) - } - - // Create cache to trigger migration - cacheDuration := 17 * time.Hour - c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false) - require.Nil(t, err) - checkSchemaVersion(t, c.db) - - // Check version - rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`) - require.Nil(t, err) - require.True(t, rows.Next()) - var version int - require.Nil(t, rows.Scan(&version)) - require.Equal(t, currentSchemaVersion, version) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 10, len(messages)) - for _, m := range messages { - require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix()) - require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix()) - } -} - -func TestSqliteCache_StartupQueries_WAL(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := `pragma journal_mode = WAL; -pragma synchronous = normal; -pragma temp_store = memory;` - db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Nil(t, err) - require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) - require.FileExists(t, filename) - require.FileExists(t, filename+"-wal") - require.FileExists(t, filename+"-shm") -} - -func TestSqliteCache_StartupQueries_None(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := "" - db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Nil(t, err) - require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message"))) - require.FileExists(t, filename) - require.NoFileExists(t, filename+"-wal") - require.NoFileExists(t, filename+"-shm") -} - -func TestSqliteCache_StartupQueries_Fail(t *testing.T) { - filename := newSqliteTestCacheFile(t) - startupQueries := `xx error` - _, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Error(t, err) -} - -func TestSqliteCache_Sender(t *testing.T) { - testSender(t, newSqliteTestCache(t)) -} - -func TestMemCache_Sender(t *testing.T) { - testSender(t, newMemTestCache(t)) -} - -func testSender(t *testing.T, c *messageCache) { - m1 := newDefaultMessage("mytopic", "mymessage") - m1.Sender = netip.MustParseAddr("1.2.3.4") - require.Nil(t, c.AddMessage(m1)) - - m2 := newDefaultMessage("mytopic", "mymessage without sender") - require.Nil(t, c.AddMessage(m2)) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4")) - require.Equal(t, messages[1].Sender, netip.Addr{}) -} - -func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) { - testDeleteScheduledBySequenceID(t, newSqliteTestCache(t)) -} - -func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) { - testDeleteScheduledBySequenceID(t, newMemTestCache(t)) -} - -func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) { - // Create a scheduled (unpublished) message - scheduledMsg := newDefaultMessage("mytopic", "scheduled message") - scheduledMsg.ID = "scheduled1" - scheduledMsg.SequenceID = "seq123" - scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled - require.Nil(t, c.AddMessage(scheduledMsg)) - - // Create a published message with different sequence ID - publishedMsg := newDefaultMessage("mytopic", "published message") - publishedMsg.ID = "published1" - publishedMsg.SequenceID = "seq456" - publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published - require.Nil(t, c.AddMessage(publishedMsg)) - - // Create a scheduled message in a different topic - otherTopicMsg := newDefaultMessage("othertopic", "other scheduled") - otherTopicMsg.ID = "other1" - otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg - otherTopicMsg.Time = time.Now().Add(time.Hour).Unix() - require.Nil(t, c.AddMessage(otherTopicMsg)) - - // Verify all messages exist (including scheduled) - messages, err := c.Messages("mytopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 2, len(messages)) - - messages, err = c.Messages("othertopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - - // Delete scheduled message by sequence ID and verify returned IDs - deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123") - require.Nil(t, err) - require.Equal(t, 1, len(deletedIDs)) - require.Equal(t, "scheduled1", deletedIDs[0]) - - // Verify scheduled message is deleted - messages, err = c.Messages("mytopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "published message", messages[0].Message) - - // Verify other topic's message still exists (topic-scoped deletion) - messages, err = c.Messages("othertopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "other scheduled", messages[0].Message) - - // Deleting non-existent sequence ID should return empty list - deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent") - require.Nil(t, err) - require.Empty(t, deletedIDs) - - // Deleting published message should not affect it (only deletes unpublished) - deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456") - require.Nil(t, err) - require.Empty(t, deletedIDs) - - messages, err = c.Messages("mytopic", sinceAllMessages, true) - require.Nil(t, err) - require.Equal(t, 1, len(messages)) - require.Equal(t, "published message", messages[0].Message) -} - -func checkSchemaVersion(t *testing.T, db *sql.DB) { - rows, err := db.Query(`SELECT version FROM schemaVersion`) - require.Nil(t, err) - require.True(t, rows.Next()) - - var schemaVersion int - require.Nil(t, rows.Scan(&schemaVersion)) - require.Equal(t, currentSchemaVersion, schemaVersion) - require.Nil(t, rows.Close()) -} - -func TestMemCache_NopCache(t *testing.T) { - c, _ := newNopCache() - require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) - - messages, err := c.Messages("mytopic", sinceAllMessages, false) - require.Nil(t, err) - require.Empty(t, messages) - - topics, err := c.Topics() - require.Nil(t, err) - require.Empty(t, topics) -} - -func newSqliteTestCache(t *testing.T) *messageCache { - c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false) - if err != nil { - t.Fatal(err) - } - return c -} - -func newSqliteTestCacheFile(t *testing.T) string { - return filepath.Join(t.TempDir(), "cache.db") -} - -func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache { - c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false) - require.Nil(t, err) - return c -} - -func newMemTestCache(t *testing.T) *messageCache { - c, err := newMemCache() - require.Nil(t, err) - return c -} diff --git a/server/server.go b/server/server.go index ea03355b..80f367c2 100644 --- a/server/server.go +++ b/server/server.go @@ -33,6 +33,8 @@ import ( "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" @@ -57,7 +59,7 @@ type Server struct { messages int64 // Total number of messages (persisted if messageCache enabled) messagesHistory []int64 // Last n values of the messages counter, used to determine rate userManager *user.Manager // Might be nil! - messageCache *messageCache // Database that stores the messages + messageCache message.Store // Database that stores the messages webPush webpush.Store // Database that stores web push subscriptions fileCache *fileCache // File system based cache that stores attachments stripe stripeAPI // Stripe API, can be replaced with a mock @@ -188,10 +190,14 @@ func New(conf *Config) (*Server, error) { return nil, err } } - topics, err := messageCache.Topics() + topicIDs, err := messageCache.Topics() if err != nil { return nil, err } + topics := make(map[string]*topic, len(topicIDs)) + for _, id := range topicIDs { + topics[id] = newTopic(id) + } messages, err := messageCache.Stats() if err != nil { return nil, err @@ -263,13 +269,15 @@ func New(conf *Config) (*Server, error) { return s, nil } -func createMessageCache(conf *Config) (*messageCache, error) { +func createMessageCache(conf *Config) (message.Store, error) { if conf.CacheDuration == 0 { - return newNopCache() + return message.NewNopStore() + } else if conf.DatabaseURL != "" { + return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout) } else if conf.CacheFile != "" { - return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) + return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false) } - return newMemCache() + return message.NewMemStore() } // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts @@ -750,7 +758,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) if s.config.CacheBatchTimeout > 0 { // Strange edge case: If we immediately after upload request the file (the web app does this for images), // and messages are persisted asynchronously, retry fetching from the database - m, err = util.Retry(func() (*message, error) { + m, err = util.Retry(func() (*model.Message, error) { return s.messageCache.Message(messageID) }, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond) } @@ -796,7 +804,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { return writeMatrixDiscoveryResponse(w) } -func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) { +func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) { start := time.Now() t, err := fromContext[*topic](r, contextTopic) if err != nil { @@ -924,7 +932,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } minc(metricMessagesPublishedSuccess) - return s.writeJSON(w, m.forJSON()) + return s.writeJSON(w, m.ForJSON()) } func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { @@ -1014,10 +1022,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v * s.mu.Lock() s.messages++ s.mu.Unlock() - return s.writeJSON(w, m.forJSON()) + return s.writeJSON(w, m.ForJSON()) } -func (s *Server) sendToFirebase(v *visitor, m *message) { +func (s *Server) sendToFirebase(v *visitor, m *model.Message) { logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase") if err := s.firebaseClient.Send(v, m); err != nil { minc(metricFirebasePublishedFailure) @@ -1031,7 +1039,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) { minc(metricFirebasePublishedSuccess) } -func (s *Server) sendEmail(v *visitor, m *message, email string) { +func (s *Server) sendEmail(v *visitor, m *model.Message, email string) { logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email) if err := s.smtpSender.Send(v, m, email); err != nil { logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error()) @@ -1041,7 +1049,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) { minc(metricEmailsPublishedSuccess) } -func (s *Server) forwardPollRequest(v *visitor, m *message) { +func (s *Server) forwardPollRequest(v *visitor, m *model.Message) { topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic) topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL))) forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash) @@ -1073,7 +1081,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) { } } -func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) { +func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) { if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) { pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path) if err != nil { @@ -1100,7 +1108,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi filename := readParam(r, "x-filename", "filename", "file", "f") attach := readParam(r, "x-attach", "attach", "a") if attach != "" || filename != "" { - m.Attachment = &attachment{} + m.Attachment = &model.Attachment{} } if filename != "" { m.Attachment.Name = filename @@ -1221,7 +1229,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message // 7. curl -T file.txt ntfy.sh/mytopic // In all other cases, mostly if file.txt is > message limit, treat it as an attachment -func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error { +func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error { if m.Event == pollRequestEvent { // Case 1 return s.handleBodyDiscard(body) } else if unifiedpush { @@ -1244,7 +1252,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error { return err } -func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error { +func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error { if utf8.Valid(body.PeekedBytes) { m.Message = string(body.PeekedBytes) // Do not trim } else { @@ -1254,7 +1262,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead return nil } -func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error { +func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error { if !utf8.Valid(body.PeekedBytes) { return errHTTPBadRequestMessageNotUTF8.With(m) } @@ -1267,7 +1275,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser return nil } -func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error { +func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error { body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit)) if err != nil { return err @@ -1292,7 +1300,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM // renderTemplateFromFile transforms the JSON message body according to a template from the filesystem. // The template file must be in the templates directory, or in the configured template directory. -func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error { +func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error { if !templateNameRegex.MatchString(templateName) { return errHTTPBadRequestTemplateFileNotFound } @@ -1334,7 +1342,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str // renderTemplateFromParams transforms the JSON message body according to the inline template in the // message, title, and priority parameters. -func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error { +func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error { var err error if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil { return err @@ -1375,7 +1383,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) { return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines } -func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error { +func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error { if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" { return errHTTPBadRequestAttachmentsDisallowed.With(m) } @@ -1399,7 +1407,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, } } if m.Attachment == nil { - m.Attachment = &attachment{} + m.Attachment = &model.Attachment{} } var ext string m.Attachment.Expires = attachmentExpiry @@ -1426,9 +1434,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, } func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error { - encoder := func(msg *message) (string, error) { + encoder := func(msg *model.Message) (string, error) { var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil { + if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil { return "", err } return buf.String(), nil @@ -1437,9 +1445,9 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v * } func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error { - encoder := func(msg *message) (string, error) { + encoder := func(msg *model.Message) (string, error) { var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil { + if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil { return "", err } if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent { @@ -1451,7 +1459,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v } func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error { - encoder := func(msg *message) (string, error) { + encoder := func(msg *model.Message) (string, error) { if msg.Event == messageEvent { // only handle default events return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil } @@ -1487,7 +1495,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v * closed = true wlock.Unlock() }() - sub := func(v *visitor, msg *message) error { + sub := func(v *visitor, msg *model.Message) error { if !filters.Pass(msg) { return nil } @@ -1649,7 +1657,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi } } }) - sub := func(v *visitor, msg *message) error { + sub := func(v *visitor, msg *model.Message) error { if !filters.Pass(msg) { return nil } @@ -1696,7 +1704,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi return nil } -func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) { +func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) { poll = readBoolParam(r, false, "x-poll", "poll", "po") scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched") since, err = parseSince(r, poll) @@ -1777,11 +1785,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the // marker, returning only messages that are newer than the marker. -func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error { +func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error { if since.IsNone() { return nil } - messages := make([]*message, 0) + messages := make([]*model.Message, 0) for _, t := range topics { topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled) if err != nil { @@ -1804,7 +1812,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b // // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), // "all" for all messages, or "latest" for the most recent message for a topic -func parseSince(r *http.Request, poll bool) (sinceMarker, error) { +func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) { since := readParam(r, "x-since", "since", "si") // Easy cases (empty, all, none) @@ -2035,7 +2043,7 @@ func (s *Server) sendDelayedMessages() error { return nil } -func (s *Server) sendDelayedMessage(v *visitor, m *message) error { +func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error { logvm(v, m).Debug("Sending delayed message") s.mu.RLock() t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published diff --git a/server/server_firebase.go b/server/server_firebase.go index 9fde63a3..ad206e13 100644 --- a/server/server_firebase.go +++ b/server/server_firebase.go @@ -10,6 +10,7 @@ import ( "firebase.google.com/go/v4/messaging" "fmt" "google.golang.org/api/option" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" "strings" @@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien } } -func (c *firebaseClient) Send(v *visitor, m *message) error { +func (c *firebaseClient) Send(v *visitor, m *model.Message) error { if !v.FirebaseAllowed() { return errFirebaseTemporarilyBanned } @@ -121,7 +122,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error { // On Android, this will trigger the app to poll the topic and thereby displaying new messages. // - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded // to Firebase here. This is mainly for iOS to support self-hosted servers. -func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) { +func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) { var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format var apnsConfig *messaging.APNSConfig switch m.Event { @@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message { // createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS). // We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service // Extension in iOS can modify the message. -func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig { +func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig { apnsData := make(map[string]any) for k, v := range data { apnsData[k] = v @@ -296,7 +297,7 @@ func maybeTruncateAPNSBodyMessage(s string) string { // // This empties all the fields that are not needed for a poll request and just sets the required fields, // most importantly, the PollID. -func toPollRequest(m *message) *message { +func toPollRequest(m *model.Message) *model.Message { pr := newPollRequestMessage(m.Topic, m.ID) pr.ID = m.ID pr.Time = m.Time diff --git a/server/server_firebase_dummy.go b/server/server_firebase_dummy.go index bddceff1..6b026075 100644 --- a/server/server_firebase_dummy.go +++ b/server/server_firebase_dummy.go @@ -4,6 +4,7 @@ package server import ( "errors" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" ) @@ -21,7 +22,7 @@ var ( type firebaseClient struct { } -func (c *firebaseClient) Send(v *visitor, m *message) error { +func (c *firebaseClient) Send(v *visitor, m *model.Message) error { return errFirebaseNotAvailable } diff --git a/server/server_firebase_test.go b/server/server_firebase_test.go index c98f528f..ab4b494c 100644 --- a/server/server_firebase_test.go +++ b/server/server_firebase_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" "net/netip" "strings" @@ -131,7 +132,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) { m.Click = "https://google.com" m.Icon = "https://ntfy.sh/static/img/ntfy.png" m.Title = "some title" - m.Actions = []*action{ + m.Actions = []*model.Action{ { ID: "123", Action: "view", @@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) { }, }, } - m.Attachment = &attachment{ + m.Attachment = &model.Attachment{ Name: "some file.jpg", Type: "image/jpeg", Size: 12345, @@ -346,16 +347,16 @@ func TestToFirebaseSender_Abuse(t *testing.T) { client := newFirebaseClient(sender, &testAuther{}) visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil) - require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) + require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"})) require.Equal(t, 1, len(sender.Messages())) - require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"})) + require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"})) require.Equal(t, 2, len(sender.Messages())) - require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"})) require.Equal(t, 2, len(sender.Messages())) sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working - require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"})) + require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"})) require.Equal(t, 0, len(sender.Messages())) } diff --git a/server/server_test.go b/server/server_test.go index a44d3880..23f72d1c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -25,6 +25,8 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/message" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -382,7 +384,7 @@ func TestServer_PublishAt(t *testing.T) { // Update message time to the past fakeTime := time.Now().Add(-10 * time.Second).Unix() - _, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) + _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime) require.Nil(t, err) // Trigger delayed message sending @@ -418,7 +420,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) { // Update message time to the past fakeTime := time.Now().Add(-10 * time.Second).Unix() - _, err := s.messageCache.db.Exec(`UPDATE messages SET time=?`, fakeTime) + _, err := s.messageCache.DB().Exec(`UPDATE messages SET time=?`, fakeTime) require.Nil(t, err) // Trigger delayed message sending @@ -596,8 +598,8 @@ func TestServer_PublishAndPollSince(t *testing.T) { require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code) } -func newMessageWithTimestamp(topic, message string, timestamp int64) *message { - m := newDefaultMessage(topic, message) +func newMessageWithTimestamp(topic, msg string, timestamp int64) *model.Message { + m := newDefaultMessage(topic, msg) m.Time = timestamp return m } @@ -1209,7 +1211,7 @@ type testMailer struct { mu sync.Mutex } -func (t *testMailer) Send(v *visitor, m *message, to string) error { +func (t *testMailer) Send(v *visitor, m *model.Message, to string) error { t.mu.Lock() defer t.mu.Unlock() t.count++ @@ -1414,7 +1416,7 @@ func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) { s := newTestServer(t, newTestConfig(t)) defer s.messageCache.Close() - subFn := func(v *visitor, msg *message) error { + subFn := func(v *visitor, msg *model.Message) error { return nil } @@ -2410,14 +2412,14 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { // Add lots of messages log.Info("Adding %d messages", count) start := time.Now() - messages := make([]*message, 0) + messages := make([]*model.Message, 0) for i := 0; i < count; i++ { topicID := fmt.Sprintf("topic%d", i) _, err := s.topicsFromIDs(topicID) // Add topic to internal s.topics array require.Nil(t, err) messages = append(messages, newDefaultMessage(topicID, "some message")) } - require.Nil(t, s.messageCache.addMessages(messages)) + require.Nil(t, s.messageCache.AddMessages(messages)) log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond)) // Update stats @@ -3763,6 +3765,12 @@ func TestServer_DeleteScheduledMessage_WithAttachment(t *testing.T) { require.NoFileExists(t, attachmentFile) } +func newMemTestCache(t *testing.T) message.Store { + c, err := message.NewMemStore() + require.Nil(t, err) + return c +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.BaseURL = "http://127.0.0.1:12345" @@ -3829,8 +3837,8 @@ func subscribe(t *testing.T, s *Server, url string, rr *httptest.ResponseRecorde return cancelAndWaitForDone } -func toMessages(t *testing.T, s string) []*message { - messages := make([]*message, 0) +func toMessages(t *testing.T, s string) []*model.Message { + messages := make([]*model.Message, 0) scanner := bufio.NewScanner(strings.NewReader(s)) for scanner.Scan() { messages = append(messages, toMessage(t, scanner.Text())) @@ -3838,8 +3846,8 @@ func toMessages(t *testing.T, s string) []*message { return messages } -func toMessage(t *testing.T, s string) *message { - var m message +func toMessage(t *testing.T, s string) *model.Message { + var m model.Message require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&m)) return &m } diff --git a/server/server_twilio.go b/server/server_twilio.go index c1761613..8bce2b90 100644 --- a/server/server_twilio.go +++ b/server/server_twilio.go @@ -11,6 +11,7 @@ import ( "text/template" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, * // callPhone calls the Twilio API to make a phone call to the given phone number, using the given message. // Failures will be logged, but not returned to the caller. -func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) { +func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) { u, sender := v.User(), m.Sender.String() if u != nil { sender = u.Name diff --git a/server/server_webpush.go b/server/server_webpush.go index 11e37f66..f98b8e91 100644 --- a/server/server_webpush.go +++ b/server/server_webpush.go @@ -11,6 +11,7 @@ import ( "github.com/SherClockHolmes/webpush-go" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" wpush "heckel.io/ntfy/v2/webpush" ) @@ -83,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ * return s.writeJSON(w, newSuccessResponse()) } -func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { +func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) { subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic) if err != nil { logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages") return } log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions)) - payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON())) + payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON())) if err != nil { log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload") return diff --git a/server/server_webpush_dummy.go b/server/server_webpush_dummy.go index 425b22a6..70c2df22 100644 --- a/server/server_webpush_dummy.go +++ b/server/server_webpush_dummy.go @@ -4,6 +4,8 @@ package server import ( "net/http" + + "heckel.io/ntfy/v2/model" ) const ( @@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ * return errHTTPNotFound } -func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) { +func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) { // Nothing to see here } diff --git a/server/smtp_sender.go b/server/smtp_sender.go index 0f798030..4e5988ba 100644 --- a/server/smtp_sender.go +++ b/server/smtp_sender.go @@ -12,11 +12,12 @@ import ( "time" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" ) type mailer interface { - Send(v *visitor, m *message, to string) error + Send(v *visitor, m *model.Message, to string) error Counts() (total int64, success int64, failure int64) } @@ -27,7 +28,7 @@ type smtpSender struct { mu sync.Mutex } -func (s *smtpSender) Send(v *visitor, m *message, to string) error { +func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error { return s.withCount(v, m, func() error { host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr) if err != nil { @@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) { return s.success + s.failure, s.success, s.failure } -func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error { +func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error { err := fn() s.mu.Lock() defer s.mu.Unlock() @@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error { return err } -func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) { +func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) { topicURL := baseURL + "/" + m.Topic subject := m.Title if subject == "" { diff --git a/server/smtp_sender_test.go b/server/smtp_sender_test.go index 782c0f97..4f97b128 100644 --- a/server/smtp_sender_test.go +++ b/server/smtp_sender_test.go @@ -1,12 +1,14 @@ package server import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/model" ) func TestFormatMail_Basic(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", @@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt } func TestFormatMail_JustEmojis(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", @@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt } func TestFormatMail_JustOtherTags(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", @@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt } func TestFormatMail_JustPriority(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", @@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt } func TestFormatMail_UTF8Subject(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", @@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt } func TestFormatMail_WithAllTheThings(t *testing.T) { - actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{ + actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{ ID: "abc", Time: 1640382204, Event: "message", diff --git a/server/smtp_server.go b/server/smtp_server.go index ee28efc2..e342e678 100644 --- a/server/smtp_server.go +++ b/server/smtp_server.go @@ -19,6 +19,7 @@ import ( "github.com/emersion/go-smtp" "github.com/microcosm-cc/bluemonday" + "heckel.io/ntfy/v2/model" ) var ( @@ -183,7 +184,7 @@ func (s *smtpSession) Data(r io.Reader) error { }) } -func (s *smtpSession) publishMessage(m *message) error { +func (s *smtpSession) publishMessage(m *model.Message) error { // Extract remote address (for rate limiting) remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String()) if err != nil { diff --git a/server/topic.go b/server/topic.go index 49def94b..f373a5e6 100644 --- a/server/topic.go +++ b/server/topic.go @@ -6,6 +6,7 @@ import ( "time" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" ) @@ -33,7 +34,7 @@ type topicSubscriber struct { } // subscriber is a function that is called for every new message on a topic -type subscriber func(v *visitor, msg *message) error +type subscriber func(v *visitor, msg *model.Message) error // newTopic creates a new topic func newTopic(id string) *topic { @@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) { } // Publish asynchronously publishes to all subscribers -func (t *topic) Publish(v *visitor, m *message) error { +func (t *topic) Publish(v *visitor, m *model.Message) error { go func() { // We want to lock the topic as short as possible, so we make a shallow copy of the // subscribers map here. Actually sending out the messages then doesn't have to lock. diff --git a/server/topic_test.go b/server/topic_test.go index 0376c942..f526921d 100644 --- a/server/topic_test.go +++ b/server/topic_test.go @@ -7,10 +7,11 @@ import ( "time" "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/model" ) func TestTopic_CancelSubscribersExceptUser(t *testing.T) { - subFn := func(v *visitor, msg *message) error { + subFn := func(v *visitor, msg *model.Message) error { return nil } canceled1 := atomic.Bool{} @@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) { func TestTopic_CancelSubscribersUser(t *testing.T) { t.Parallel() - subFn := func(v *visitor, msg *message) error { + subFn := func(v *visitor, msg *model.Message) error { return nil } canceled1 := atomic.Bool{} @@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) { cancel: func() {}, } - subFn := func(v *visitor, msg *message) error { + subFn := func(v *visitor, msg *model.Message) error { return nil } diff --git a/server/types.go b/server/types.go index dfd797e5..ade191a2 100644 --- a/server/types.go +++ b/server/types.go @@ -2,109 +2,52 @@ package server import ( "net/http" - "net/netip" - "time" - "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) -// List of possible events +// Event constants const ( - openEvent = "open" - keepaliveEvent = "keepalive" - messageEvent = "message" - messageDeleteEvent = "message_delete" - messageClearEvent = "message_clear" - pollRequestEvent = "poll_request" + openEvent = model.OpenEvent + keepaliveEvent = model.KeepaliveEvent + messageEvent = model.MessageEvent + messageDeleteEvent = model.MessageDeleteEvent + messageClearEvent = model.MessageClearEvent + pollRequestEvent = model.PollRequestEvent + messageIDLength = model.MessageIDLength ) -const ( - messageIDLength = 12 +// Sentinel values and errors +var ( + sinceAllMessages = model.SinceAllMessages + sinceNoMessages = model.SinceNoMessages + sinceLatestMessage = model.SinceLatestMessage + errUnexpectedMessageType = model.ErrUnexpectedMessageType + errMessageNotFound = model.ErrMessageNotFound ) -// message represents a message published to a topic -type message struct { - ID string `json:"id"` // Random message ID - SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID) - Time int64 `json:"time"` // Unix time in seconds - Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive) - Event string `json:"event"` // One of the above - Topic string `json:"topic"` - Title string `json:"title,omitempty"` - Message string `json:"message,omitempty"` - Priority int `json:"priority,omitempty"` - Tags []string `json:"tags,omitempty"` - Click string `json:"click,omitempty"` - Icon string `json:"icon,omitempty"` - Actions []*action `json:"actions,omitempty"` - Attachment *attachment `json:"attachment,omitempty"` - PollID string `json:"poll_id,omitempty"` - ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown - Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes - Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting - User string `json:"-"` // UserID of the uploader, used to associated attachments -} +// Constructors and helpers +var ( + newMessage = model.NewMessage + newDefaultMessage = model.NewDefaultMessage + newOpenMessage = model.NewOpenMessage + newKeepaliveMessage = model.NewKeepaliveMessage + newActionMessage = model.NewActionMessage + newAction = model.NewAction + newSinceTime = model.NewSinceTime + newSinceID = model.NewSinceID + validMessageID = model.ValidMessageID +) -func (m *message) Context() log.Context { - fields := map[string]any{ - "topic": m.Topic, - "message_id": m.ID, - "message_sequence_id": m.SequenceID, - "message_time": m.Time, - "message_event": m.Event, - "message_body_size": len(m.Message), - } - if m.Sender.IsValid() { - fields["message_sender"] = m.Sender.String() - } - if m.User != "" { - fields["message_user"] = m.User - } - return fields -} - -// forJSON returns a copy of the message suitable for JSON output. -// It clears the SequenceID if it equals the ID to reduce redundancy. -func (m *message) forJSON() *message { - if m.SequenceID == m.ID { - clone := *m - clone.SequenceID = "" - return &clone - } +// newPollRequestMessage is a convenience method to create a poll request message +func newPollRequestMessage(topic, pollID string) *model.Message { + m := newMessage(pollRequestEvent, topic, newMessageBody) + m.PollID = pollID return m } -type attachment struct { - Name string `json:"name"` - Type string `json:"type,omitempty"` - Size int64 `json:"size,omitempty"` - Expires int64 `json:"expires,omitempty"` - URL string `json:"url"` -} - -type action struct { - ID string `json:"id"` - Action string `json:"action"` // "view", "broadcast", "http", or "copy" - Label string `json:"label"` // action button label - Clear bool `json:"clear"` // clear notification after successful execution - URL string `json:"url,omitempty"` // used in "view" and "http" actions - Method string `json:"method,omitempty"` // used in "http" action, default is POST (!) - Headers map[string]string `json:"headers,omitempty"` // used in "http" action - Body string `json:"body,omitempty"` // used in "http" action - Intent string `json:"intent,omitempty"` // used in "broadcast" action - Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action - Value string `json:"value,omitempty"` // used in "copy" action -} - -func newAction() *action { - return &action{ - Headers: make(map[string]string), - Extras: make(map[string]string), - } -} - // publishMessage is used as input when publishing as JSON type publishMessage struct { Topic string `json:"topic"` @@ -115,7 +58,7 @@ type publishMessage struct { Tags []string `json:"tags"` Click string `json:"click"` Icon string `json:"icon"` - Actions []action `json:"actions"` + Actions []model.Action `json:"actions"` Attach string `json:"attach"` Markdown bool `json:"markdown"` Filename string `json:"filename"` @@ -127,94 +70,7 @@ type publishMessage struct { } // messageEncoder is a function that knows how to encode a message -type messageEncoder func(msg *message) (string, error) - -// newMessage creates a new message with the current timestamp -func newMessage(event, topic, msg string) *message { - return &message{ - ID: util.RandomString(messageIDLength), - Time: time.Now().Unix(), - Event: event, - Topic: topic, - Message: msg, - } -} - -// newOpenMessage is a convenience method to create an open message -func newOpenMessage(topic string) *message { - return newMessage(openEvent, topic, "") -} - -// newKeepaliveMessage is a convenience method to create a keepalive message -func newKeepaliveMessage(topic string) *message { - return newMessage(keepaliveEvent, topic, "") -} - -// newDefaultMessage is a convenience method to create a notification message -func newDefaultMessage(topic, msg string) *message { - return newMessage(messageEvent, topic, msg) -} - -// newPollRequestMessage is a convenience method to create a poll request message -func newPollRequestMessage(topic, pollID string) *message { - m := newMessage(pollRequestEvent, topic, newMessageBody) - m.PollID = pollID - return m -} - -// newActionMessage creates a new action message (message_delete or message_clear) -func newActionMessage(event, topic, sequenceID string) *message { - m := newMessage(event, topic, "") - m.SequenceID = sequenceID - return m -} - -func validMessageID(s string) bool { - return util.ValidRandomString(s, messageIDLength) -} - -type sinceMarker struct { - time time.Time - id string -} - -func newSinceTime(timestamp int64) sinceMarker { - return sinceMarker{time.Unix(timestamp, 0), ""} -} - -func newSinceID(id string) sinceMarker { - return sinceMarker{time.Unix(0, 0), id} -} - -func (t sinceMarker) IsAll() bool { - return t == sinceAllMessages -} - -func (t sinceMarker) IsNone() bool { - return t == sinceNoMessages -} - -func (t sinceMarker) IsLatest() bool { - return t == sinceLatestMessage -} - -func (t sinceMarker) IsID() bool { - return t.id != "" && t.id != "latest" -} - -func (t sinceMarker) Time() time.Time { - return t.time -} - -func (t sinceMarker) ID() string { - return t.id -} - -var ( - sinceAllMessages = sinceMarker{time.Unix(0, 0), ""} - sinceNoMessages = sinceMarker{time.Unix(1, 0), ""} - sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"} -) +type messageEncoder func(msg *model.Message) (string, error) type queryFilter struct { ID string @@ -246,7 +102,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) { }, nil } -func (q *queryFilter) Pass(msg *message) bool { +func (q *queryFilter) Pass(msg *model.Message) bool { if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent { return true // filters only apply to messages } else if q.ID != "" && msg.ID != q.ID { @@ -572,10 +428,10 @@ const ( type webPushPayload struct { Event string `json:"event"` SubscriptionID string `json:"subscription_id"` - Message *message `json:"message"` + Message *model.Message `json:"message"` } -func newWebPushPayload(subscriptionID string, message *message) *webPushPayload { +func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload { return &webPushPayload{ Event: webPushMessageEvent, SubscriptionID: subscriptionID, diff --git a/server/visitor.go b/server/visitor.go index f26729f1..12217be6 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -8,6 +8,7 @@ import ( "golang.org/x/time/rate" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/message" "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -53,7 +54,7 @@ const ( // visitor represents an API user, and its associated rate.Limiter used for rate limiting type visitor struct { config *Config - messageCache *messageCache + messageCache message.Store userManager *user.Manager // May be nil ip netip.Addr // Visitor IP address user *user.User // Only set if authenticated user, otherwise nil @@ -114,7 +115,7 @@ const ( visitorLimitBasisTier = visitorLimitBasis("tier") ) -func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { +func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor { var messages, emails, calls int64 if user != nil { messages = user.Stats.Messages diff --git a/user/manager_test.go b/user/manager_test.go index 4b929d39..13cff861 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -691,7 +691,6 @@ func TestManager_Token_Expire(t *testing.T) { // But the token row should still exist tokens, err := a.Tokens(u.ID) require.Nil(t, err) - require.Equal(t, token1.Value, tokens[0].Value) require.Equal(t, 2, len(tokens)) // Expire tokens and check that token1 is gone