607 lines
15 KiB
Go
607 lines
15 KiB
Go
package message
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/netip"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"heckel.io/ntfy/v2/log"
|
|
"heckel.io/ntfy/v2/model"
|
|
"heckel.io/ntfy/v2/util"
|
|
)
|
|
|
|
const (
|
|
tagMessageCache = "message_cache"
|
|
)
|
|
|
|
var errNoRows = errors.New("no rows found")
|
|
|
|
// queries holds the database-specific SQL queries
|
|
type queries struct {
|
|
insertMessage string
|
|
deleteMessage string
|
|
selectScheduledMessageIDsBySeqID string
|
|
deleteScheduledBySequenceID string
|
|
updateMessagesForTopicExpiry string
|
|
selectMessagesByID string
|
|
selectMessagesSinceTime string
|
|
selectMessagesSinceTimeScheduled string
|
|
selectMessagesSinceID string
|
|
selectMessagesSinceIDScheduled string
|
|
selectMessagesLatest string
|
|
selectMessagesDue string
|
|
selectMessagesExpired string
|
|
updateMessagePublished string
|
|
selectMessagesCount string
|
|
selectTopics string
|
|
updateAttachmentDeleted string
|
|
selectAttachmentsExpired string
|
|
selectAttachmentsSizeBySender string
|
|
selectAttachmentsSizeByUserID string
|
|
selectStats string
|
|
updateStats string
|
|
updateMessageTime string
|
|
}
|
|
|
|
// Cache stores published messages
|
|
type Cache struct {
|
|
db *sql.DB
|
|
queue *util.BatchingQueue[*model.Message]
|
|
nop bool
|
|
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
|
|
queries queries
|
|
}
|
|
|
|
func newCache(db *sql.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
|
|
var queue *util.BatchingQueue[*model.Message]
|
|
if batchSize > 0 || batchTimeout > 0 {
|
|
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
|
}
|
|
c := &Cache{
|
|
db: db,
|
|
queue: queue,
|
|
nop: nop,
|
|
mu: mu,
|
|
queries: queries,
|
|
}
|
|
go c.processMessageBatches()
|
|
return c
|
|
}
|
|
|
|
func (c *Cache) maybeLock() {
|
|
if c.mu != nil {
|
|
c.mu.Lock()
|
|
}
|
|
}
|
|
|
|
func (c *Cache) maybeUnlock() {
|
|
if c.mu != nil {
|
|
c.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
|
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
|
|
func (c *Cache) AddMessage(m *model.Message) error {
|
|
if c.queue != nil {
|
|
c.queue.Enqueue(m)
|
|
return nil
|
|
}
|
|
return c.addMessages([]*model.Message{m})
|
|
}
|
|
|
|
// AddMessages synchronously stores a batch of messages to the message cache
|
|
func (c *Cache) AddMessages(ms []*model.Message) error {
|
|
return c.addMessages(ms)
|
|
}
|
|
|
|
func (c *Cache) addMessages(ms []*model.Message) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
if c.nop {
|
|
return nil
|
|
}
|
|
if len(ms) == 0 {
|
|
return nil
|
|
}
|
|
start := time.Now()
|
|
tx, err := c.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
stmt, err := tx.Prepare(c.queries.insertMessage)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
for _, m := range ms {
|
|
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
|
|
return model.ErrUnexpectedMessageType
|
|
}
|
|
published := m.Time <= time.Now().Unix()
|
|
tags := strings.Join(m.Tags, ",")
|
|
var attachmentName, attachmentType, attachmentURL string
|
|
var attachmentSize, attachmentExpires int64
|
|
var attachmentDeleted bool
|
|
if m.Attachment != nil {
|
|
attachmentName = m.Attachment.Name
|
|
attachmentType = m.Attachment.Type
|
|
attachmentSize = m.Attachment.Size
|
|
attachmentExpires = m.Attachment.Expires
|
|
attachmentURL = m.Attachment.URL
|
|
}
|
|
var actionsStr string
|
|
if len(m.Actions) > 0 {
|
|
actionsBytes, err := json.Marshal(m.Actions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
actionsStr = string(actionsBytes)
|
|
}
|
|
var sender string
|
|
if m.Sender.IsValid() {
|
|
sender = m.Sender.String()
|
|
}
|
|
_, err := stmt.Exec(
|
|
m.ID,
|
|
m.SequenceID,
|
|
m.Time,
|
|
m.Event,
|
|
m.Expires,
|
|
m.Topic,
|
|
m.Message,
|
|
m.Title,
|
|
m.Priority,
|
|
tags,
|
|
m.Click,
|
|
m.Icon,
|
|
actionsStr,
|
|
attachmentName,
|
|
attachmentType,
|
|
attachmentSize,
|
|
attachmentExpires,
|
|
attachmentURL,
|
|
attachmentDeleted, // Always zero
|
|
sender,
|
|
m.User,
|
|
m.ContentType,
|
|
m.Encoding,
|
|
published,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
|
|
return err
|
|
}
|
|
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
|
|
return nil
|
|
}
|
|
|
|
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
|
|
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
|
if since.IsNone() {
|
|
return make([]*model.Message, 0), nil
|
|
} else if since.IsLatest() {
|
|
return c.messagesLatest(topic)
|
|
} else if since.IsID() {
|
|
return c.messagesSinceID(topic, since, scheduled)
|
|
}
|
|
return c.messagesSinceTime(topic, since, scheduled)
|
|
}
|
|
|
|
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
if scheduled {
|
|
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
|
} else {
|
|
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return readMessages(rows)
|
|
}
|
|
|
|
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
if scheduled {
|
|
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
|
|
} else {
|
|
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, since.ID())
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return readMessages(rows)
|
|
}
|
|
|
|
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
|
|
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return readMessages(rows)
|
|
}
|
|
|
|
// MessagesDue returns all messages that are due for publishing
|
|
func (c *Cache) MessagesDue() ([]*model.Message, error) {
|
|
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return readMessages(rows)
|
|
}
|
|
|
|
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
|
func (c *Cache) MessagesExpired() ([]string, error) {
|
|
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
ids := make([]string, 0)
|
|
for rows.Next() {
|
|
var id string
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
|
func (c *Cache) Message(id string) (*model.Message, error) {
|
|
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !rows.Next() {
|
|
return nil, model.ErrMessageNotFound
|
|
}
|
|
defer rows.Close()
|
|
return readMessage(rows)
|
|
}
|
|
|
|
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
|
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
|
return err
|
|
}
|
|
|
|
// MarkPublished marks a message as published
|
|
func (c *Cache) MarkPublished(m *model.Message) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
|
return err
|
|
}
|
|
|
|
// MessagesCount returns the total number of messages in the cache
|
|
func (c *Cache) MessagesCount() (int, error) {
|
|
rows, err := c.db.Query(c.queries.selectMessagesCount)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return 0, errNoRows
|
|
}
|
|
var count int
|
|
if err := rows.Scan(&count); err != nil {
|
|
return 0, err
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// Topics returns a list of all topics with messages in the cache
|
|
func (c *Cache) Topics() ([]string, error) {
|
|
rows, err := c.db.Query(c.queries.selectTopics)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
topics := make([]string, 0)
|
|
for rows.Next() {
|
|
var id string
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
topics = append(topics, id)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return topics, nil
|
|
}
|
|
|
|
// DeleteMessages deletes the messages with the given IDs
|
|
func (c *Cache) DeleteMessages(ids ...string) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
tx, err := c.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
for _, id := range ids {
|
|
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
|
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
|
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
tx, err := c.db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback()
|
|
// First, get the message IDs of scheduled messages to be deleted
|
|
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
ids := make([]string, 0)
|
|
for rows.Next() {
|
|
var id string
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
rows.Close() // Close rows before executing delete in same transaction
|
|
// Then delete the messages
|
|
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, err
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
// ExpireMessages marks messages in the given topics as expired
|
|
func (c *Cache) ExpireMessages(topics ...string) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
tx, err := c.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
for _, t := range topics {
|
|
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
|
func (c *Cache) AttachmentsExpired() ([]string, error) {
|
|
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
ids := make([]string, 0)
|
|
for rows.Next() {
|
|
var id string
|
|
if err := rows.Scan(&id); err != nil {
|
|
return nil, err
|
|
}
|
|
ids = append(ids, id)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return ids, nil
|
|
}
|
|
|
|
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
|
|
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
tx, err := c.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
for _, id := range ids {
|
|
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
|
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c.readAttachmentBytesUsed(rows)
|
|
}
|
|
|
|
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
|
|
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return c.readAttachmentBytesUsed(rows)
|
|
}
|
|
|
|
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
|
defer rows.Close()
|
|
var size int64
|
|
if !rows.Next() {
|
|
return 0, errors.New("no rows found")
|
|
}
|
|
if err := rows.Scan(&size); err != nil {
|
|
return 0, err
|
|
} else if err := rows.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
return size, nil
|
|
}
|
|
|
|
// UpdateStats updates the total message count statistic
|
|
func (c *Cache) UpdateStats(messages int64) error {
|
|
c.maybeLock()
|
|
defer c.maybeUnlock()
|
|
_, err := c.db.Exec(c.queries.updateStats, messages)
|
|
return err
|
|
}
|
|
|
|
// Stats returns the total message count statistic
|
|
func (c *Cache) Stats() (messages int64, err error) {
|
|
rows, err := c.db.Query(c.queries.selectStats)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer rows.Close()
|
|
if !rows.Next() {
|
|
return 0, errNoRows
|
|
}
|
|
if err := rows.Scan(&messages); err != nil {
|
|
return 0, err
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
// Close closes the underlying database connection
|
|
func (c *Cache) Close() error {
|
|
return c.db.Close()
|
|
}
|
|
|
|
func (c *Cache) processMessageBatches() {
|
|
if c.queue == nil {
|
|
return
|
|
}
|
|
for messages := range c.queue.Dequeue() {
|
|
if err := c.addMessages(messages); err != nil {
|
|
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
|
}
|
|
}
|
|
}
|
|
|
|
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
|
|
defer rows.Close()
|
|
messages := make([]*model.Message, 0)
|
|
for rows.Next() {
|
|
m, err := readMessage(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, m)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func readMessage(rows *sql.Rows) (*model.Message, error) {
|
|
var timestamp, expires, attachmentSize, attachmentExpires int64
|
|
var priority int
|
|
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
|
err := rows.Scan(
|
|
&id,
|
|
&sequenceID,
|
|
×tamp,
|
|
&event,
|
|
&expires,
|
|
&topic,
|
|
&msg,
|
|
&title,
|
|
&priority,
|
|
&tagsStr,
|
|
&click,
|
|
&icon,
|
|
&actionsStr,
|
|
&attachmentName,
|
|
&attachmentType,
|
|
&attachmentSize,
|
|
&attachmentExpires,
|
|
&attachmentURL,
|
|
&sender,
|
|
&user,
|
|
&contentType,
|
|
&encoding,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var tags []string
|
|
if tagsStr != "" {
|
|
tags = strings.Split(tagsStr, ",")
|
|
}
|
|
var actions []*model.Action
|
|
if actionsStr != "" {
|
|
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
senderIP, err := netip.ParseAddr(sender)
|
|
if err != nil {
|
|
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
|
}
|
|
var att *model.Attachment
|
|
if attachmentName != "" && attachmentURL != "" {
|
|
att = &model.Attachment{
|
|
Name: attachmentName,
|
|
Type: attachmentType,
|
|
Size: attachmentSize,
|
|
Expires: attachmentExpires,
|
|
URL: attachmentURL,
|
|
}
|
|
}
|
|
return &model.Message{
|
|
ID: id,
|
|
SequenceID: sequenceID,
|
|
Time: timestamp,
|
|
Expires: expires,
|
|
Event: event,
|
|
Topic: topic,
|
|
Message: msg,
|
|
Title: title,
|
|
Priority: priority,
|
|
Tags: tags,
|
|
Click: click,
|
|
Icon: icon,
|
|
Actions: actions,
|
|
Attachment: att,
|
|
Sender: senderIP,
|
|
User: user,
|
|
ContentType: contentType,
|
|
Encoding: encoding,
|
|
}, nil
|
|
}
|