diff --git a/server/message_cache.go b/server/message_cache.go index 03cb4969..eabbcf63 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "path/filepath" "strings" "time" @@ -64,6 +65,10 @@ const ( INSERT INTO stats (key, value) VALUES ('messages', 0); COMMIT; ` + builtinMessageCacheStartupQueries = ` + PRAGMA foreign_keys = ON; + PRAGMA busy_timeout = 50000; -- Wait up to 5 seconds for a lock to be released + ` insertMessageQuery = ` INSERT INTO messages (mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) @@ -287,13 +292,18 @@ type messageCache struct { // newSqliteCache creates a SQLite file-backed cache func newSqliteCache(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (*messageCache, error) { + // Parse the filename + file, datasource, err := parseSqliteFile(filename) + if err != nil { + return nil, fmt.Errorf("cannot parse cache database filename %s: %w", filename, err) + } // Check the parent directory of the database file (makes for friendly error messages) parentDir := filepath.Dir(filename) if !util.FileExists(parentDir) { return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir) } // Open database - db, err := sql.Open("sqlite3", filename) + db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=50000", filename)) if err != nil { return nil, err } @@ -789,8 +799,21 @@ func (c *messageCache) Close() error { return c.db.Close() } +func parseSqliteFile(filename string) (file string, datasource string, err error) { + f, err := url.Parse(filename) + if err != nil { + return "", "", fmt.Errorf("cannot parse cache database filename %s: %w", filename, err) + } else if f.Scheme != "file" { + return f.Path, filename, nil + } + return filename, filename, nil +} + func setupMessagesDB(db *sql.DB, startupQueries string, cacheDuration time.Duration) error { // Run startup queries + if _, err := db.Exec(builtinMessageCacheStartupQueries); err != nil { + return err + } if startupQueries != "" { if _, err := db.Exec(startupQueries); err != nil { return err diff --git a/server/message_cache_test.go b/server/message_cache_test.go index 778f28fe..10d065db 100644 --- a/server/message_cache_test.go +++ b/server/message_cache_test.go @@ -3,8 +3,11 @@ package server import ( "database/sql" "fmt" + "github.com/stretchr/testify/assert" "net/netip" + "net/url" "path/filepath" + "sync" "testing" "time" @@ -90,6 +93,26 @@ func testCacheMessages(t *testing.T, c *messageCache) { require.Empty(t, messages) } +func TestSqliteCache_MessagesLock(t *testing.T) { + testCacheMessagesLock(t, newSqliteTestCache(t)) +} + +func TestMemCache_MessagesLock(t *testing.T) { + testCacheMessagesLock(t, newMemTestCache(t)) +} + +func testCacheMessagesLock(t *testing.T, c *messageCache) { + var wg sync.WaitGroup + for i := 0; i < 3000; i++ { + wg.Add(1) + go func() { + assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message"))) + wg.Done() + }() + } + wg.Wait() +} + func TestSqliteCache_MessagesScheduled(t *testing.T) { testCacheMessagesScheduled(t, newSqliteTestCache(t)) } @@ -685,6 +708,35 @@ func checkSchemaVersion(t *testing.T, db *sql.DB) { require.Nil(t, rows.Close()) } +func TestURL(t *testing.T) { + u, _ := url.Parse("file:mem?_busy_timeout=1000&_journal_mode=WAL&_synchronous=normal&_temp_store=memory") + fmt.Printf("opaque: %+v\n", u.Opaque) + fmt.Printf("scheme: %+v\n", u.Scheme) + fmt.Printf("host: %+v\n", u.Host) + fmt.Printf("path: %+v\n", u.Path) + fmt.Printf("raw path: %+v\n", u.RawPath) + fmt.Printf("raw query: %+v\n", u.RawQuery) + fmt.Printf("query: %+v\n", u.Query()) + fmt.Println("----------") + u, _ = url.Parse("myfile.db") + fmt.Printf("opaque: %+v\n", u.Opaque) + fmt.Printf("scheme: %+v\n", u.Scheme) + fmt.Printf("host: %+v\n", u.Host) + fmt.Printf("path: %+v\n", u.Path) + fmt.Printf("raw path: %+v\n", u.RawPath) + fmt.Printf("raw query: %+v\n", u.RawQuery) + fmt.Printf("query: %+v\n", u.Query()) + fmt.Println("----------") + u, _ = url.Parse("htttps://abc.com/myfile.db") + fmt.Printf("opaque: %+v\n", u.Opaque) + fmt.Printf("scheme: %+v\n", u.Scheme) + fmt.Printf("host: %+v\n", u.Host) + fmt.Printf("path: %+v\n", u.Path) + fmt.Printf("raw path: %+v\n", u.RawPath) + fmt.Printf("raw query: %+v\n", u.RawQuery) + fmt.Printf("query: %+v\n", u.Query()) + +} func TestMemCache_NopCache(t *testing.T) { c, _ := newNopCache() require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message"))) diff --git a/server/webpush_store.go b/server/webpush_store.go index db0304be..cdba26a8 100644 --- a/server/webpush_store.go +++ b/server/webpush_store.go @@ -50,7 +50,7 @@ const ( ); COMMIT; ` - builtinStartupQueries = ` + builtinWebPushStartupQueries = ` PRAGMA foreign_keys = ON; ` @@ -134,7 +134,7 @@ func runWebPushStartupQueries(db *sql.DB, startupQueries string) error { if _, err := db.Exec(startupQueries); err != nil { return err } - if _, err := db.Exec(builtinStartupQueries); err != nil { + if _, err := db.Exec(builtinWebPushStartupQueries); err != nil { return err } return nil