Extract message cache into message/ package with model/ types
Move message cache from server/message_cache.go into a dedicated message/ package with Store interface, SQLite and PostgreSQL implementations. Extract shared types into model/ package.
This commit is contained in:
620
message/store.go
Normal file
620
message/store.go
Normal file
@@ -0,0 +1,620 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagMessageCache = "message_cache"
|
||||
)
|
||||
|
||||
var errNoRows = errors.New("no rows found")
|
||||
|
||||
// Store is the interface for a message cache store
|
||||
type Store interface {
|
||||
AddMessage(m *model.Message) error
|
||||
AddMessages(ms []*model.Message) error
|
||||
DB() *sql.DB
|
||||
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
||||
MessagesDue() ([]*model.Message, error)
|
||||
MessagesExpired() ([]string, error)
|
||||
Message(id string) (*model.Message, error)
|
||||
MarkPublished(m *model.Message) error
|
||||
MessageCounts() (map[string]int, error)
|
||||
Topics() ([]string, error)
|
||||
DeleteMessages(ids ...string) error
|
||||
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
|
||||
ExpireMessages(topics ...string) error
|
||||
AttachmentsExpired() ([]string, error)
|
||||
MarkAttachmentsDeleted(ids ...string) error
|
||||
AttachmentBytesUsedBySender(sender string) (int64, error)
|
||||
AttachmentBytesUsedByUser(userID string) (int64, error)
|
||||
UpdateStats(messages int64) error
|
||||
Stats() (int64, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// storeQueries holds the database-specific SQL queries
|
||||
type storeQueries struct {
|
||||
insertMessage string
|
||||
deleteMessage string
|
||||
selectScheduledMessageIDsBySeqID string
|
||||
deleteScheduledBySequenceID string
|
||||
updateMessagesForTopicExpiry string
|
||||
selectRowIDFromMessageID string
|
||||
selectMessagesByID string
|
||||
selectMessagesSinceTime string
|
||||
selectMessagesSinceTimeScheduled string
|
||||
selectMessagesSinceID string
|
||||
selectMessagesSinceIDScheduled string
|
||||
selectMessagesLatest string
|
||||
selectMessagesDue string
|
||||
selectMessagesExpired string
|
||||
updateMessagePublished string
|
||||
selectMessagesCount string
|
||||
selectMessageCountPerTopic string
|
||||
selectTopics string
|
||||
updateAttachmentDeleted string
|
||||
selectAttachmentsExpired string
|
||||
selectAttachmentsSizeBySender string
|
||||
selectAttachmentsSizeByUserID string
|
||||
selectStats string
|
||||
updateStats string
|
||||
}
|
||||
|
||||
// commonStore implements store operations that are identical across database backends
|
||||
type commonStore struct {
|
||||
db *sql.DB
|
||||
queue *util.BatchingQueue[*model.Message]
|
||||
nop bool
|
||||
mu sync.Mutex
|
||||
queries storeQueries
|
||||
}
|
||||
|
||||
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
|
||||
var queue *util.BatchingQueue[*model.Message]
|
||||
if batchSize > 0 || batchTimeout > 0 {
|
||||
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||
}
|
||||
c := &commonStore{
|
||||
db: db,
|
||||
queue: queue,
|
||||
nop: nop,
|
||||
queries: queries,
|
||||
}
|
||||
go c.processMessageBatches()
|
||||
return c
|
||||
}
|
||||
|
||||
// DB returns the underlying database connection
|
||||
func (c *commonStore) DB() *sql.DB {
|
||||
return c.db
|
||||
}
|
||||
|
||||
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||
func (c *commonStore) AddMessage(m *model.Message) error {
|
||||
if c.queue != nil {
|
||||
c.queue.Enqueue(m)
|
||||
return nil
|
||||
}
|
||||
return c.addMessages([]*model.Message{m})
|
||||
}
|
||||
|
||||
// AddMessages synchronously stores a batch of messages to the message cache
|
||||
func (c *commonStore) AddMessages(ms []*model.Message) error {
|
||||
return c.addMessages(ms)
|
||||
}
|
||||
|
||||
func (c *commonStore) addMessages(ms []*model.Message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.nop {
|
||||
return nil
|
||||
}
|
||||
if len(ms) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := time.Now()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
stmt, err := tx.Prepare(c.queries.insertMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
for _, m := range ms {
|
||||
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
|
||||
return model.ErrUnexpectedMessageType
|
||||
}
|
||||
published := m.Time <= time.Now().Unix()
|
||||
tags := strings.Join(m.Tags, ",")
|
||||
var attachmentName, attachmentType, attachmentURL string
|
||||
var attachmentSize, attachmentExpires int64
|
||||
var attachmentDeleted bool
|
||||
if m.Attachment != nil {
|
||||
attachmentName = m.Attachment.Name
|
||||
attachmentType = m.Attachment.Type
|
||||
attachmentSize = m.Attachment.Size
|
||||
attachmentExpires = m.Attachment.Expires
|
||||
attachmentURL = m.Attachment.URL
|
||||
}
|
||||
var actionsStr string
|
||||
if len(m.Actions) > 0 {
|
||||
actionsBytes, err := json.Marshal(m.Actions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionsStr = string(actionsBytes)
|
||||
}
|
||||
var sender string
|
||||
if m.Sender.IsValid() {
|
||||
sender = m.Sender.String()
|
||||
}
|
||||
_, err := stmt.Exec(
|
||||
m.ID,
|
||||
m.SequenceID,
|
||||
m.Time,
|
||||
m.Event,
|
||||
m.Expires,
|
||||
m.Topic,
|
||||
m.Message,
|
||||
m.Title,
|
||||
m.Priority,
|
||||
tags,
|
||||
m.Click,
|
||||
m.Icon,
|
||||
actionsStr,
|
||||
attachmentName,
|
||||
attachmentType,
|
||||
attachmentSize,
|
||||
attachmentExpires,
|
||||
attachmentURL,
|
||||
attachmentDeleted, // Always zero
|
||||
sender,
|
||||
m.User,
|
||||
m.ContentType,
|
||||
m.Encoding,
|
||||
published,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
|
||||
return err
|
||||
}
|
||||
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
if since.IsNone() {
|
||||
return make([]*model.Message, 0), nil
|
||||
} else if since.IsLatest() {
|
||||
return c.messagesLatest(topic)
|
||||
} else if since.IsID() {
|
||||
return c.messagesSinceID(topic, since, scheduled)
|
||||
}
|
||||
return c.messagesSinceTime(topic, since, scheduled)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer idrows.Close()
|
||||
if !idrows.Next() {
|
||||
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
|
||||
}
|
||||
var rowID int64
|
||||
if err := idrows.Scan(&rowID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idrows.Close()
|
||||
var rows *sql.Rows
|
||||
if scheduled {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
|
||||
} else {
|
||||
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readMessages(rows)
|
||||
}
|
||||
|
||||
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||
func (c *commonStore) MessagesExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Message(id string) (*model.Message, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !rows.Next() {
|
||||
return nil, model.ErrMessageNotFound
|
||||
}
|
||||
defer rows.Close()
|
||||
return readMessage(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) MarkPublished(m *model.Message) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonStore) MessageCounts() (map[string]int, error) {
|
||||
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var topic string
|
||||
var count int
|
||||
counts := make(map[string]int)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&topic, &count); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[topic] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Topics() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectTopics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics = append(topics, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) DeleteMessages(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close()
|
||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) ExpireMessages(topics ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, t := range topics {
|
||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
for _, id := range ids {
|
||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.readAttachmentBytesUsed(rows)
|
||||
}
|
||||
|
||||
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||
defer rows.Close()
|
||||
var size int64
|
||||
if !rows.Next() {
|
||||
return 0, errors.New("no rows found")
|
||||
}
|
||||
if err := rows.Scan(&size); err != nil {
|
||||
return 0, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) UpdateStats(messages int64) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *commonStore) Stats() (messages int64, err error) {
|
||||
rows, err := c.db.Query(c.queries.selectStats)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return 0, errNoRows
|
||||
}
|
||||
if err := rows.Scan(&messages); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *commonStore) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *commonStore) processMessageBatches() {
|
||||
if c.queue == nil {
|
||||
return
|
||||
}
|
||||
for messages := range c.queue.Dequeue() {
|
||||
if err := c.addMessages(messages); err != nil {
|
||||
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
|
||||
defer rows.Close()
|
||||
messages := make([]*model.Message, 0)
|
||||
for rows.Next() {
|
||||
m, err := readMessage(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func readMessage(rows *sql.Rows) (*model.Message, error) {
|
||||
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||
var priority int
|
||||
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
||||
err := rows.Scan(
|
||||
&id,
|
||||
&sequenceID,
|
||||
×tamp,
|
||||
&event,
|
||||
&expires,
|
||||
&topic,
|
||||
&msg,
|
||||
&title,
|
||||
&priority,
|
||||
&tagsStr,
|
||||
&click,
|
||||
&icon,
|
||||
&actionsStr,
|
||||
&attachmentName,
|
||||
&attachmentType,
|
||||
&attachmentSize,
|
||||
&attachmentExpires,
|
||||
&attachmentURL,
|
||||
&sender,
|
||||
&user,
|
||||
&contentType,
|
||||
&encoding,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tags []string
|
||||
if tagsStr != "" {
|
||||
tags = strings.Split(tagsStr, ",")
|
||||
}
|
||||
var actions []*model.Action
|
||||
if actionsStr != "" {
|
||||
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
senderIP, err := netip.ParseAddr(sender)
|
||||
if err != nil {
|
||||
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
||||
}
|
||||
var att *model.Attachment
|
||||
if attachmentName != "" && attachmentURL != "" {
|
||||
att = &model.Attachment{
|
||||
Name: attachmentName,
|
||||
Type: attachmentType,
|
||||
Size: attachmentSize,
|
||||
Expires: attachmentExpires,
|
||||
URL: attachmentURL,
|
||||
}
|
||||
}
|
||||
return &model.Message{
|
||||
ID: id,
|
||||
SequenceID: sequenceID,
|
||||
Time: timestamp,
|
||||
Expires: expires,
|
||||
Event: event,
|
||||
Topic: topic,
|
||||
Message: msg,
|
||||
Title: title,
|
||||
Priority: priority,
|
||||
Tags: tags,
|
||||
Click: click,
|
||||
Icon: icon,
|
||||
Actions: actions,
|
||||
Attachment: att,
|
||||
Sender: senderIP,
|
||||
User: user,
|
||||
ContentType: contentType,
|
||||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ensure commonStore implements Store
|
||||
var _ Store = (*commonStore)(nil)
|
||||
|
||||
// Needed by store.go but not part of Store interface; unused import guard
|
||||
var _ = fmt.Sprintf
|
||||
118
message/store_postgres.go
Normal file
118
message/store_postgres.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
)
|
||||
|
||||
// PostgreSQL runtime query constants
|
||||
const (
|
||||
pgInsertMessageQuery = `
|
||||
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
||||
`
|
||||
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
||||
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
||||
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
|
||||
pgSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE mid = $1
|
||||
`
|
||||
pgSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND time >= $2
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND id > $2 AND published = TRUE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND (id > $2 OR published = FALSE)
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE topic = $1 AND published = TRUE
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
pgSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||
FROM message
|
||||
WHERE time <= $1 AND published = FALSE
|
||||
ORDER BY time, id
|
||||
`
|
||||
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
||||
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
||||
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
||||
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
|
||||
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
||||
|
||||
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
||||
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
||||
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
||||
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
||||
|
||||
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||
)
|
||||
|
||||
var pgQueries = storeQueries{
|
||||
insertMessage: pgInsertMessageQuery,
|
||||
deleteMessage: pgDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
|
||||
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
|
||||
selectMessagesByID: pgSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: pgSelectMessagesLatestQuery,
|
||||
selectMessagesDue: pgSelectMessagesDueQuery,
|
||||
selectMessagesExpired: pgSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: pgUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: pgSelectMessagesCountQuery,
|
||||
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
|
||||
selectTopics: pgSelectTopicsQuery,
|
||||
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
|
||||
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: pgSelectStatsQuery,
|
||||
updateStats: pgUpdateStatsQuery,
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
|
||||
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupPostgresDB(db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
|
||||
}
|
||||
90
message/store_postgres_schema.go
Normal file
90
message/store_postgres_schema.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Initial PostgreSQL schema
|
||||
const (
|
||||
pgCreateTablesQuery = `
|
||||
CREATE TABLE IF NOT EXISTS message (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time BIGINT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires BIGINT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size BIGINT NOT NULL,
|
||||
attachment_expires BIGINT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
sender TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS message_stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BIGINT
|
||||
);
|
||||
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
store TEXT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
)
|
||||
|
||||
// PostgreSQL schema management queries
|
||||
const (
|
||||
pgCurrentSchemaVersion = 14
|
||||
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||
)
|
||||
|
||||
func setupPostgresDB(db *sql.DB) error {
|
||||
var schemaVersion int
|
||||
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||
if err != nil {
|
||||
return setupNewPostgresDB(db)
|
||||
}
|
||||
if schemaVersion > pgCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewPostgresDB(db *sql.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
120
message/store_postgres_test.go
Normal file
120
message/store_postgres_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
func newTestPostgresStore(t *testing.T) message.Store {
|
||||
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||
}
|
||||
// Create a unique schema for this test
|
||||
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||
setupDB, err := sql.Open("pgx", dsn)
|
||||
require.Nil(t, err)
|
||||
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, setupDB.Close())
|
||||
// Open store with search_path set to the new schema
|
||||
u, err := url.Parse(dsn)
|
||||
require.Nil(t, err)
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
u.RawQuery = q.Encode()
|
||||
store, err := message.NewPostgresStore(u.String(), 0, 0)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() {
|
||||
store.Close()
|
||||
cleanDB, err := sql.Open("pgx", dsn)
|
||||
if err == nil {
|
||||
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||
cleanDB.Close()
|
||||
}
|
||||
})
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Sender(t *testing.T) {
|
||||
testSender(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_Stats(t *testing.T) {
|
||||
testStats(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newTestPostgresStore(t))
|
||||
}
|
||||
|
||||
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
|
||||
}
|
||||
138
message/store_sqlite.go
Normal file
138
message/store_sqlite.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
// SQLite runtime query constants
|
||||
const (
|
||||
sqliteInsertMessageQuery = `
|
||||
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
|
||||
sqliteSelectMessagesByIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE mid = ?
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND time >= ?
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND id > ? AND published = 1
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND (id > ? OR published = 0)
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesLatestQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE topic = ? AND published = 1
|
||||
ORDER BY time DESC, id DESC
|
||||
LIMIT 1
|
||||
`
|
||||
sqliteSelectMessagesDueQuery = `
|
||||
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||
FROM messages
|
||||
WHERE time <= ? AND published = 0
|
||||
ORDER BY time, id
|
||||
`
|
||||
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
||||
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||
|
||||
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||
|
||||
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
)
|
||||
|
||||
var sqliteQueries = storeQueries{
|
||||
insertMessage: sqliteInsertMessageQuery,
|
||||
deleteMessage: sqliteDeleteMessageQuery,
|
||||
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
||||
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
||||
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
|
||||
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
||||
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
||||
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
|
||||
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
|
||||
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
|
||||
selectMessagesDue: sqliteSelectMessagesDueQuery,
|
||||
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
||||
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
||||
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
||||
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
|
||||
selectTopics: sqliteSelectTopicsQuery,
|
||||
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
|
||||
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
||||
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
||||
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
||||
selectStats: sqliteSelectStatsQuery,
|
||||
updateStats: sqliteUpdateStatsQuery,
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a SQLite file-backed cache
|
||||
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
|
||||
parentDir := filepath.Dir(filename)
|
||||
if !util.FileExists(parentDir) {
|
||||
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||
}
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
|
||||
}
|
||||
|
||||
// NewMemStore creates an in-memory cache
|
||||
func NewMemStore() (Store, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||
}
|
||||
|
||||
// NewNopStore creates an in-memory cache that discards all messages;
|
||||
// it is always empty and can be used if caching is entirely disabled
|
||||
func NewNopStore() (Store, error) {
|
||||
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||
}
|
||||
|
||||
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||
func createMemoryFilename() string {
|
||||
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||
}
|
||||
466
message/store_sqlite_schema.go
Normal file
466
message/store_sqlite_schema.go
Normal file
@@ -0,0 +1,466 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
// Initial SQLite schema
|
||||
const (
|
||||
sqliteCreateMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
sequence_id TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
event TEXT NOT NULL,
|
||||
expires INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_deleted INT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
user TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
COMMIT;
|
||||
`
|
||||
)
|
||||
|
||||
// Schema version management for SQLite
|
||||
const (
|
||||
sqliteCurrentSchemaVersion = 14
|
||||
sqliteCreateSchemaVersionTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
`
|
||||
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||
)
|
||||
|
||||
// Schema migrations for SQLite
|
||||
const (
|
||||
// 0 -> 1
|
||||
sqliteMigrate0To1AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 1 -> 2
|
||||
sqliteMigrate1To2AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||
`
|
||||
|
||||
// 2 -> 3
|
||||
sqliteMigrate2To3AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||
COMMIT;
|
||||
`
|
||||
// 3 -> 4
|
||||
sqliteMigrate3To4AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 4 -> 5
|
||||
sqliteMigrate4To5AlterMessagesTableQuery = `
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages_new (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
attachment_owner TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||
INSERT
|
||||
INTO messages_new (
|
||||
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||
SELECT
|
||||
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||
FROM messages;
|
||||
DROP TABLE messages;
|
||||
ALTER TABLE messages_new RENAME TO messages;
|
||||
COMMIT;
|
||||
`
|
||||
|
||||
// 5 -> 6
|
||||
sqliteMigrate5To6AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 6 -> 7
|
||||
sqliteMigrate6To7AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||
`
|
||||
|
||||
// 7 -> 8
|
||||
sqliteMigrate7To8AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 8 -> 9
|
||||
sqliteMigrate8To9AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
`
|
||||
|
||||
// 9 -> 10
|
||||
sqliteMigrate9To10AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||
`
|
||||
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||
|
||||
// 10 -> 11
|
||||
sqliteMigrate10To11AlterMessagesTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS stats (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INT
|
||||
);
|
||||
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||
`
|
||||
|
||||
// 11 -> 12
|
||||
sqliteMigrate11To12AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||
`
|
||||
|
||||
// 12 -> 13
|
||||
sqliteMigrate12To13AlterMessagesTableQuery = `
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
`
|
||||
|
||||
// 13 -> 14
|
||||
sqliteMigrate13To14AlterMessagesTableQuery = `
|
||||
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
|
||||
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
|
||||
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||
`
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||
0: sqliteMigrateFrom0,
|
||||
1: sqliteMigrateFrom1,
|
||||
2: sqliteMigrateFrom2,
|
||||
3: sqliteMigrateFrom3,
|
||||
4: sqliteMigrateFrom4,
|
||||
5: sqliteMigrateFrom5,
|
||||
6: sqliteMigrateFrom6,
|
||||
7: sqliteMigrateFrom7,
|
||||
8: sqliteMigrateFrom8,
|
||||
9: sqliteMigrateFrom9,
|
||||
10: sqliteMigrateFrom10,
|
||||
11: sqliteMigrateFrom11,
|
||||
12: sqliteMigrateFrom12,
|
||||
13: sqliteMigrateFrom13,
|
||||
}
|
||||
)
|
||||
|
||||
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
// If 'messages' table does not exist, this must be a new database
|
||||
rowsMC, err := db.Query(sqliteSelectMessagesCountQuery)
|
||||
if err != nil {
|
||||
return setupNewSQLite(db)
|
||||
}
|
||||
rowsMC.Close()
|
||||
// If 'messages' table exists, check 'schemaVersion' table
|
||||
schemaVersion := 0
|
||||
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
|
||||
if err == nil {
|
||||
defer rowsSV.Close()
|
||||
if !rowsSV.Next() {
|
||||
return fmt.Errorf("cannot determine schema version: cache file may be corrupt")
|
||||
}
|
||||
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
rowsSV.Close()
|
||||
}
|
||||
// Do migrations
|
||||
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||
return nil
|
||||
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||
}
|
||||
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||
fn, ok := sqliteMigrations[i]
|
||||
if !ok {
|
||||
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||
} else if err := fn(db, cacheDuration); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupNewSQLite(db *sql.DB) error {
|
||||
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||
if startupQueries != "" {
|
||||
if _, err := db.Exec(startupQueries); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||
if _, err := db.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
459
message/store_sqlite_test.go
Normal file
459
message/store_sqlite_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func TestSqliteStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Messages(t *testing.T) {
|
||||
testCacheMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesLock(t *testing.T) {
|
||||
testCacheMessagesLock(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesScheduled(t *testing.T) {
|
||||
testCacheMessagesScheduled(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Topics(t *testing.T) {
|
||||
testCacheTopics(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesSinceID(t *testing.T) {
|
||||
testCacheMessagesSinceID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Prune(t *testing.T) {
|
||||
testCachePrune(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Attachments(t *testing.T) {
|
||||
testCacheAttachments(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_AttachmentsExpired(t *testing.T) {
|
||||
testCacheAttachmentsExpired(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Sender(t *testing.T) {
|
||||
testSender(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Sender(t *testing.T) {
|
||||
testSender(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessageByID(t *testing.T) {
|
||||
testMessageByID(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MarkPublished(t *testing.T) {
|
||||
testMarkPublished(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_ExpireMessages(t *testing.T) {
|
||||
testExpireMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||
testMarkAttachmentsDeleted(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Stats(t *testing.T) {
|
||||
testStats(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_Stats(t *testing.T) {
|
||||
testStats(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_AddMessages(t *testing.T) {
|
||||
testAddMessages(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessagesDue(t *testing.T) {
|
||||
testMessagesDue(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
|
||||
}
|
||||
|
||||
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
|
||||
testMessageFieldRoundTrip(t, newMemTestStore(t))
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From0(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 0" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(1024) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
require.Equal(t, "some message 5", messages[5].Message)
|
||||
require.Equal(t, "", messages[5].Title)
|
||||
require.Nil(t, messages[5].Tags)
|
||||
require.Equal(t, 0, messages[5].Priority)
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From1(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 1" schema
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id VARCHAR(20) PRIMARY KEY,
|
||||
time INT NOT NULL,
|
||||
topic VARCHAR(64) NOT NULL,
|
||||
message VARCHAR(512) NOT NULL,
|
||||
title VARCHAR(256) NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags VARCHAR(256) NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Add delayed message
|
||||
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
|
||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(delayedMessage))
|
||||
|
||||
// 10, not 11!
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
|
||||
// 11!
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 11, len(messages))
|
||||
|
||||
// Check that index "idx_topic" exists
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var indexName string
|
||||
require.Nil(t, rows.Scan(&indexName))
|
||||
require.Equal(t, "idx_topic", indexName)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
|
||||
func TestSqliteStore_Migration_From9(t *testing.T) {
|
||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create "version 9" schema
|
||||
_, err = db.Exec(`
|
||||
BEGIN;
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mid TEXT NOT NULL,
|
||||
time INT NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
priority INT NOT NULL,
|
||||
tags TEXT NOT NULL,
|
||||
click TEXT NOT NULL,
|
||||
icon TEXT NOT NULL,
|
||||
actions TEXT NOT NULL,
|
||||
attachment_name TEXT NOT NULL,
|
||||
attachment_type TEXT NOT NULL,
|
||||
attachment_size INT NOT NULL,
|
||||
attachment_expires INT NOT NULL,
|
||||
attachment_url TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
encoding TEXT NOT NULL,
|
||||
published INT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||
id INT PRIMARY KEY,
|
||||
version INT NOT NULL
|
||||
);
|
||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||
COMMIT;
|
||||
`)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Insert a bunch of messages
|
||||
insertQuery := `
|
||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = db.Exec(
|
||||
insertQuery,
|
||||
fmt.Sprintf("abcd%d", i),
|
||||
time.Now().Unix(),
|
||||
"mytopic",
|
||||
fmt.Sprintf("some message %d", i),
|
||||
"", // title
|
||||
0, // priority
|
||||
"", // tags
|
||||
"", // click
|
||||
"", // icon
|
||||
"", // actions
|
||||
"", // attachment_name
|
||||
"", // attachment_type
|
||||
0, // attachment_size
|
||||
0, // attachment_expires
|
||||
"", // attachment_url
|
||||
"9.9.9.9", // sender
|
||||
"", // encoding
|
||||
1, // published
|
||||
)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Nil(t, db.Close())
|
||||
|
||||
// Create store to trigger migration
|
||||
cacheDuration := 17 * time.Hour
|
||||
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
checkSqliteSchemaVersion(t, filename)
|
||||
|
||||
// Check version
|
||||
verifyDB, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer verifyDB.Close()
|
||||
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var version int
|
||||
require.Nil(t, rows.Scan(&version))
|
||||
require.Equal(t, 14, version)
|
||||
require.Nil(t, rows.Close())
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 10, len(messages))
|
||||
for _, m := range messages {
|
||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
startupQueries := `pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;`
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.FileExists(t, filename+"-wal")
|
||||
require.FileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_None(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||
require.FileExists(t, filename)
|
||||
require.NoFileExists(t, filename+"-wal")
|
||||
require.NoFileExists(t, filename+"-shm")
|
||||
}
|
||||
|
||||
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
|
||||
filename := newSqliteTestStoreFile(t)
|
||||
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNopStore(t *testing.T) {
|
||||
s, err := message.NewNopStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, messages)
|
||||
|
||||
topics, err := s.Topics()
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, topics)
|
||||
}
|
||||
|
||||
func newSqliteTestStore(t *testing.T) message.Store {
|
||||
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFile(t *testing.T) string {
|
||||
return filepath.Join(t.TempDir(), "cache.db")
|
||||
}
|
||||
|
||||
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
|
||||
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func newMemTestStore(t *testing.T) message.Store {
|
||||
s, err := message.NewMemStore()
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { s.Close() })
|
||||
return s
|
||||
}
|
||||
|
||||
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
require.Nil(t, err)
|
||||
defer db.Close()
|
||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||
require.Nil(t, err)
|
||||
require.True(t, rows.Next())
|
||||
var schemaVersion int
|
||||
require.Nil(t, rows.Scan(&schemaVersion))
|
||||
require.Equal(t, 14, schemaVersion)
|
||||
require.Nil(t, rows.Close())
|
||||
}
|
||||
767
message/store_test.go
Normal file
767
message/store_test.go
Normal file
@@ -0,0 +1,767 @@
|
||||
package message_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/message"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
)
|
||||
|
||||
func testCacheMessages(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = 1
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Adding invalid
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
||||
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
||||
|
||||
// mytopic: count
|
||||
counts, err := s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
|
||||
// mytopic: since all
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "my message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
require.Equal(t, model.MessageEvent, messages[0].Event)
|
||||
require.Equal(t, "", messages[0].Title)
|
||||
require.Equal(t, 0, messages[0].Priority)
|
||||
require.Nil(t, messages[0].Tags)
|
||||
require.Equal(t, "my other message", messages[1].Message)
|
||||
|
||||
// mytopic: since none
|
||||
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
||||
require.Empty(t, messages)
|
||||
|
||||
// mytopic: since m1 (by ID)
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, m2.ID, messages[0].ID)
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
require.Equal(t, "mytopic", messages[0].Topic)
|
||||
|
||||
// mytopic: since 2
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// mytopic: latest
|
||||
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
|
||||
// example: count
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["example"])
|
||||
|
||||
// example: since all
|
||||
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
||||
require.Equal(t, "my example message", messages[0].Message)
|
||||
|
||||
// non-existing: count
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, counts["doesnotexist"])
|
||||
|
||||
// non-existing: since all
|
||||
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func testCacheMessagesLock(t *testing.T, s message.Store) {
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5000; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
|
||||
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||
require.Equal(t, "message 2", messages[2].Message)
|
||||
|
||||
messages, _ = s.MessagesDue()
|
||||
require.Empty(t, messages)
|
||||
}
|
||||
|
||||
func testCacheTopics(t *testing.T, s message.Store) {
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
||||
|
||||
topics, err := s.Topics()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.Equal(t, 2, len(topics))
|
||||
require.Contains(t, topics, "topic1")
|
||||
require.Contains(t, topics, "topic2")
|
||||
}
|
||||
|
||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
m.Priority = 5
|
||||
m.Title = "some title"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||
require.Equal(t, 5, messages[0].Priority)
|
||||
require.Equal(t, "some title", messages[0].Title)
|
||||
}
|
||||
|
||||
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||
m1.Time = 100
|
||||
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||
m2.Time = 200
|
||||
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
||||
m4.Time = 400
|
||||
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
||||
m6.Time = 600
|
||||
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
||||
m7.Time = 700
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
require.Nil(t, s.AddMessage(m4))
|
||||
require.Nil(t, s.AddMessage(m5))
|
||||
require.Nil(t, s.AddMessage(m6))
|
||||
require.Nil(t, s.AddMessage(m7))
|
||||
|
||||
// Case 1: Since ID exists, exclude scheduled
|
||||
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
|
||||
// Case 2: Since ID exists, include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
||||
require.Equal(t, 5, len(messages))
|
||||
require.Equal(t, "message 4", messages[0].Message)
|
||||
require.Equal(t, "message 6", messages[1].Message)
|
||||
require.Equal(t, "message 7", messages[2].Message)
|
||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||
|
||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
||||
require.Equal(t, 7, len(messages))
|
||||
require.Equal(t, "message 1", messages[0].Message)
|
||||
require.Equal(t, "message 2", messages[1].Message)
|
||||
require.Equal(t, "message 4", messages[2].Message)
|
||||
require.Equal(t, "message 6", messages[3].Message)
|
||||
require.Equal(t, "message 7", messages[4].Message)
|
||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||
|
||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, "message 5", messages[0].Message)
|
||||
require.Equal(t, "message 3", messages[1].Message)
|
||||
}
|
||||
|
||||
func testCachePrune(t *testing.T, s message.Store) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||
m1.Time = now - 10
|
||||
m1.Expires = now - 5
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||
m2.Time = now - 5
|
||||
m2.Expires = now + 5 // In the future
|
||||
|
||||
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
||||
m3.Time = now - 12
|
||||
m3.Expires = now - 2
|
||||
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
counts, err := s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, counts["mytopic"])
|
||||
require.Equal(t, 1, counts["another_topic"])
|
||||
|
||||
expiredMessageIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
||||
|
||||
counts, err = s.MessageCounts()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, counts["mytopic"])
|
||||
require.Equal(t, 0, counts["another_topic"])
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "my other message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testCacheAttachments(t *testing.T, s message.Store) {
|
||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "flower.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 5000,
|
||||
Expires: expires1,
|
||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: expires2,
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.User = "u_BAsbaAa"
|
||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "another-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: expires3,
|
||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
require.Equal(t, "flower for you", messages[0].Message)
|
||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||
|
||||
require.Equal(t, "sending you a car", messages[1].Message)
|
||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||
|
||||
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(10000), size)
|
||||
|
||||
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||
|
||||
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(20000), size)
|
||||
}
|
||||
|
||||
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
|
||||
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||
m.ID = "m1"
|
||||
m.SequenceID = "m1"
|
||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
||||
m.ID = "m2"
|
||||
m.SequenceID = "m2"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 10000,
|
||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
||||
m.ID = "m3"
|
||||
m.SequenceID = "m3"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Expires: 0, // Unknown!
|
||||
URL: "https://somedomain.com/car.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
||||
m.ID = "m4"
|
||||
m.SequenceID = "m4"
|
||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||
m.Attachment = &model.Attachment{
|
||||
Name: "expired-car.jpg",
|
||||
Type: "image/jpeg",
|
||||
Size: 20000,
|
||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "m4", ids[0])
|
||||
}
|
||||
|
||||
func testSender(t *testing.T, s message.Store) {
|
||||
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||
}
|
||||
|
||||
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
|
||||
// Create a scheduled (unpublished) message
|
||||
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
scheduledMsg.ID = "scheduled1"
|
||||
scheduledMsg.SequenceID = "seq123"
|
||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||
require.Nil(t, s.AddMessage(scheduledMsg))
|
||||
|
||||
// Create a published message with different sequence ID
|
||||
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
||||
publishedMsg.ID = "published1"
|
||||
publishedMsg.SequenceID = "seq456"
|
||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||
require.Nil(t, s.AddMessage(publishedMsg))
|
||||
|
||||
// Create a scheduled message in a different topic
|
||||
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
||||
otherTopicMsg.ID = "other1"
|
||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(otherTopicMsg))
|
||||
|
||||
// Verify all messages exist (including scheduled)
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Delete scheduled message by sequence ID and verify returned IDs
|
||||
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(deletedIDs))
|
||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||
|
||||
// Verify scheduled message is deleted
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
|
||||
// Verify other topic's message still exists (topic-scoped deletion)
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "other scheduled", messages[0].Message)
|
||||
|
||||
// Deleting non-existent sequence ID should return empty list
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
// Deleting published message should not affect it (only deletes unpublished)
|
||||
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||
require.Nil(t, err)
|
||||
require.Empty(t, deletedIDs)
|
||||
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "published message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testMessageByID(t *testing.T, s message.Store) {
|
||||
// Add a message
|
||||
m := model.NewDefaultMessage("mytopic", "some message")
|
||||
m.Title = "some title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"tag1", "tag2"}
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve by ID
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "some message", retrieved.Message)
|
||||
require.Equal(t, "some title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
||||
|
||||
// Non-existent ID returns ErrMessageNotFound
|
||||
_, err = s.Message("doesnotexist")
|
||||
require.Equal(t, model.ErrMessageNotFound, err)
|
||||
}
|
||||
|
||||
func testMarkPublished(t *testing.T, s message.Store) {
|
||||
// Add a scheduled message (future time → unpublished)
|
||||
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||
m.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Verify it does not appear in non-scheduled queries
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(messages))
|
||||
|
||||
// Verify it does appear in scheduled queries
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Mark as published
|
||||
require.Nil(t, s.MarkPublished(m))
|
||||
|
||||
// Now it should appear in non-scheduled queries too
|
||||
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "scheduled message", messages[0].Message)
|
||||
}
|
||||
|
||||
func testExpireMessages(t *testing.T, s message.Store) {
|
||||
// Add messages to two topics
|
||||
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2 := model.NewDefaultMessage("topic1", "message 2")
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m3 := model.NewDefaultMessage("topic2", "message 3")
|
||||
m3.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Verify all messages exist
|
||||
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
|
||||
// Expire topic1 messages
|
||||
require.Nil(t, s.ExpireMessages("topic1"))
|
||||
|
||||
// topic1 messages should now be expired (expires set to past)
|
||||
expiredIDs, err := s.MessagesExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(expiredIDs))
|
||||
sort.Strings(expiredIDs)
|
||||
expectedIDs := []string{m1.ID, m2.ID}
|
||||
sort.Strings(expectedIDs)
|
||||
require.Equal(t, expectedIDs, expiredIDs)
|
||||
|
||||
// topic2 should be unaffected
|
||||
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "message 3", messages[0].Message)
|
||||
}
|
||||
|
||||
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
|
||||
// Add a message with an expired attachment (file needs cleanup)
|
||||
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||
m1.ID = "msg1"
|
||||
m1.SequenceID = "msg1"
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m1.Attachment = &model.Attachment{
|
||||
Name: "old.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 50000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/old.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message with another expired attachment
|
||||
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
||||
m2.ID = "msg2"
|
||||
m2.SequenceID = "msg2"
|
||||
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||
m2.Attachment = &model.Attachment{
|
||||
Name: "another.pdf",
|
||||
Type: "application/pdf",
|
||||
Size: 30000,
|
||||
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||
URL: "https://ntfy.sh/file/another.pdf",
|
||||
}
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Both should show as expired attachments needing cleanup
|
||||
ids, err := s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(ids))
|
||||
|
||||
// Mark msg1's attachment as deleted (file cleaned up)
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
||||
|
||||
// Now only msg2 should show as needing cleanup
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(ids))
|
||||
require.Equal(t, "msg2", ids[0])
|
||||
|
||||
// Mark msg2 too
|
||||
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
||||
|
||||
// No more expired attachments to clean up
|
||||
ids, err = s.AttachmentsExpired()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(ids))
|
||||
|
||||
// Messages themselves still exist
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
}
|
||||
|
||||
func testStats(t *testing.T, s message.Store) {
|
||||
// Initial stats should be zero
|
||||
messages, err := s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), messages)
|
||||
|
||||
// Update stats
|
||||
require.Nil(t, s.UpdateStats(42))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(42), messages)
|
||||
|
||||
// Update again (overwrites)
|
||||
require.Nil(t, s.UpdateStats(100))
|
||||
messages, err = s.Stats()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(100), messages)
|
||||
}
|
||||
|
||||
func testAddMessages(t *testing.T, s message.Store) {
|
||||
// Batch add multiple messages
|
||||
msgs := []*model.Message{
|
||||
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||
model.NewDefaultMessage("mytopic", "batch 2"),
|
||||
model.NewDefaultMessage("othertopic", "batch 3"),
|
||||
}
|
||||
require.Nil(t, s.AddMessages(msgs))
|
||||
|
||||
// Verify all were inserted
|
||||
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 2, len(messages))
|
||||
|
||||
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(messages))
|
||||
require.Equal(t, "batch 3", messages[0].Message)
|
||||
|
||||
// Empty batch should succeed
|
||||
require.Nil(t, s.AddMessages([]*model.Message{}))
|
||||
|
||||
// Batch with invalid event type should fail
|
||||
badMsgs := []*model.Message{
|
||||
model.NewKeepaliveMessage("mytopic"),
|
||||
}
|
||||
require.NotNil(t, s.AddMessages(badMsgs))
|
||||
}
|
||||
|
||||
func testMessagesDue(t *testing.T, s message.Store) {
|
||||
// Add a message scheduled in the past (i.e. it's due now)
|
||||
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||
// Set expires in the future so it doesn't get pruned
|
||||
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m1))
|
||||
|
||||
// Add a message scheduled in the future (not due)
|
||||
m2 := model.NewDefaultMessage("mytopic", "future message")
|
||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||
require.Nil(t, s.AddMessage(m2))
|
||||
|
||||
// Mark m1 as published so it won't be "due"
|
||||
// (MessagesDue returns unpublished messages whose time <= now)
|
||||
// m1 is auto-published (time <= now), so it should not be due
|
||||
// m2 is unpublished (time in future), not due yet
|
||||
due, err := s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
||||
// We need to manipulate the database to create a truly "due" message:
|
||||
// a message with published=false and time <= now
|
||||
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
||||
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
||||
require.Nil(t, s.AddMessage(m3))
|
||||
|
||||
// Not due yet
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, len(due))
|
||||
|
||||
// Wait for it to become due
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
due, err = s.MessagesDue()
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, len(due))
|
||||
require.Equal(t, "truly due message", due[0].Message)
|
||||
}
|
||||
|
||||
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
|
||||
// Create a message with all fields populated
|
||||
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||
m.SequenceID = "custom_seq_id"
|
||||
m.Title = "A Title"
|
||||
m.Priority = 4
|
||||
m.Tags = []string{"warning", "srv01"}
|
||||
m.Click = "https://example.com/click"
|
||||
m.Icon = "https://example.com/icon.png"
|
||||
m.Actions = []*model.Action{
|
||||
{
|
||||
ID: "action1",
|
||||
Action: "view",
|
||||
Label: "Open Site",
|
||||
URL: "https://example.com",
|
||||
Clear: true,
|
||||
},
|
||||
{
|
||||
ID: "action2",
|
||||
Action: "http",
|
||||
Label: "Call Webhook",
|
||||
URL: "https://example.com/hook",
|
||||
Method: "PUT",
|
||||
Headers: map[string]string{"X-Token": "secret"},
|
||||
Body: `{"key":"value"}`,
|
||||
},
|
||||
}
|
||||
m.ContentType = "text/markdown"
|
||||
m.Encoding = "base64"
|
||||
m.Sender = netip.MustParseAddr("9.8.7.6")
|
||||
m.User = "u_TestUser123"
|
||||
require.Nil(t, s.AddMessage(m))
|
||||
|
||||
// Retrieve and verify every field
|
||||
retrieved, err := s.Message(m.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, m.ID, retrieved.ID)
|
||||
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
||||
require.Equal(t, m.Time, retrieved.Time)
|
||||
require.Equal(t, m.Expires, retrieved.Expires)
|
||||
require.Equal(t, model.MessageEvent, retrieved.Event)
|
||||
require.Equal(t, "mytopic", retrieved.Topic)
|
||||
require.Equal(t, "hello world", retrieved.Message)
|
||||
require.Equal(t, "A Title", retrieved.Title)
|
||||
require.Equal(t, 4, retrieved.Priority)
|
||||
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
||||
require.Equal(t, "https://example.com/click", retrieved.Click)
|
||||
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
||||
require.Equal(t, "text/markdown", retrieved.ContentType)
|
||||
require.Equal(t, "base64", retrieved.Encoding)
|
||||
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
||||
require.Equal(t, "u_TestUser123", retrieved.User)
|
||||
|
||||
// Verify actions round-trip
|
||||
require.Equal(t, 2, len(retrieved.Actions))
|
||||
|
||||
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
||||
require.Equal(t, "view", retrieved.Actions[0].Action)
|
||||
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
||||
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
||||
require.Equal(t, true, retrieved.Actions[0].Clear)
|
||||
|
||||
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
||||
require.Equal(t, "http", retrieved.Actions[1].Action)
|
||||
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
||||
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
||||
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
||||
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
||||
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||
}
|
||||
Reference in New Issue
Block a user