package message import ( "database/sql" "encoding/json" "errors" "fmt" "net/netip" "strings" "sync" "time" "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/util" ) const ( tagMessageCache = "message_cache" ) var errNoRows = errors.New("no rows found") // Store is the interface for a message cache store type Store interface { AddMessage(m *model.Message) error AddMessages(ms []*model.Message) error DB() *sql.DB Message(id string) (*model.Message, error) MessageCounts() (map[string]int, error) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) MessagesDue() ([]*model.Message, error) MessagesExpired() ([]string, error) MarkPublished(m *model.Message) error UpdateMessageTime(messageID string, timestamp int64) error Topics() ([]string, error) DeleteMessages(ids ...string) error DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) ExpireMessages(topics ...string) error AttachmentsExpired() ([]string, error) MarkAttachmentsDeleted(ids ...string) error AttachmentBytesUsedBySender(sender string) (int64, error) AttachmentBytesUsedByUser(userID string) (int64, error) UpdateStats(messages int64) error Stats() (int64, error) Close() error } // storeQueries holds the database-specific SQL queries type storeQueries struct { insertMessage string deleteMessage string selectScheduledMessageIDsBySeqID string deleteScheduledBySequenceID string updateMessagesForTopicExpiry string selectRowIDFromMessageID string selectMessagesByID string selectMessagesSinceTime string selectMessagesSinceTimeScheduled string selectMessagesSinceID string selectMessagesSinceIDScheduled string selectMessagesLatest string selectMessagesDue string selectMessagesExpired string updateMessagePublished string selectMessagesCount string selectMessageCountPerTopic string selectTopics string updateAttachmentDeleted string selectAttachmentsExpired string selectAttachmentsSizeBySender string selectAttachmentsSizeByUserID string selectStats string updateStats string updateMessageTime string } // commonStore implements store operations that are identical across database backends type commonStore struct { db *sql.DB queue *util.BatchingQueue[*model.Message] nop bool mu sync.Mutex queries storeQueries } func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore { var queue *util.BatchingQueue[*model.Message] if batchSize > 0 || batchTimeout > 0 { queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout) } c := &commonStore{ db: db, queue: queue, nop: nop, queries: queries, } go c.processMessageBatches() return c } // DB returns the underlying database connection func (c *commonStore) DB() *sql.DB { return c.db } // AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously. func (c *commonStore) AddMessage(m *model.Message) error { if c.queue != nil { c.queue.Enqueue(m) return nil } return c.addMessages([]*model.Message{m}) } // AddMessages synchronously stores a batch of messages to the message cache func (c *commonStore) AddMessages(ms []*model.Message) error { return c.addMessages(ms) } func (c *commonStore) addMessages(ms []*model.Message) error { c.mu.Lock() defer c.mu.Unlock() if c.nop { return nil } if len(ms) == 0 { return nil } start := time.Now() tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() stmt, err := tx.Prepare(c.queries.insertMessage) if err != nil { return err } defer stmt.Close() for _, m := range ms { if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent { return model.ErrUnexpectedMessageType } published := m.Time <= time.Now().Unix() tags := strings.Join(m.Tags, ",") var attachmentName, attachmentType, attachmentURL string var attachmentSize, attachmentExpires int64 var attachmentDeleted bool if m.Attachment != nil { attachmentName = m.Attachment.Name attachmentType = m.Attachment.Type attachmentSize = m.Attachment.Size attachmentExpires = m.Attachment.Expires attachmentURL = m.Attachment.URL } var actionsStr string if len(m.Actions) > 0 { actionsBytes, err := json.Marshal(m.Actions) if err != nil { return err } actionsStr = string(actionsBytes) } var sender string if m.Sender.IsValid() { sender = m.Sender.String() } _, err := stmt.Exec( m.ID, m.SequenceID, m.Time, m.Event, m.Expires, m.Topic, m.Message, m.Title, m.Priority, tags, m.Click, m.Icon, actionsStr, attachmentName, attachmentType, attachmentSize, attachmentExpires, attachmentURL, attachmentDeleted, // Always zero sender, m.User, m.ContentType, m.Encoding, published, ) if err != nil { return err } } if err := tx.Commit(); err != nil { log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start)) return err } log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start)) return nil } func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { if since.IsNone() { return make([]*model.Message, 0), nil } else if since.IsLatest() { return c.messagesLatest(topic) } else if since.IsID() { return c.messagesSinceID(topic, since, scheduled) } return c.messagesSinceTime(topic, since, scheduled) } func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { var rows *sql.Rows var err error if scheduled { rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix()) } else { rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix()) } if err != nil { return nil, err } return readMessages(rows) } func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) { idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID()) if err != nil { return nil, err } defer idrows.Close() if !idrows.Next() { return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled) } var rowID int64 if err := idrows.Scan(&rowID); err != nil { return nil, err } idrows.Close() var rows *sql.Rows if scheduled { rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID) } else { rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID) } if err != nil { return nil, err } return readMessages(rows) } func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesLatest, topic) if err != nil { return nil, err } return readMessages(rows) } func (c *commonStore) MessagesDue() ([]*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix()) if err != nil { return nil, err } return readMessages(rows) } // MessagesExpired returns a list of IDs for messages that have expired (should be deleted) func (c *commonStore) MessagesExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix()) if err != nil { return nil, err } defer rows.Close() ids := make([]string, 0) for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, err } ids = append(ids, id) } if err := rows.Err(); err != nil { return nil, err } return ids, nil } func (c *commonStore) Message(id string) (*model.Message, error) { rows, err := c.db.Query(c.queries.selectMessagesByID, id) if err != nil { return nil, err } if !rows.Next() { return nil, model.ErrMessageNotFound } defer rows.Close() return readMessage(rows) } // UpdateMessageTime updates the time column for a message by ID. This is only used for testing. func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error { _, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID) return err } func (c *commonStore) MarkPublished(m *model.Message) error { c.mu.Lock() defer c.mu.Unlock() _, err := c.db.Exec(c.queries.updateMessagePublished, m.ID) return err } func (c *commonStore) MessageCounts() (map[string]int, error) { rows, err := c.db.Query(c.queries.selectMessageCountPerTopic) if err != nil { return nil, err } defer rows.Close() var topic string var count int counts := make(map[string]int) for rows.Next() { if err := rows.Scan(&topic, &count); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } counts[topic] = count } return counts, nil } func (c *commonStore) Topics() ([]string, error) { rows, err := c.db.Query(c.queries.selectTopics) if err != nil { return nil, err } defer rows.Close() topics := make([]string, 0) for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, err } topics = append(topics, id) } if err := rows.Err(); err != nil { return nil, err } return topics, nil } func (c *commonStore) DeleteMessages(ids ...string) error { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, id := range ids { if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil { return err } } return tx.Commit() } // DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. // It returns the message IDs of the deleted messages, which can be used to clean up attachment files. func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Begin() if err != nil { return nil, err } defer tx.Rollback() rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID) if err != nil { return nil, err } defer rows.Close() ids := make([]string, 0) for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, err } ids = append(ids, id) } if err := rows.Err(); err != nil { return nil, err } rows.Close() if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil { return nil, err } if err := tx.Commit(); err != nil { return nil, err } return ids, nil } func (c *commonStore) ExpireMessages(topics ...string) error { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, t := range topics { if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil { return err } } return tx.Commit() } func (c *commonStore) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) if err != nil { return nil, err } defer rows.Close() ids := make([]string, 0) for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, err } ids = append(ids, id) } if err := rows.Err(); err != nil { return nil, err } return ids, nil } func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error { c.mu.Lock() defer c.mu.Unlock() tx, err := c.db.Begin() if err != nil { return err } defer tx.Rollback() for _, id := range ids { if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil { return err } } return tx.Commit() } func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) { rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix()) if err != nil { return 0, err } return c.readAttachmentBytesUsed(rows) } func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) { rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix()) if err != nil { return 0, err } return c.readAttachmentBytesUsed(rows) } func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) { defer rows.Close() var size int64 if !rows.Next() { return 0, errors.New("no rows found") } if err := rows.Scan(&size); err != nil { return 0, err } else if err := rows.Err(); err != nil { return 0, err } return size, nil } func (c *commonStore) UpdateStats(messages int64) error { c.mu.Lock() defer c.mu.Unlock() _, err := c.db.Exec(c.queries.updateStats, messages) return err } func (c *commonStore) Stats() (messages int64, err error) { rows, err := c.db.Query(c.queries.selectStats) if err != nil { return 0, err } defer rows.Close() if !rows.Next() { return 0, errNoRows } if err := rows.Scan(&messages); err != nil { return 0, err } return messages, nil } func (c *commonStore) Close() error { return c.db.Close() } func (c *commonStore) processMessageBatches() { if c.queue == nil { return } for messages := range c.queue.Dequeue() { if err := c.addMessages(messages); err != nil { log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch") } } } func readMessages(rows *sql.Rows) ([]*model.Message, error) { defer rows.Close() messages := make([]*model.Message, 0) for rows.Next() { m, err := readMessage(rows) if err != nil { return nil, err } messages = append(messages, m) } if err := rows.Err(); err != nil { return nil, err } return messages, nil } func readMessage(rows *sql.Rows) (*model.Message, error) { var timestamp, expires, attachmentSize, attachmentExpires int64 var priority int var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string err := rows.Scan( &id, &sequenceID, ×tamp, &event, &expires, &topic, &msg, &title, &priority, &tagsStr, &click, &icon, &actionsStr, &attachmentName, &attachmentType, &attachmentSize, &attachmentExpires, &attachmentURL, &sender, &user, &contentType, &encoding, ) if err != nil { return nil, err } var tags []string if tagsStr != "" { tags = strings.Split(tagsStr, ",") } var actions []*model.Action if actionsStr != "" { if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil { return nil, err } } senderIP, err := netip.ParseAddr(sender) if err != nil { senderIP = netip.Addr{} // if no IP stored in database, return invalid address } var att *model.Attachment if attachmentName != "" && attachmentURL != "" { att = &model.Attachment{ Name: attachmentName, Type: attachmentType, Size: attachmentSize, Expires: attachmentExpires, URL: attachmentURL, } } return &model.Message{ ID: id, SequenceID: sequenceID, Time: timestamp, Expires: expires, Event: event, Topic: topic, Message: msg, Title: title, Priority: priority, Tags: tags, Click: click, Icon: icon, Actions: actions, Attachment: att, Sender: senderIP, User: user, ContentType: contentType, Encoding: encoding, }, nil } // Ensure commonStore implements Store var _ Store = (*commonStore)(nil) // Needed by store.go but not part of Store interface; unused import guard var _ = fmt.Sprintf