Merge remote-tracking branch 'timofej673/main' into message-cache-lock

This commit is contained in:
binwiederhier
2025-08-08 16:06:27 -04:00

View File

@@ -9,6 +9,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"sync"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
@@ -106,7 +107,7 @@ const (
WHERE topic = ? AND published = 1 WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC ORDER BY time DESC, id DESC
LIMIT 1 LIMIT 1
` `
selectMessagesDueQuery = ` selectMessagesDueQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages FROM messages
@@ -282,6 +283,7 @@ var (
type messageCache struct { type messageCache struct {
db *sql.DB db *sql.DB
queue *util.BatchingQueue[*message] queue *util.BatchingQueue[*message]
mu sync.Mutex
nop bool nop bool
} }
@@ -347,6 +349,8 @@ func (c *messageCache) AddMessage(m *message) error {
// addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until // addMessages synchronously stores a match of messages. If the database is locked, the transaction waits until
// SQLite's busy_timeout is exceeded before erroring out. // SQLite's busy_timeout is exceeded before erroring out.
func (c *messageCache) addMessages(ms []*message) error { func (c *messageCache) addMessages(ms []*message) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.nop { if c.nop {
return nil return nil
} }
@@ -528,6 +532,8 @@ func (c *messageCache) Message(id string) (*message, error) {
} }
func (c *messageCache) MarkPublished(m *message) error { func (c *messageCache) MarkPublished(m *message) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(updateMessagePublishedQuery, m.ID) _, err := c.db.Exec(updateMessagePublishedQuery, m.ID)
return err return err
} }
@@ -573,6 +579,8 @@ func (c *messageCache) Topics() (map[string]*topic, error) {
} }
func (c *messageCache) DeleteMessages(ids ...string) error { func (c *messageCache) DeleteMessages(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
@@ -587,6 +595,8 @@ func (c *messageCache) DeleteMessages(ids ...string) error {
} }
func (c *messageCache) ExpireMessages(topics ...string) error { func (c *messageCache) ExpireMessages(topics ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
@@ -621,6 +631,8 @@ func (c *messageCache) AttachmentsExpired() ([]string, error) {
} }
func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error { func (c *messageCache) MarkAttachmentsDeleted(ids ...string) error {
c.mu.Lock()
defer c.mu.Unlock()
tx, err := c.db.Begin() tx, err := c.db.Begin()
if err != nil { if err != nil {
return err return err
@@ -766,6 +778,8 @@ func readMessage(rows *sql.Rows) (*message, error) {
} }
func (c *messageCache) UpdateStats(messages int64) error { func (c *messageCache) UpdateStats(messages int64) error {
c.mu.Lock()
defer c.mu.Unlock()
_, err := c.db.Exec(updateStatsQuery, messages) _, err := c.db.Exec(updateStatsQuery, messages)
return err return err
} }