Compare commits
5 Commits
postgres-w
...
postgres-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e4a48b058 | ||
|
|
939b3d1117 | ||
|
|
9cc9891f49 | ||
|
|
0d1f3444f2 | ||
|
|
2716ede6e1 |
16
.github/workflows/release.yaml
vendored
16
.github/workflows/release.yaml
vendored
@@ -6,6 +6,22 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:17
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: ntfy
|
||||||
|
POSTGRES_PASSWORD: ntfy
|
||||||
|
POSTGRES_DB: ntfy_test
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd "pg_isready -U ntfy"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
env:
|
||||||
|
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
18
.github/workflows/test.yaml
vendored
18
.github/workflows/test.yaml
vendored
@@ -3,6 +3,22 @@ on: [ push, pull_request ]
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:17
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: ntfy
|
||||||
|
POSTGRES_PASSWORD: ntfy
|
||||||
|
POSTGRES_DB: ntfy_test
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
options: >-
|
||||||
|
--health-cmd "pg_isready -U ntfy"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
env:
|
||||||
|
NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -23,7 +39,7 @@ jobs:
|
|||||||
- name: Build web app (required for tests)
|
- name: Build web app (required for tests)
|
||||||
run: make web
|
run: make web
|
||||||
- name: Run tests, formatting, vetting and linting
|
- name: Run tests, formatting, vetting and linting
|
||||||
run: make check
|
run: make checkv
|
||||||
- name: Run coverage
|
- name: Run coverage
|
||||||
run: make coverage
|
run: make coverage
|
||||||
- name: Upload coverage to codecov.io
|
- name: Upload coverage to codecov.io
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -265,6 +265,8 @@ cli-build-results:
|
|||||||
|
|
||||||
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
|
check: test web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||||
|
|
||||||
|
checkv: testv web-fmt-check fmt-check vet web-lint lint staticcheck
|
||||||
|
|
||||||
test: .PHONY
|
test: .PHONY
|
||||||
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
go test $(shell go list ./... | grep -vE 'ntfy/(test|examples|tools)')
|
||||||
|
|
||||||
|
|||||||
628
message/store.go
Normal file
628
message/store.go
Normal file
@@ -0,0 +1,628 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tagMessageCache = "message_cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoRows = errors.New("no rows found")
|
||||||
|
|
||||||
|
// Store is the interface for a message cache store
|
||||||
|
type Store interface {
|
||||||
|
AddMessage(m *model.Message) error
|
||||||
|
AddMessages(ms []*model.Message) error
|
||||||
|
DB() *sql.DB
|
||||||
|
Message(id string) (*model.Message, error)
|
||||||
|
MessageCounts() (map[string]int, error)
|
||||||
|
Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error)
|
||||||
|
MessagesDue() ([]*model.Message, error)
|
||||||
|
MessagesExpired() ([]string, error)
|
||||||
|
MarkPublished(m *model.Message) error
|
||||||
|
UpdateMessageTime(messageID string, timestamp int64) error
|
||||||
|
Topics() ([]string, error)
|
||||||
|
DeleteMessages(ids ...string) error
|
||||||
|
DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error)
|
||||||
|
ExpireMessages(topics ...string) error
|
||||||
|
AttachmentsExpired() ([]string, error)
|
||||||
|
MarkAttachmentsDeleted(ids ...string) error
|
||||||
|
AttachmentBytesUsedBySender(sender string) (int64, error)
|
||||||
|
AttachmentBytesUsedByUser(userID string) (int64, error)
|
||||||
|
UpdateStats(messages int64) error
|
||||||
|
Stats() (int64, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeQueries holds the database-specific SQL queries
|
||||||
|
type storeQueries struct {
|
||||||
|
insertMessage string
|
||||||
|
deleteMessage string
|
||||||
|
selectScheduledMessageIDsBySeqID string
|
||||||
|
deleteScheduledBySequenceID string
|
||||||
|
updateMessagesForTopicExpiry string
|
||||||
|
selectRowIDFromMessageID string
|
||||||
|
selectMessagesByID string
|
||||||
|
selectMessagesSinceTime string
|
||||||
|
selectMessagesSinceTimeScheduled string
|
||||||
|
selectMessagesSinceID string
|
||||||
|
selectMessagesSinceIDScheduled string
|
||||||
|
selectMessagesLatest string
|
||||||
|
selectMessagesDue string
|
||||||
|
selectMessagesExpired string
|
||||||
|
updateMessagePublished string
|
||||||
|
selectMessagesCount string
|
||||||
|
selectMessageCountPerTopic string
|
||||||
|
selectTopics string
|
||||||
|
updateAttachmentDeleted string
|
||||||
|
selectAttachmentsExpired string
|
||||||
|
selectAttachmentsSizeBySender string
|
||||||
|
selectAttachmentsSizeByUserID string
|
||||||
|
selectStats string
|
||||||
|
updateStats string
|
||||||
|
updateMessageTime string
|
||||||
|
}
|
||||||
|
|
||||||
|
// commonStore implements store operations that are identical across database backends
|
||||||
|
type commonStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
queue *util.BatchingQueue[*model.Message]
|
||||||
|
nop bool
|
||||||
|
mu sync.Mutex
|
||||||
|
queries storeQueries
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCommonStore(db *sql.DB, queries storeQueries, batchSize int, batchTimeout time.Duration, nop bool) *commonStore {
|
||||||
|
var queue *util.BatchingQueue[*model.Message]
|
||||||
|
if batchSize > 0 || batchTimeout > 0 {
|
||||||
|
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
|
||||||
|
}
|
||||||
|
c := &commonStore{
|
||||||
|
db: db,
|
||||||
|
queue: queue,
|
||||||
|
nop: nop,
|
||||||
|
queries: queries,
|
||||||
|
}
|
||||||
|
go c.processMessageBatches()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB returns the underlying database connection
|
||||||
|
func (c *commonStore) DB() *sql.DB {
|
||||||
|
return c.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
|
||||||
|
func (c *commonStore) AddMessage(m *model.Message) error {
|
||||||
|
if c.queue != nil {
|
||||||
|
c.queue.Enqueue(m)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.addMessages([]*model.Message{m})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddMessages synchronously stores a batch of messages to the message cache
|
||||||
|
func (c *commonStore) AddMessages(ms []*model.Message) error {
|
||||||
|
return c.addMessages(ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) addMessages(ms []*model.Message) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.nop {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(ms) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
tx, err := c.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
stmt, err := tx.Prepare(c.queries.insertMessage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
for _, m := range ms {
|
||||||
|
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
|
||||||
|
return model.ErrUnexpectedMessageType
|
||||||
|
}
|
||||||
|
published := m.Time <= time.Now().Unix()
|
||||||
|
tags := strings.Join(m.Tags, ",")
|
||||||
|
var attachmentName, attachmentType, attachmentURL string
|
||||||
|
var attachmentSize, attachmentExpires int64
|
||||||
|
var attachmentDeleted bool
|
||||||
|
if m.Attachment != nil {
|
||||||
|
attachmentName = m.Attachment.Name
|
||||||
|
attachmentType = m.Attachment.Type
|
||||||
|
attachmentSize = m.Attachment.Size
|
||||||
|
attachmentExpires = m.Attachment.Expires
|
||||||
|
attachmentURL = m.Attachment.URL
|
||||||
|
}
|
||||||
|
var actionsStr string
|
||||||
|
if len(m.Actions) > 0 {
|
||||||
|
actionsBytes, err := json.Marshal(m.Actions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
actionsStr = string(actionsBytes)
|
||||||
|
}
|
||||||
|
var sender string
|
||||||
|
if m.Sender.IsValid() {
|
||||||
|
sender = m.Sender.String()
|
||||||
|
}
|
||||||
|
_, err := stmt.Exec(
|
||||||
|
m.ID,
|
||||||
|
m.SequenceID,
|
||||||
|
m.Time,
|
||||||
|
m.Event,
|
||||||
|
m.Expires,
|
||||||
|
m.Topic,
|
||||||
|
m.Message,
|
||||||
|
m.Title,
|
||||||
|
m.Priority,
|
||||||
|
tags,
|
||||||
|
m.Click,
|
||||||
|
m.Icon,
|
||||||
|
actionsStr,
|
||||||
|
attachmentName,
|
||||||
|
attachmentType,
|
||||||
|
attachmentSize,
|
||||||
|
attachmentExpires,
|
||||||
|
attachmentURL,
|
||||||
|
attachmentDeleted, // Always zero
|
||||||
|
sender,
|
||||||
|
m.User,
|
||||||
|
m.ContentType,
|
||||||
|
m.Encoding,
|
||||||
|
published,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
|
if since.IsNone() {
|
||||||
|
return make([]*model.Message, 0), nil
|
||||||
|
} else if since.IsLatest() {
|
||||||
|
return c.messagesLatest(topic)
|
||||||
|
} else if since.IsID() {
|
||||||
|
return c.messagesSinceID(topic, since, scheduled)
|
||||||
|
}
|
||||||
|
return c.messagesSinceTime(topic, since, scheduled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
|
var rows *sql.Rows
|
||||||
|
var err error
|
||||||
|
if scheduled {
|
||||||
|
rows, err = c.db.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
|
||||||
|
} else {
|
||||||
|
rows, err = c.db.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return readMessages(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
|
||||||
|
idrows, err := c.db.Query(c.queries.selectRowIDFromMessageID, since.ID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer idrows.Close()
|
||||||
|
if !idrows.Next() {
|
||||||
|
return c.messagesSinceTime(topic, model.SinceAllMessages, scheduled)
|
||||||
|
}
|
||||||
|
var rowID int64
|
||||||
|
if err := idrows.Scan(&rowID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
idrows.Close()
|
||||||
|
var rows *sql.Rows
|
||||||
|
if scheduled {
|
||||||
|
rows, err = c.db.Query(c.queries.selectMessagesSinceIDScheduled, topic, rowID)
|
||||||
|
} else {
|
||||||
|
rows, err = c.db.Query(c.queries.selectMessagesSinceID, topic, rowID)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return readMessages(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) messagesLatest(topic string) ([]*model.Message, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessagesLatest, topic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return readMessages(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) MessagesDue() ([]*model.Message, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return readMessages(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessagesExpired returns a list of IDs for messages that have expired (should be deleted)
|
||||||
|
func (c *commonStore) MessagesExpired() ([]string, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessagesExpired, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
ids := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var id string
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) Message(id string) (*model.Message, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessagesByID, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !rows.Next() {
|
||||||
|
return nil, model.ErrMessageNotFound
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return readMessage(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
|
||||||
|
func (c *commonStore) UpdateMessageTime(messageID string, timestamp int64) error {
|
||||||
|
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) MarkPublished(m *model.Message) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) MessageCounts() (map[string]int, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectMessageCountPerTopic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var topic string
|
||||||
|
var count int
|
||||||
|
counts := make(map[string]int)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&topic, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
counts[topic] = count
|
||||||
|
}
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) Topics() ([]string, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectTopics)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
topics := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var id string
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
topics = append(topics, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return topics, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) DeleteMessages(ids ...string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
tx, err := c.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
for _, id := range ids {
|
||||||
|
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||||
|
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
|
||||||
|
func (c *commonStore) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
tx, err := c.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
ids := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var id string
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) ExpireMessages(topics ...string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
tx, err := c.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
for _, t := range topics {
|
||||||
|
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) AttachmentsExpired() ([]string, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
ids := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var id string
|
||||||
|
if err := rows.Scan(&id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) MarkAttachmentsDeleted(ids ...string) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
tx, err := c.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
for _, id := range ids {
|
||||||
|
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) AttachmentBytesUsedBySender(sender string) (int64, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return c.readAttachmentBytesUsed(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) AttachmentBytesUsedByUser(userID string) (int64, error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return c.readAttachmentBytesUsed(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
|
||||||
|
defer rows.Close()
|
||||||
|
var size int64
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errors.New("no rows found")
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&size); err != nil {
|
||||||
|
return 0, err
|
||||||
|
} else if err := rows.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return size, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) UpdateStats(messages int64) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
_, err := c.db.Exec(c.queries.updateStats, messages)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) Stats() (messages int64, err error) {
|
||||||
|
rows, err := c.db.Query(c.queries.selectStats)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return 0, errNoRows
|
||||||
|
}
|
||||||
|
if err := rows.Scan(&messages); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) Close() error {
|
||||||
|
return c.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commonStore) processMessageBatches() {
|
||||||
|
if c.queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for messages := range c.queue.Dequeue() {
|
||||||
|
if err := c.addMessages(messages); err != nil {
|
||||||
|
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
|
||||||
|
defer rows.Close()
|
||||||
|
messages := make([]*model.Message, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
m, err := readMessage(rows)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
messages = append(messages, m)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readMessage(rows *sql.Rows) (*model.Message, error) {
|
||||||
|
var timestamp, expires, attachmentSize, attachmentExpires int64
|
||||||
|
var priority int
|
||||||
|
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
|
||||||
|
err := rows.Scan(
|
||||||
|
&id,
|
||||||
|
&sequenceID,
|
||||||
|
×tamp,
|
||||||
|
&event,
|
||||||
|
&expires,
|
||||||
|
&topic,
|
||||||
|
&msg,
|
||||||
|
&title,
|
||||||
|
&priority,
|
||||||
|
&tagsStr,
|
||||||
|
&click,
|
||||||
|
&icon,
|
||||||
|
&actionsStr,
|
||||||
|
&attachmentName,
|
||||||
|
&attachmentType,
|
||||||
|
&attachmentSize,
|
||||||
|
&attachmentExpires,
|
||||||
|
&attachmentURL,
|
||||||
|
&sender,
|
||||||
|
&user,
|
||||||
|
&contentType,
|
||||||
|
&encoding,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var tags []string
|
||||||
|
if tagsStr != "" {
|
||||||
|
tags = strings.Split(tagsStr, ",")
|
||||||
|
}
|
||||||
|
var actions []*model.Action
|
||||||
|
if actionsStr != "" {
|
||||||
|
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderIP, err := netip.ParseAddr(sender)
|
||||||
|
if err != nil {
|
||||||
|
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
|
||||||
|
}
|
||||||
|
var att *model.Attachment
|
||||||
|
if attachmentName != "" && attachmentURL != "" {
|
||||||
|
att = &model.Attachment{
|
||||||
|
Name: attachmentName,
|
||||||
|
Type: attachmentType,
|
||||||
|
Size: attachmentSize,
|
||||||
|
Expires: attachmentExpires,
|
||||||
|
URL: attachmentURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &model.Message{
|
||||||
|
ID: id,
|
||||||
|
SequenceID: sequenceID,
|
||||||
|
Time: timestamp,
|
||||||
|
Expires: expires,
|
||||||
|
Event: event,
|
||||||
|
Topic: topic,
|
||||||
|
Message: msg,
|
||||||
|
Title: title,
|
||||||
|
Priority: priority,
|
||||||
|
Tags: tags,
|
||||||
|
Click: click,
|
||||||
|
Icon: icon,
|
||||||
|
Actions: actions,
|
||||||
|
Attachment: att,
|
||||||
|
Sender: senderIP,
|
||||||
|
User: user,
|
||||||
|
ContentType: contentType,
|
||||||
|
Encoding: encoding,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure commonStore implements Store
|
||||||
|
var _ Store = (*commonStore)(nil)
|
||||||
|
|
||||||
|
// Needed by store.go but not part of Store interface; unused import guard
|
||||||
|
var _ = fmt.Sprintf
|
||||||
120
message/store_postgres.go
Normal file
120
message/store_postgres.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL runtime query constants
|
||||||
|
const (
|
||||||
|
pgInsertMessageQuery = `
|
||||||
|
INSERT INTO message (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user_id, content_type, encoding, published)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24)
|
||||||
|
`
|
||||||
|
pgDeleteMessageQuery = `DELETE FROM message WHERE mid = $1`
|
||||||
|
pgSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||||
|
pgDeleteScheduledBySequenceIDQuery = `DELETE FROM message WHERE topic = $1 AND sequence_id = $2 AND published = FALSE`
|
||||||
|
pgUpdateMessagesForTopicExpiryQuery = `UPDATE message SET expires = $1 WHERE topic = $2`
|
||||||
|
pgSelectRowIDFromMessageID = `SELECT id FROM message WHERE mid = $1`
|
||||||
|
pgSelectMessagesByIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE mid = $1
|
||||||
|
`
|
||||||
|
pgSelectMessagesSinceTimeQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND time >= $2 AND published = TRUE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
pgSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND time >= $2
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
pgSelectMessagesSinceIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND id > $2 AND published = TRUE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
pgSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND (id > $2 OR published = FALSE)
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
pgSelectMessagesLatestQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE topic = $1 AND published = TRUE
|
||||||
|
ORDER BY time DESC, id DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
pgSelectMessagesDueQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user_id, content_type, encoding
|
||||||
|
FROM message
|
||||||
|
WHERE time <= $1 AND published = FALSE
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
pgSelectMessagesExpiredQuery = `SELECT mid FROM message WHERE expires <= $1 AND published = TRUE`
|
||||||
|
pgUpdateMessagePublishedQuery = `UPDATE message SET published = TRUE WHERE mid = $1`
|
||||||
|
pgSelectMessagesCountQuery = `SELECT COUNT(*) FROM message`
|
||||||
|
pgSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM message GROUP BY topic`
|
||||||
|
pgSelectTopicsQuery = `SELECT topic FROM message GROUP BY topic`
|
||||||
|
|
||||||
|
pgUpdateAttachmentDeleted = `UPDATE message SET attachment_deleted = TRUE WHERE mid = $1`
|
||||||
|
pgSelectAttachmentsExpiredQuery = `SELECT mid FROM message WHERE attachment_expires > 0 AND attachment_expires <= $1 AND attachment_deleted = FALSE`
|
||||||
|
pgSelectAttachmentsSizeBySenderQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = '' AND sender = $1 AND attachment_expires >= $2`
|
||||||
|
pgSelectAttachmentsSizeByUserIDQuery = `SELECT COALESCE(SUM(attachment_size), 0) FROM message WHERE user_id = $1 AND attachment_expires >= $2`
|
||||||
|
|
||||||
|
pgSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||||
|
pgUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||||
|
pgUpdateMessageTimesQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||||
|
)
|
||||||
|
|
||||||
|
var pgQueries = storeQueries{
|
||||||
|
insertMessage: pgInsertMessageQuery,
|
||||||
|
deleteMessage: pgDeleteMessageQuery,
|
||||||
|
selectScheduledMessageIDsBySeqID: pgSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
|
deleteScheduledBySequenceID: pgDeleteScheduledBySequenceIDQuery,
|
||||||
|
updateMessagesForTopicExpiry: pgUpdateMessagesForTopicExpiryQuery,
|
||||||
|
selectRowIDFromMessageID: pgSelectRowIDFromMessageID,
|
||||||
|
selectMessagesByID: pgSelectMessagesByIDQuery,
|
||||||
|
selectMessagesSinceTime: pgSelectMessagesSinceTimeQuery,
|
||||||
|
selectMessagesSinceTimeScheduled: pgSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||||
|
selectMessagesSinceID: pgSelectMessagesSinceIDQuery,
|
||||||
|
selectMessagesSinceIDScheduled: pgSelectMessagesSinceIDIncludeScheduledQuery,
|
||||||
|
selectMessagesLatest: pgSelectMessagesLatestQuery,
|
||||||
|
selectMessagesDue: pgSelectMessagesDueQuery,
|
||||||
|
selectMessagesExpired: pgSelectMessagesExpiredQuery,
|
||||||
|
updateMessagePublished: pgUpdateMessagePublishedQuery,
|
||||||
|
selectMessagesCount: pgSelectMessagesCountQuery,
|
||||||
|
selectMessageCountPerTopic: pgSelectMessageCountPerTopicQuery,
|
||||||
|
selectTopics: pgSelectTopicsQuery,
|
||||||
|
updateAttachmentDeleted: pgUpdateAttachmentDeleted,
|
||||||
|
selectAttachmentsExpired: pgSelectAttachmentsExpiredQuery,
|
||||||
|
selectAttachmentsSizeBySender: pgSelectAttachmentsSizeBySenderQuery,
|
||||||
|
selectAttachmentsSizeByUserID: pgSelectAttachmentsSizeByUserIDQuery,
|
||||||
|
selectStats: pgSelectStatsQuery,
|
||||||
|
updateStats: pgUpdateStatsQuery,
|
||||||
|
updateMessageTime: pgUpdateMessageTimesQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPostgresStore creates a new PostgreSQL-backed message cache store.
|
||||||
|
func NewPostgresStore(dsn string, batchSize int, batchTimeout time.Duration) (Store, error) {
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := setupPostgresDB(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newCommonStore(db, pgQueries, batchSize, batchTimeout, false), nil
|
||||||
|
}
|
||||||
90
message/store_postgres_schema.go
Normal file
90
message/store_postgres_schema.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initial PostgreSQL schema
|
||||||
|
const (
|
||||||
|
pgCreateTablesQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS message (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
mid TEXT NOT NULL,
|
||||||
|
sequence_id TEXT NOT NULL,
|
||||||
|
time BIGINT NOT NULL,
|
||||||
|
event TEXT NOT NULL,
|
||||||
|
expires BIGINT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags TEXT NOT NULL,
|
||||||
|
click TEXT NOT NULL,
|
||||||
|
icon TEXT NOT NULL,
|
||||||
|
actions TEXT NOT NULL,
|
||||||
|
attachment_name TEXT NOT NULL,
|
||||||
|
attachment_type TEXT NOT NULL,
|
||||||
|
attachment_size BIGINT NOT NULL,
|
||||||
|
attachment_expires BIGINT NOT NULL,
|
||||||
|
attachment_url TEXT NOT NULL,
|
||||||
|
attachment_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
sender TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
content_type TEXT NOT NULL,
|
||||||
|
encoding TEXT NOT NULL,
|
||||||
|
published BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_mid ON message (mid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_sequence_id ON message (sequence_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_time ON message (time);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_topic ON message (topic);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_expires ON message (expires);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_sender ON message (sender);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_user_id ON message (user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_message_attachment_expires ON message (attachment_expires);
|
||||||
|
CREATE TABLE IF NOT EXISTS message_stats (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value BIGINT
|
||||||
|
);
|
||||||
|
INSERT INTO message_stats (key, value) VALUES ('messages', 0);
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_version (
|
||||||
|
store TEXT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// PostgreSQL schema management queries
|
||||||
|
const (
|
||||||
|
pgCurrentSchemaVersion = 14
|
||||||
|
pgInsertSchemaVersion = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||||
|
pgSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupPostgresDB(db *sql.DB) error {
|
||||||
|
var schemaVersion int
|
||||||
|
err := db.QueryRow(pgSelectSchemaVersionQuery).Scan(&schemaVersion)
|
||||||
|
if err != nil {
|
||||||
|
return setupNewPostgresDB(db)
|
||||||
|
}
|
||||||
|
if schemaVersion > pgCurrentSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupNewPostgresDB(db *sql.DB) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(pgCreateTablesQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(pgInsertSchemaVersion, pgCurrentSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
120
message/store_postgres_test.go
Normal file
120
message/store_postgres_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package message_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestPostgresStore(t *testing.T) message.Store {
|
||||||
|
dsn := os.Getenv("NTFY_TEST_DATABASE_URL")
|
||||||
|
if dsn == "" {
|
||||||
|
t.Skip("NTFY_TEST_DATABASE_URL not set, skipping PostgreSQL tests")
|
||||||
|
}
|
||||||
|
// Create a unique schema for this test
|
||||||
|
schema := fmt.Sprintf("test_%s", util.RandomString(10))
|
||||||
|
setupDB, err := sql.Open("pgx", dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, setupDB.Close())
|
||||||
|
// Open store with search_path set to the new schema
|
||||||
|
u, err := url.Parse(dsn)
|
||||||
|
require.Nil(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("search_path", schema)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
store, err := message.NewPostgresStore(u.String(), 0, 0)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
store.Close()
|
||||||
|
cleanDB, err := sql.Open("pgx", dsn)
|
||||||
|
if err == nil {
|
||||||
|
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
|
||||||
|
cleanDB.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Messages(t *testing.T) {
|
||||||
|
testCacheMessages(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessagesLock(t *testing.T) {
|
||||||
|
testCacheMessagesLock(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessagesScheduled(t *testing.T) {
|
||||||
|
testCacheMessagesScheduled(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Topics(t *testing.T) {
|
||||||
|
testCacheTopics(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
|
testCacheMessagesTagsPrioAndTitle(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessagesSinceID(t *testing.T) {
|
||||||
|
testCacheMessagesSinceID(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Prune(t *testing.T) {
|
||||||
|
testCachePrune(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Attachments(t *testing.T) {
|
||||||
|
testCacheAttachments(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_AttachmentsExpired(t *testing.T) {
|
||||||
|
testCacheAttachmentsExpired(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Sender(t *testing.T) {
|
||||||
|
testSender(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||||
|
testDeleteScheduledBySequenceID(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessageByID(t *testing.T) {
|
||||||
|
testMessageByID(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MarkPublished(t *testing.T) {
|
||||||
|
testMarkPublished(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_ExpireMessages(t *testing.T) {
|
||||||
|
testExpireMessages(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||||
|
testMarkAttachmentsDeleted(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_Stats(t *testing.T) {
|
||||||
|
testStats(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_AddMessages(t *testing.T) {
|
||||||
|
testAddMessages(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessagesDue(t *testing.T) {
|
||||||
|
testMessagesDue(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresStore_MessageFieldRoundTrip(t *testing.T) {
|
||||||
|
testMessageFieldRoundTrip(t, newTestPostgresStore(t))
|
||||||
|
}
|
||||||
140
message/store_sqlite.go
Normal file
140
message/store_sqlite.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLite runtime query constants
|
||||||
|
const (
|
||||||
|
sqliteInsertMessageQuery = `
|
||||||
|
INSERT INTO messages (mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, attachment_deleted, sender, user, content_type, encoding, published)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
sqliteDeleteMessageQuery = `DELETE FROM messages WHERE mid = ?`
|
||||||
|
sqliteSelectScheduledMessageIDsBySeqIDQuery = `SELECT mid FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||||
|
sqliteDeleteScheduledBySequenceIDQuery = `DELETE FROM messages WHERE topic = ? AND sequence_id = ? AND published = 0`
|
||||||
|
sqliteUpdateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?`
|
||||||
|
sqliteSelectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?`
|
||||||
|
sqliteSelectMessagesByIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE mid = ?
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesSinceTimeQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE topic = ? AND time >= ? AND published = 1
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesSinceTimeIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE topic = ? AND time >= ?
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesSinceIDQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE topic = ? AND id > ? AND published = 1
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesSinceIDIncludeScheduledQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE topic = ? AND (id > ? OR published = 0)
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesLatestQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE topic = ? AND published = 1
|
||||||
|
ORDER BY time DESC, id DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesDueQuery = `
|
||||||
|
SELECT mid, sequence_id, time, event, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
|
||||||
|
FROM messages
|
||||||
|
WHERE time <= ? AND published = 0
|
||||||
|
ORDER BY time, id
|
||||||
|
`
|
||||||
|
sqliteSelectMessagesExpiredQuery = `SELECT mid FROM messages WHERE expires <= ? AND published = 1`
|
||||||
|
sqliteUpdateMessagePublishedQuery = `UPDATE messages SET published = 1 WHERE mid = ?`
|
||||||
|
sqliteSelectMessagesCountQuery = `SELECT COUNT(*) FROM messages`
|
||||||
|
sqliteSelectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
|
||||||
|
sqliteSelectTopicsQuery = `SELECT topic FROM messages GROUP BY topic`
|
||||||
|
|
||||||
|
sqliteUpdateAttachmentDeleted = `UPDATE messages SET attachment_deleted = 1 WHERE mid = ?`
|
||||||
|
sqliteSelectAttachmentsExpiredQuery = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
|
||||||
|
sqliteSelectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
|
||||||
|
sqliteSelectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
|
||||||
|
|
||||||
|
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||||
|
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||||
|
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||||
|
)
|
||||||
|
|
||||||
|
var sqliteQueries = storeQueries{
|
||||||
|
insertMessage: sqliteInsertMessageQuery,
|
||||||
|
deleteMessage: sqliteDeleteMessageQuery,
|
||||||
|
selectScheduledMessageIDsBySeqID: sqliteSelectScheduledMessageIDsBySeqIDQuery,
|
||||||
|
deleteScheduledBySequenceID: sqliteDeleteScheduledBySequenceIDQuery,
|
||||||
|
updateMessagesForTopicExpiry: sqliteUpdateMessagesForTopicExpiryQuery,
|
||||||
|
selectRowIDFromMessageID: sqliteSelectRowIDFromMessageID,
|
||||||
|
selectMessagesByID: sqliteSelectMessagesByIDQuery,
|
||||||
|
selectMessagesSinceTime: sqliteSelectMessagesSinceTimeQuery,
|
||||||
|
selectMessagesSinceTimeScheduled: sqliteSelectMessagesSinceTimeIncludeScheduledQuery,
|
||||||
|
selectMessagesSinceID: sqliteSelectMessagesSinceIDQuery,
|
||||||
|
selectMessagesSinceIDScheduled: sqliteSelectMessagesSinceIDIncludeScheduledQuery,
|
||||||
|
selectMessagesLatest: sqliteSelectMessagesLatestQuery,
|
||||||
|
selectMessagesDue: sqliteSelectMessagesDueQuery,
|
||||||
|
selectMessagesExpired: sqliteSelectMessagesExpiredQuery,
|
||||||
|
updateMessagePublished: sqliteUpdateMessagePublishedQuery,
|
||||||
|
selectMessagesCount: sqliteSelectMessagesCountQuery,
|
||||||
|
selectMessageCountPerTopic: sqliteSelectMessageCountPerTopicQuery,
|
||||||
|
selectTopics: sqliteSelectTopicsQuery,
|
||||||
|
updateAttachmentDeleted: sqliteUpdateAttachmentDeleted,
|
||||||
|
selectAttachmentsExpired: sqliteSelectAttachmentsExpiredQuery,
|
||||||
|
selectAttachmentsSizeBySender: sqliteSelectAttachmentsSizeBySenderQuery,
|
||||||
|
selectAttachmentsSizeByUserID: sqliteSelectAttachmentsSizeByUserIDQuery,
|
||||||
|
selectStats: sqliteSelectStatsQuery,
|
||||||
|
updateStats: sqliteUpdateStatsQuery,
|
||||||
|
updateMessageTime: sqliteUpdateMessageTimeQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSQLiteStore creates a SQLite file-backed cache
|
||||||
|
func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration, batchSize int, batchTimeout time.Duration, nop bool) (Store, error) {
|
||||||
|
parentDir := filepath.Dir(filename)
|
||||||
|
if !util.FileExists(parentDir) {
|
||||||
|
return nil, fmt.Errorf("cache database directory %s does not exist or is not accessible", parentDir)
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := setupSQLite(db, startupQueries, cacheDuration); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newCommonStore(db, sqliteQueries, batchSize, batchTimeout, nop), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemStore creates an in-memory cache
|
||||||
|
func NewMemStore() (Store, error) {
|
||||||
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNopStore creates an in-memory cache that discards all messages;
|
||||||
|
// it is always empty and can be used if caching is entirely disabled
|
||||||
|
func NewNopStore() (Store, error) {
|
||||||
|
return NewSQLiteStore(createMemoryFilename(), "", 0, 0, 0, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createMemoryFilename creates a unique memory filename to use for the SQLite backend.
|
||||||
|
func createMemoryFilename() string {
|
||||||
|
return fmt.Sprintf("file:%s?mode=memory&cache=shared", util.RandomString(10))
|
||||||
|
}
|
||||||
466
message/store_sqlite_schema.go
Normal file
466
message/store_sqlite_schema.go
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initial SQLite schema
|
||||||
|
const (
|
||||||
|
sqliteCreateMessagesTableQuery = `
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
mid TEXT NOT NULL,
|
||||||
|
sequence_id TEXT NOT NULL,
|
||||||
|
time INT NOT NULL,
|
||||||
|
event TEXT NOT NULL,
|
||||||
|
expires INT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags TEXT NOT NULL,
|
||||||
|
click TEXT NOT NULL,
|
||||||
|
icon TEXT NOT NULL,
|
||||||
|
actions TEXT NOT NULL,
|
||||||
|
attachment_name TEXT NOT NULL,
|
||||||
|
attachment_type TEXT NOT NULL,
|
||||||
|
attachment_size INT NOT NULL,
|
||||||
|
attachment_expires INT NOT NULL,
|
||||||
|
attachment_url TEXT NOT NULL,
|
||||||
|
attachment_deleted INT NOT NULL,
|
||||||
|
sender TEXT NOT NULL,
|
||||||
|
user TEXT NOT NULL,
|
||||||
|
content_type TEXT NOT NULL,
|
||||||
|
encoding TEXT NOT NULL,
|
||||||
|
published INT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||||
|
CREATE TABLE IF NOT EXISTS stats (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value INT
|
||||||
|
);
|
||||||
|
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||||
|
COMMIT;
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema version management for SQLite
|
||||||
|
const (
|
||||||
|
sqliteCurrentSchemaVersion = 14
|
||||||
|
sqliteCreateSchemaVersionTableQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
sqliteInsertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
|
||||||
|
sqliteUpdateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
|
||||||
|
sqliteSelectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema migrations for SQLite
|
||||||
|
const (
|
||||||
|
// 0 -> 1
|
||||||
|
sqliteMigrate0To1AlterMessagesTableQuery = `
|
||||||
|
BEGIN;
|
||||||
|
ALTER TABLE messages ADD COLUMN title TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN priority INT NOT NULL DEFAULT(0);
|
||||||
|
ALTER TABLE messages ADD COLUMN tags TEXT NOT NULL DEFAULT('');
|
||||||
|
COMMIT;
|
||||||
|
`
|
||||||
|
|
||||||
|
// 1 -> 2
|
||||||
|
sqliteMigrate1To2AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN published INT NOT NULL DEFAULT(1);
|
||||||
|
`
|
||||||
|
|
||||||
|
// 2 -> 3
|
||||||
|
sqliteMigrate2To3AlterMessagesTableQuery = `
|
||||||
|
BEGIN;
|
||||||
|
ALTER TABLE messages ADD COLUMN click TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_name TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_type TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_size INT NOT NULL DEFAULT('0');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_expires INT NOT NULL DEFAULT('0');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_owner TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_url TEXT NOT NULL DEFAULT('');
|
||||||
|
COMMIT;
|
||||||
|
`
|
||||||
|
// 3 -> 4
|
||||||
|
sqliteMigrate3To4AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN encoding TEXT NOT NULL DEFAULT('');
|
||||||
|
`
|
||||||
|
|
||||||
|
// 4 -> 5
|
||||||
|
sqliteMigrate4To5AlterMessagesTableQuery = `
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS messages_new (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
mid TEXT NOT NULL,
|
||||||
|
time INT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags TEXT NOT NULL,
|
||||||
|
click TEXT NOT NULL,
|
||||||
|
attachment_name TEXT NOT NULL,
|
||||||
|
attachment_type TEXT NOT NULL,
|
||||||
|
attachment_size INT NOT NULL,
|
||||||
|
attachment_expires INT NOT NULL,
|
||||||
|
attachment_url TEXT NOT NULL,
|
||||||
|
attachment_owner TEXT NOT NULL,
|
||||||
|
encoding TEXT NOT NULL,
|
||||||
|
published INT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_mid ON messages_new (mid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages_new (topic);
|
||||||
|
INSERT
|
||||||
|
INTO messages_new (
|
||||||
|
mid, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||||
|
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published)
|
||||||
|
SELECT
|
||||||
|
id, time, topic, message, title, priority, tags, click, attachment_name, attachment_type,
|
||||||
|
attachment_size, attachment_expires, attachment_url, attachment_owner, encoding, published
|
||||||
|
FROM messages;
|
||||||
|
DROP TABLE messages;
|
||||||
|
ALTER TABLE messages_new RENAME TO messages;
|
||||||
|
COMMIT;
|
||||||
|
`
|
||||||
|
|
||||||
|
// 5 -> 6
|
||||||
|
sqliteMigrate5To6AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN actions TEXT NOT NULL DEFAULT('');
|
||||||
|
`
|
||||||
|
|
||||||
|
// 6 -> 7
|
||||||
|
sqliteMigrate6To7AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages RENAME COLUMN attachment_owner TO sender;
|
||||||
|
`
|
||||||
|
|
||||||
|
// 7 -> 8
|
||||||
|
sqliteMigrate7To8AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN icon TEXT NOT NULL DEFAULT('');
|
||||||
|
`
|
||||||
|
|
||||||
|
// 8 -> 9
|
||||||
|
sqliteMigrate8To9AlterMessagesTableQuery = `
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||||
|
`
|
||||||
|
|
||||||
|
// 9 -> 10
|
||||||
|
sqliteMigrate9To10AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN user TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN attachment_deleted INT NOT NULL DEFAULT('0');
|
||||||
|
ALTER TABLE messages ADD COLUMN expires INT NOT NULL DEFAULT('0');
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_expires ON messages (expires);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
|
||||||
|
`
|
||||||
|
sqliteMigrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
|
||||||
|
|
||||||
|
// 10 -> 11
|
||||||
|
sqliteMigrate10To11AlterMessagesTableQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS stats (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value INT
|
||||||
|
);
|
||||||
|
INSERT INTO stats (key, value) VALUES ('messages', 0);
|
||||||
|
`
|
||||||
|
|
||||||
|
// 11 -> 12
|
||||||
|
sqliteMigrate11To12AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN content_type TEXT NOT NULL DEFAULT('');
|
||||||
|
`
|
||||||
|
|
||||||
|
// 12 -> 13
|
||||||
|
sqliteMigrate12To13AlterMessagesTableQuery = `
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||||
|
`
|
||||||
|
|
||||||
|
// 13 -> 14
|
||||||
|
sqliteMigrate13To14AlterMessagesTableQuery = `
|
||||||
|
ALTER TABLE messages ADD COLUMN sequence_id TEXT NOT NULL DEFAULT('');
|
||||||
|
ALTER TABLE messages ADD COLUMN event TEXT NOT NULL DEFAULT('message');
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sequence_id ON messages (sequence_id);
|
||||||
|
`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sqliteMigrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
|
||||||
|
0: sqliteMigrateFrom0,
|
||||||
|
1: sqliteMigrateFrom1,
|
||||||
|
2: sqliteMigrateFrom2,
|
||||||
|
3: sqliteMigrateFrom3,
|
||||||
|
4: sqliteMigrateFrom4,
|
||||||
|
5: sqliteMigrateFrom5,
|
||||||
|
6: sqliteMigrateFrom6,
|
||||||
|
7: sqliteMigrateFrom7,
|
||||||
|
8: sqliteMigrateFrom8,
|
||||||
|
9: sqliteMigrateFrom9,
|
||||||
|
10: sqliteMigrateFrom10,
|
||||||
|
11: sqliteMigrateFrom11,
|
||||||
|
12: sqliteMigrateFrom12,
|
||||||
|
13: sqliteMigrateFrom13,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupSQLite(db *sql.DB, startupQueries string, cacheDuration time.Duration) error {
|
||||||
|
if err := runSQLiteStartupQueries(db, startupQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If 'messages' table does not exist, this must be a new database
|
||||||
|
rowsMC, err := db.Query(sqliteSelectMessagesCountQuery)
|
||||||
|
if err != nil {
|
||||||
|
return setupNewSQLite(db)
|
||||||
|
}
|
||||||
|
rowsMC.Close()
|
||||||
|
// If 'messages' table exists, check 'schemaVersion' table
|
||||||
|
schemaVersion := 0
|
||||||
|
rowsSV, err := db.Query(sqliteSelectSchemaVersionQuery)
|
||||||
|
if err == nil {
|
||||||
|
defer rowsSV.Close()
|
||||||
|
if !rowsSV.Next() {
|
||||||
|
return fmt.Errorf("cannot determine schema version: cache file may be corrupt")
|
||||||
|
}
|
||||||
|
if err := rowsSV.Scan(&schemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rowsSV.Close()
|
||||||
|
}
|
||||||
|
// Do migrations
|
||||||
|
if schemaVersion == sqliteCurrentSchemaVersion {
|
||||||
|
return nil
|
||||||
|
} else if schemaVersion > sqliteCurrentSchemaVersion {
|
||||||
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, sqliteCurrentSchemaVersion)
|
||||||
|
}
|
||||||
|
for i := schemaVersion; i < sqliteCurrentSchemaVersion; i++ {
|
||||||
|
fn, ok := sqliteMigrations[i]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
|
||||||
|
} else if err := fn(db, cacheDuration); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupNewSQLite(db *sql.DB) error {
|
||||||
|
if _, err := db.Exec(sqliteCreateMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteInsertSchemaVersion, sqliteCurrentSchemaVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||||
|
if startupQueries != "" {
|
||||||
|
if _, err := db.Exec(startupQueries); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom0(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 0 to 1")
|
||||||
|
if _, err := db.Exec(sqliteMigrate0To1AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteCreateSchemaVersionTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteInsertSchemaVersion, 1); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom1(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 1 to 2")
|
||||||
|
if _, err := db.Exec(sqliteMigrate1To2AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 2); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom2(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 2 to 3")
|
||||||
|
if _, err := db.Exec(sqliteMigrate2To3AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 3); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom3(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 3 to 4")
|
||||||
|
if _, err := db.Exec(sqliteMigrate3To4AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 4); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom4(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 4 to 5")
|
||||||
|
if _, err := db.Exec(sqliteMigrate4To5AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 5); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom5(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 5 to 6")
|
||||||
|
if _, err := db.Exec(sqliteMigrate5To6AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 6); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom6(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 6 to 7")
|
||||||
|
if _, err := db.Exec(sqliteMigrate6To7AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 7); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom7(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 7 to 8")
|
||||||
|
if _, err := db.Exec(sqliteMigrate7To8AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 8); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 8 to 9")
|
||||||
|
if _, err := db.Exec(sqliteMigrate8To9AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(sqliteUpdateSchemaVersion, 9); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 10); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 11); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 12); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 13); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
||||||
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(sqliteUpdateSchemaVersion, 14); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
459
message/store_sqlite_test.go
Normal file
459
message/store_sqlite_test.go
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
package message_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSqliteStore_Messages(t *testing.T) {
|
||||||
|
testCacheMessages(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Messages(t *testing.T) {
|
||||||
|
testCacheMessages(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessagesLock(t *testing.T) {
|
||||||
|
testCacheMessagesLock(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessagesLock(t *testing.T) {
|
||||||
|
testCacheMessagesLock(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessagesScheduled(t *testing.T) {
|
||||||
|
testCacheMessagesScheduled(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessagesScheduled(t *testing.T) {
|
||||||
|
testCacheMessagesScheduled(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Topics(t *testing.T) {
|
||||||
|
testCacheTopics(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Topics(t *testing.T) {
|
||||||
|
testCacheTopics(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
|
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessagesTagsPrioAndTitle(t *testing.T) {
|
||||||
|
testCacheMessagesTagsPrioAndTitle(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessagesSinceID(t *testing.T) {
|
||||||
|
testCacheMessagesSinceID(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessagesSinceID(t *testing.T) {
|
||||||
|
testCacheMessagesSinceID(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Prune(t *testing.T) {
|
||||||
|
testCachePrune(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Prune(t *testing.T) {
|
||||||
|
testCachePrune(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Attachments(t *testing.T) {
|
||||||
|
testCacheAttachments(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Attachments(t *testing.T) {
|
||||||
|
testCacheAttachments(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_AttachmentsExpired(t *testing.T) {
|
||||||
|
testCacheAttachmentsExpired(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_AttachmentsExpired(t *testing.T) {
|
||||||
|
testCacheAttachmentsExpired(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Sender(t *testing.T) {
|
||||||
|
testSender(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Sender(t *testing.T) {
|
||||||
|
testSender(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||||
|
testDeleteScheduledBySequenceID(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_DeleteScheduledBySequenceID(t *testing.T) {
|
||||||
|
testDeleteScheduledBySequenceID(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessageByID(t *testing.T) {
|
||||||
|
testMessageByID(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessageByID(t *testing.T) {
|
||||||
|
testMessageByID(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MarkPublished(t *testing.T) {
|
||||||
|
testMarkPublished(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MarkPublished(t *testing.T) {
|
||||||
|
testMarkPublished(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_ExpireMessages(t *testing.T) {
|
||||||
|
testExpireMessages(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_ExpireMessages(t *testing.T) {
|
||||||
|
testExpireMessages(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||||
|
testMarkAttachmentsDeleted(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MarkAttachmentsDeleted(t *testing.T) {
|
||||||
|
testMarkAttachmentsDeleted(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Stats(t *testing.T) {
|
||||||
|
testStats(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_Stats(t *testing.T) {
|
||||||
|
testStats(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_AddMessages(t *testing.T) {
|
||||||
|
testAddMessages(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_AddMessages(t *testing.T) {
|
||||||
|
testAddMessages(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessagesDue(t *testing.T) {
|
||||||
|
testMessagesDue(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessagesDue(t *testing.T) {
|
||||||
|
testMessagesDue(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_MessageFieldRoundTrip(t *testing.T) {
|
||||||
|
testMessageFieldRoundTrip(t, newSqliteTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemStore_MessageFieldRoundTrip(t *testing.T) {
|
||||||
|
testMessageFieldRoundTrip(t, newMemTestStore(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Migration_From0(t *testing.T) {
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create "version 0" schema
|
||||||
|
_, err = db.Exec(`
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id VARCHAR(20) PRIMARY KEY,
|
||||||
|
time INT NOT NULL,
|
||||||
|
topic VARCHAR(64) NOT NULL,
|
||||||
|
message VARCHAR(1024) NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||||
|
COMMIT;
|
||||||
|
`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Insert a bunch of messages
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
||||||
|
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
require.Nil(t, db.Close())
|
||||||
|
|
||||||
|
// Create store to trigger migration
|
||||||
|
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||||
|
checkSqliteSchemaVersion(t, filename)
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 10, len(messages))
|
||||||
|
require.Equal(t, "some message 5", messages[5].Message)
|
||||||
|
require.Equal(t, "", messages[5].Title)
|
||||||
|
require.Nil(t, messages[5].Tags)
|
||||||
|
require.Equal(t, 0, messages[5].Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Migration_From1(t *testing.T) {
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create "version 1" schema
|
||||||
|
_, err = db.Exec(`
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id VARCHAR(20) PRIMARY KEY,
|
||||||
|
time INT NOT NULL,
|
||||||
|
topic VARCHAR(64) NOT NULL,
|
||||||
|
message VARCHAR(512) NOT NULL,
|
||||||
|
title VARCHAR(256) NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags VARCHAR(256) NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||||
|
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
||||||
|
`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Insert a bunch of messages
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
require.Nil(t, db.Close())
|
||||||
|
|
||||||
|
// Create store to trigger migration
|
||||||
|
s := newSqliteTestStoreFromFile(t, filename, "")
|
||||||
|
checkSqliteSchemaVersion(t, filename)
|
||||||
|
|
||||||
|
// Add delayed message
|
||||||
|
delayedMessage := model.NewDefaultMessage("mytopic", "some delayed message")
|
||||||
|
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(delayedMessage))
|
||||||
|
|
||||||
|
// 10, not 11!
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 10, len(messages))
|
||||||
|
|
||||||
|
// 11!
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 11, len(messages))
|
||||||
|
|
||||||
|
// Check that index "idx_topic" exists
|
||||||
|
verifyDB, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer verifyDB.Close()
|
||||||
|
rows, err := verifyDB.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
var indexName string
|
||||||
|
require.Nil(t, rows.Scan(&indexName))
|
||||||
|
require.Equal(t, "idx_topic", indexName)
|
||||||
|
require.Nil(t, rows.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_Migration_From9(t *testing.T) {
|
||||||
|
// This primarily tests the awkward migration that introduces the "expires" column.
|
||||||
|
// The migration logic has to update the column, using the existing "cache-duration" value.
|
||||||
|
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Create "version 9" schema
|
||||||
|
_, err = db.Exec(`
|
||||||
|
BEGIN;
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
mid TEXT NOT NULL,
|
||||||
|
time INT NOT NULL,
|
||||||
|
topic TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
priority INT NOT NULL,
|
||||||
|
tags TEXT NOT NULL,
|
||||||
|
click TEXT NOT NULL,
|
||||||
|
icon TEXT NOT NULL,
|
||||||
|
actions TEXT NOT NULL,
|
||||||
|
attachment_name TEXT NOT NULL,
|
||||||
|
attachment_type TEXT NOT NULL,
|
||||||
|
attachment_size INT NOT NULL,
|
||||||
|
attachment_expires INT NOT NULL,
|
||||||
|
attachment_url TEXT NOT NULL,
|
||||||
|
sender TEXT NOT NULL,
|
||||||
|
encoding TEXT NOT NULL,
|
||||||
|
published INT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
||||||
|
CREATE TABLE IF NOT EXISTS schemaVersion (
|
||||||
|
id INT PRIMARY KEY,
|
||||||
|
version INT NOT NULL
|
||||||
|
);
|
||||||
|
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
||||||
|
COMMIT;
|
||||||
|
`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Insert a bunch of messages
|
||||||
|
insertQuery := `
|
||||||
|
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, err = db.Exec(
|
||||||
|
insertQuery,
|
||||||
|
fmt.Sprintf("abcd%d", i),
|
||||||
|
time.Now().Unix(),
|
||||||
|
"mytopic",
|
||||||
|
fmt.Sprintf("some message %d", i),
|
||||||
|
"", // title
|
||||||
|
0, // priority
|
||||||
|
"", // tags
|
||||||
|
"", // click
|
||||||
|
"", // icon
|
||||||
|
"", // actions
|
||||||
|
"", // attachment_name
|
||||||
|
"", // attachment_type
|
||||||
|
0, // attachment_size
|
||||||
|
0, // attachment_expires
|
||||||
|
"", // attachment_url
|
||||||
|
"9.9.9.9", // sender
|
||||||
|
"", // encoding
|
||||||
|
1, // published
|
||||||
|
)
|
||||||
|
require.Nil(t, err)
|
||||||
|
}
|
||||||
|
require.Nil(t, db.Close())
|
||||||
|
|
||||||
|
// Create store to trigger migration
|
||||||
|
cacheDuration := 17 * time.Hour
|
||||||
|
s, err := message.NewSQLiteStore(filename, "", cacheDuration, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
checkSqliteSchemaVersion(t, filename)
|
||||||
|
|
||||||
|
// Check version
|
||||||
|
verifyDB, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer verifyDB.Close()
|
||||||
|
rows, err := verifyDB.Query(`SELECT version FROM schemaVersion WHERE id = 1`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
var version int
|
||||||
|
require.Nil(t, rows.Scan(&version))
|
||||||
|
require.Equal(t, 14, version)
|
||||||
|
require.Nil(t, rows.Close())
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 10, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
||||||
|
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_StartupQueries_WAL(t *testing.T) {
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
startupQueries := `pragma journal_mode = WAL;
|
||||||
|
pragma synchronous = normal;
|
||||||
|
pragma temp_store = memory;`
|
||||||
|
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||||
|
require.FileExists(t, filename)
|
||||||
|
require.FileExists(t, filename+"-wal")
|
||||||
|
require.FileExists(t, filename+"-shm")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_StartupQueries_None(t *testing.T) {
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "some message")))
|
||||||
|
require.FileExists(t, filename)
|
||||||
|
require.NoFileExists(t, filename+"-wal")
|
||||||
|
require.NoFileExists(t, filename+"-shm")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqliteStore_StartupQueries_Fail(t *testing.T) {
|
||||||
|
filename := newSqliteTestStoreFile(t)
|
||||||
|
_, err := message.NewSQLiteStore(filename, `xx error`, time.Hour, 0, 0, false)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNopStore(t *testing.T) {
|
||||||
|
s, err := message.NewNopStore()
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "my message")))
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, messages)
|
||||||
|
|
||||||
|
topics, err := s.Topics()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, topics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSqliteTestStore(t *testing.T) message.Store {
|
||||||
|
filename := filepath.Join(t.TempDir(), "cache.db")
|
||||||
|
s, err := message.NewSQLiteStore(filename, "", time.Hour, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSqliteTestStoreFile(t *testing.T) string {
|
||||||
|
return filepath.Join(t.TempDir(), "cache.db")
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSqliteTestStoreFromFile(t *testing.T, filename, startupQueries string) message.Store {
|
||||||
|
s, err := message.NewSQLiteStore(filename, startupQueries, time.Hour, 0, 0, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMemTestStore(t *testing.T) message.Store {
|
||||||
|
s, err := message.NewMemStore()
|
||||||
|
require.Nil(t, err)
|
||||||
|
t.Cleanup(func() { s.Close() })
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSqliteSchemaVersion(t *testing.T, filename string) {
|
||||||
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
require.Nil(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
var schemaVersion int
|
||||||
|
require.Nil(t, rows.Scan(&schemaVersion))
|
||||||
|
require.Equal(t, 14, schemaVersion)
|
||||||
|
require.Nil(t, rows.Close())
|
||||||
|
}
|
||||||
767
message/store_test.go
Normal file
767
message/store_test.go
Normal file
@@ -0,0 +1,767 @@
|
|||||||
|
package message_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testCacheMessages(t *testing.T, s message.Store) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
|
m1.Time = 1
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||||
|
m2.Time = 2
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("example", "my example message")))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Adding invalid
|
||||||
|
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewKeepaliveMessage("mytopic"))) // These should not be added!
|
||||||
|
require.Equal(t, model.ErrUnexpectedMessageType, s.AddMessage(model.NewOpenMessage("example"))) // These should not be added!
|
||||||
|
|
||||||
|
// mytopic: count
|
||||||
|
counts, err := s.MessageCounts()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, counts["mytopic"])
|
||||||
|
|
||||||
|
// mytopic: since all
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, "my message", messages[0].Message)
|
||||||
|
require.Equal(t, "mytopic", messages[0].Topic)
|
||||||
|
require.Equal(t, model.MessageEvent, messages[0].Event)
|
||||||
|
require.Equal(t, "", messages[0].Title)
|
||||||
|
require.Equal(t, 0, messages[0].Priority)
|
||||||
|
require.Nil(t, messages[0].Tags)
|
||||||
|
require.Equal(t, "my other message", messages[1].Message)
|
||||||
|
|
||||||
|
// mytopic: since none
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceNoMessages, false)
|
||||||
|
require.Empty(t, messages)
|
||||||
|
|
||||||
|
// mytopic: since m1 (by ID)
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m1.ID), false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, m2.ID, messages[0].ID)
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
require.Equal(t, "mytopic", messages[0].Topic)
|
||||||
|
|
||||||
|
// mytopic: since 2
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceTime(2), false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
|
||||||
|
// mytopic: latest
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceLatestMessage, false)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
|
||||||
|
// example: count
|
||||||
|
counts, err = s.MessageCounts()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, counts["example"])
|
||||||
|
|
||||||
|
// example: since all
|
||||||
|
messages, _ = s.Messages("example", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, "my example message", messages[0].Message)
|
||||||
|
|
||||||
|
// non-existing: count
|
||||||
|
counts, err = s.MessageCounts()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, counts["doesnotexist"])
|
||||||
|
|
||||||
|
// non-existing: since all
|
||||||
|
messages, _ = s.Messages("doesnotexist", model.SinceAllMessages, false)
|
||||||
|
require.Empty(t, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheMessagesLock(t *testing.T, s message.Store) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5000; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
assert.Nil(t, s.AddMessage(model.NewDefaultMessage("mytopic", "test message")))
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheMessagesScheduled(t *testing.T, s message.Store) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
|
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||||
|
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
||||||
|
m4 := model.NewDefaultMessage("mytopic2", "message 4")
|
||||||
|
m4.Time = time.Now().Add(time.Minute).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false) // exclude scheduled
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
|
||||||
|
messages, _ = s.Messages("mytopic", model.SinceAllMessages, true) // include scheduled
|
||||||
|
require.Equal(t, 3, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
require.Equal(t, "message 3", messages[1].Message) // Order!
|
||||||
|
require.Equal(t, "message 2", messages[2].Message)
|
||||||
|
|
||||||
|
messages, _ = s.MessagesDue()
|
||||||
|
require.Empty(t, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheTopics(t *testing.T, s message.Store) {
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic1", "my example message")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 1")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 2")))
|
||||||
|
require.Nil(t, s.AddMessage(model.NewDefaultMessage("topic2", "message 3")))
|
||||||
|
|
||||||
|
topics, err := s.Topics()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
require.Equal(t, 2, len(topics))
|
||||||
|
require.Contains(t, topics, "topic1")
|
||||||
|
require.Contains(t, topics, "topic2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheMessagesTagsPrioAndTitle(t *testing.T, s message.Store) {
|
||||||
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
|
m.Priority = 5
|
||||||
|
m.Title = "some title"
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
messages, _ := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
||||||
|
require.Equal(t, 5, messages[0].Priority)
|
||||||
|
require.Equal(t, "some title", messages[0].Title)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheMessagesSinceID(t *testing.T, s message.Store) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "message 1")
|
||||||
|
m1.Time = 100
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "message 2")
|
||||||
|
m2.Time = 200
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "message 3")
|
||||||
|
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
||||||
|
m4 := model.NewDefaultMessage("mytopic", "message 4")
|
||||||
|
m4.Time = 400
|
||||||
|
m5 := model.NewDefaultMessage("mytopic", "message 5")
|
||||||
|
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
||||||
|
m6 := model.NewDefaultMessage("mytopic", "message 6")
|
||||||
|
m6.Time = 600
|
||||||
|
m7 := model.NewDefaultMessage("mytopic", "message 7")
|
||||||
|
m7.Time = 700
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
require.Nil(t, s.AddMessage(m4))
|
||||||
|
require.Nil(t, s.AddMessage(m5))
|
||||||
|
require.Nil(t, s.AddMessage(m6))
|
||||||
|
require.Nil(t, s.AddMessage(m7))
|
||||||
|
|
||||||
|
// Case 1: Since ID exists, exclude scheduled
|
||||||
|
messages, _ := s.Messages("mytopic", model.NewSinceID(m2.ID), false)
|
||||||
|
require.Equal(t, 3, len(messages))
|
||||||
|
require.Equal(t, "message 4", messages[0].Message)
|
||||||
|
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
||||||
|
require.Equal(t, "message 7", messages[2].Message)
|
||||||
|
|
||||||
|
// Case 2: Since ID exists, include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m2.ID), true)
|
||||||
|
require.Equal(t, 5, len(messages))
|
||||||
|
require.Equal(t, "message 4", messages[0].Message)
|
||||||
|
require.Equal(t, "message 6", messages[1].Message)
|
||||||
|
require.Equal(t, "message 7", messages[2].Message)
|
||||||
|
require.Equal(t, "message 5", messages[3].Message) // Order!
|
||||||
|
require.Equal(t, "message 3", messages[4].Message) // Order!
|
||||||
|
|
||||||
|
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID("doesntexist"), true)
|
||||||
|
require.Equal(t, 7, len(messages))
|
||||||
|
require.Equal(t, "message 1", messages[0].Message)
|
||||||
|
require.Equal(t, "message 2", messages[1].Message)
|
||||||
|
require.Equal(t, "message 4", messages[2].Message)
|
||||||
|
require.Equal(t, "message 6", messages[3].Message)
|
||||||
|
require.Equal(t, "message 7", messages[4].Message)
|
||||||
|
require.Equal(t, "message 5", messages[5].Message) // Order!
|
||||||
|
require.Equal(t, "message 3", messages[6].Message) // Order!
|
||||||
|
|
||||||
|
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), false)
|
||||||
|
require.Equal(t, 0, len(messages))
|
||||||
|
|
||||||
|
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
||||||
|
messages, _ = s.Messages("mytopic", model.NewSinceID(m7.ID), true)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, "message 5", messages[0].Message)
|
||||||
|
require.Equal(t, "message 3", messages[1].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCachePrune(t *testing.T, s message.Store) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "my message")
|
||||||
|
m1.Time = now - 10
|
||||||
|
m1.Expires = now - 5
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "my other message")
|
||||||
|
m2.Time = now - 5
|
||||||
|
m2.Expires = now + 5 // In the future
|
||||||
|
|
||||||
|
m3 := model.NewDefaultMessage("another_topic", "and another one")
|
||||||
|
m3.Time = now - 12
|
||||||
|
m3.Expires = now - 2
|
||||||
|
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
counts, err := s.MessageCounts()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, counts["mytopic"])
|
||||||
|
require.Equal(t, 1, counts["another_topic"])
|
||||||
|
|
||||||
|
expiredMessageIDs, err := s.MessagesExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, s.DeleteMessages(expiredMessageIDs...))
|
||||||
|
|
||||||
|
counts, err = s.MessageCounts()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, counts["mytopic"])
|
||||||
|
require.Equal(t, 0, counts["another_topic"])
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "my other message", messages[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheAttachments(t *testing.T, s message.Store) {
|
||||||
|
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
||||||
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
|
m.ID = "m1"
|
||||||
|
m.SequenceID = "m1"
|
||||||
|
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "flower.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 5000,
|
||||||
|
Expires: expires1,
|
||||||
|
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
||||||
|
m = model.NewDefaultMessage("mytopic", "sending you a car")
|
||||||
|
m.ID = "m2"
|
||||||
|
m.SequenceID = "m2"
|
||||||
|
m.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 10000,
|
||||||
|
Expires: expires2,
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
||||||
|
m = model.NewDefaultMessage("another-topic", "sending you another car")
|
||||||
|
m.ID = "m3"
|
||||||
|
m.SequenceID = "m3"
|
||||||
|
m.User = "u_BAsbaAa"
|
||||||
|
m.Sender = netip.MustParseAddr("5.6.7.8")
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "another-car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 20000,
|
||||||
|
Expires: expires3,
|
||||||
|
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
require.Equal(t, "flower for you", messages[0].Message)
|
||||||
|
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
||||||
|
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
||||||
|
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
||||||
|
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
||||||
|
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
||||||
|
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
||||||
|
|
||||||
|
require.Equal(t, "sending you a car", messages[1].Message)
|
||||||
|
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
||||||
|
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
||||||
|
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
||||||
|
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
||||||
|
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
||||||
|
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
||||||
|
|
||||||
|
size, err := s.AttachmentBytesUsedBySender("1.2.3.4")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(10000), size)
|
||||||
|
|
||||||
|
size, err = s.AttachmentBytesUsedBySender("5.6.7.8")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
||||||
|
|
||||||
|
size, err = s.AttachmentBytesUsedByUser("u_BAsbaAa")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(20000), size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCacheAttachmentsExpired(t *testing.T, s message.Store) {
|
||||||
|
m := model.NewDefaultMessage("mytopic", "flower for you")
|
||||||
|
m.ID = "m1"
|
||||||
|
m.SequenceID = "m1"
|
||||||
|
m.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic", "message with attachment")
|
||||||
|
m.ID = "m2"
|
||||||
|
m.SequenceID = "m2"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 10000,
|
||||||
|
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic", "message with external attachment")
|
||||||
|
m.ID = "m3"
|
||||||
|
m.SequenceID = "m3"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Expires: 0, // Unknown!
|
||||||
|
URL: "https://somedomain.com/car.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
m = model.NewDefaultMessage("mytopic2", "message with expired attachment")
|
||||||
|
m.ID = "m4"
|
||||||
|
m.SequenceID = "m4"
|
||||||
|
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
||||||
|
m.Attachment = &model.Attachment{
|
||||||
|
Name: "expired-car.jpg",
|
||||||
|
Type: "image/jpeg",
|
||||||
|
Size: 20000,
|
||||||
|
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
||||||
|
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
ids, err := s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(ids))
|
||||||
|
require.Equal(t, "m4", ids[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func testSender(t *testing.T, s message.Store) {
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "mymessage")
|
||||||
|
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "mymessage without sender")
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
||||||
|
require.Equal(t, messages[1].Sender, netip.Addr{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDeleteScheduledBySequenceID(t *testing.T, s message.Store) {
|
||||||
|
// Create a scheduled (unpublished) message
|
||||||
|
scheduledMsg := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
|
scheduledMsg.ID = "scheduled1"
|
||||||
|
scheduledMsg.SequenceID = "seq123"
|
||||||
|
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
||||||
|
require.Nil(t, s.AddMessage(scheduledMsg))
|
||||||
|
|
||||||
|
// Create a published message with different sequence ID
|
||||||
|
publishedMsg := model.NewDefaultMessage("mytopic", "published message")
|
||||||
|
publishedMsg.ID = "published1"
|
||||||
|
publishedMsg.SequenceID = "seq456"
|
||||||
|
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
||||||
|
require.Nil(t, s.AddMessage(publishedMsg))
|
||||||
|
|
||||||
|
// Create a scheduled message in a different topic
|
||||||
|
otherTopicMsg := model.NewDefaultMessage("othertopic", "other scheduled")
|
||||||
|
otherTopicMsg.ID = "other1"
|
||||||
|
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
||||||
|
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(otherTopicMsg))
|
||||||
|
|
||||||
|
// Verify all messages exist (including scheduled)
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Delete scheduled message by sequence ID and verify returned IDs
|
||||||
|
deletedIDs, err := s.DeleteScheduledBySequenceID("mytopic", "seq123")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(deletedIDs))
|
||||||
|
require.Equal(t, "scheduled1", deletedIDs[0])
|
||||||
|
|
||||||
|
// Verify scheduled message is deleted
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "published message", messages[0].Message)
|
||||||
|
|
||||||
|
// Verify other topic's message still exists (topic-scoped deletion)
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "other scheduled", messages[0].Message)
|
||||||
|
|
||||||
|
// Deleting non-existent sequence ID should return empty list
|
||||||
|
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, deletedIDs)
|
||||||
|
|
||||||
|
// Deleting published message should not affect it (only deletes unpublished)
|
||||||
|
deletedIDs, err = s.DeleteScheduledBySequenceID("mytopic", "seq456")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Empty(t, deletedIDs)
|
||||||
|
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "published message", messages[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMessageByID(t *testing.T, s message.Store) {
|
||||||
|
// Add a message
|
||||||
|
m := model.NewDefaultMessage("mytopic", "some message")
|
||||||
|
m.Title = "some title"
|
||||||
|
m.Priority = 4
|
||||||
|
m.Tags = []string{"tag1", "tag2"}
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Retrieve by ID
|
||||||
|
retrieved, err := s.Message(m.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, m.ID, retrieved.ID)
|
||||||
|
require.Equal(t, "mytopic", retrieved.Topic)
|
||||||
|
require.Equal(t, "some message", retrieved.Message)
|
||||||
|
require.Equal(t, "some title", retrieved.Title)
|
||||||
|
require.Equal(t, 4, retrieved.Priority)
|
||||||
|
require.Equal(t, []string{"tag1", "tag2"}, retrieved.Tags)
|
||||||
|
|
||||||
|
// Non-existent ID returns ErrMessageNotFound
|
||||||
|
_, err = s.Message("doesnotexist")
|
||||||
|
require.Equal(t, model.ErrMessageNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMarkPublished(t *testing.T, s message.Store) {
|
||||||
|
// Add a scheduled message (future time → unpublished)
|
||||||
|
m := model.NewDefaultMessage("mytopic", "scheduled message")
|
||||||
|
m.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Verify it does not appear in non-scheduled queries
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(messages))
|
||||||
|
|
||||||
|
// Verify it does appear in scheduled queries
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, true)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Mark as published
|
||||||
|
require.Nil(t, s.MarkPublished(m))
|
||||||
|
|
||||||
|
// Now it should appear in non-scheduled queries too
|
||||||
|
messages, err = s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "scheduled message", messages[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testExpireMessages(t *testing.T, s message.Store) {
|
||||||
|
// Add messages to two topics
|
||||||
|
m1 := model.NewDefaultMessage("topic1", "message 1")
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m2 := model.NewDefaultMessage("topic1", "message 2")
|
||||||
|
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m3 := model.NewDefaultMessage("topic2", "message 3")
|
||||||
|
m3.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
// Verify all messages exist
|
||||||
|
messages, err := s.Messages("topic1", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
|
||||||
|
// Expire topic1 messages
|
||||||
|
require.Nil(t, s.ExpireMessages("topic1"))
|
||||||
|
|
||||||
|
// topic1 messages should now be expired (expires set to past)
|
||||||
|
expiredIDs, err := s.MessagesExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(expiredIDs))
|
||||||
|
sort.Strings(expiredIDs)
|
||||||
|
expectedIDs := []string{m1.ID, m2.ID}
|
||||||
|
sort.Strings(expectedIDs)
|
||||||
|
require.Equal(t, expectedIDs, expiredIDs)
|
||||||
|
|
||||||
|
// topic2 should be unaffected
|
||||||
|
messages, err = s.Messages("topic2", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "message 3", messages[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMarkAttachmentsDeleted(t *testing.T, s message.Store) {
|
||||||
|
// Add a message with an expired attachment (file needs cleanup)
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "old file")
|
||||||
|
m1.ID = "msg1"
|
||||||
|
m1.SequenceID = "msg1"
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m1.Attachment = &model.Attachment{
|
||||||
|
Name: "old.pdf",
|
||||||
|
Type: "application/pdf",
|
||||||
|
Size: 50000,
|
||||||
|
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||||
|
URL: "https://ntfy.sh/file/old.pdf",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
// Add a message with another expired attachment
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "another old file")
|
||||||
|
m2.ID = "msg2"
|
||||||
|
m2.SequenceID = "msg2"
|
||||||
|
m2.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
m2.Attachment = &model.Attachment{
|
||||||
|
Name: "another.pdf",
|
||||||
|
Type: "application/pdf",
|
||||||
|
Size: 30000,
|
||||||
|
Expires: time.Now().Add(-time.Hour).Unix(), // Expired
|
||||||
|
URL: "https://ntfy.sh/file/another.pdf",
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Both should show as expired attachments needing cleanup
|
||||||
|
ids, err := s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(ids))
|
||||||
|
|
||||||
|
// Mark msg1's attachment as deleted (file cleaned up)
|
||||||
|
require.Nil(t, s.MarkAttachmentsDeleted("msg1"))
|
||||||
|
|
||||||
|
// Now only msg2 should show as needing cleanup
|
||||||
|
ids, err = s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(ids))
|
||||||
|
require.Equal(t, "msg2", ids[0])
|
||||||
|
|
||||||
|
// Mark msg2 too
|
||||||
|
require.Nil(t, s.MarkAttachmentsDeleted("msg2"))
|
||||||
|
|
||||||
|
// No more expired attachments to clean up
|
||||||
|
ids, err = s.AttachmentsExpired()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(ids))
|
||||||
|
|
||||||
|
// Messages themselves still exist
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testStats(t *testing.T, s message.Store) {
|
||||||
|
// Initial stats should be zero
|
||||||
|
messages, err := s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), messages)
|
||||||
|
|
||||||
|
// Update stats
|
||||||
|
require.Nil(t, s.UpdateStats(42))
|
||||||
|
messages, err = s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(42), messages)
|
||||||
|
|
||||||
|
// Update again (overwrites)
|
||||||
|
require.Nil(t, s.UpdateStats(100))
|
||||||
|
messages, err = s.Stats()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(100), messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testAddMessages(t *testing.T, s message.Store) {
|
||||||
|
// Batch add multiple messages
|
||||||
|
msgs := []*model.Message{
|
||||||
|
model.NewDefaultMessage("mytopic", "batch 1"),
|
||||||
|
model.NewDefaultMessage("mytopic", "batch 2"),
|
||||||
|
model.NewDefaultMessage("othertopic", "batch 3"),
|
||||||
|
}
|
||||||
|
require.Nil(t, s.AddMessages(msgs))
|
||||||
|
|
||||||
|
// Verify all were inserted
|
||||||
|
messages, err := s.Messages("mytopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 2, len(messages))
|
||||||
|
|
||||||
|
messages, err = s.Messages("othertopic", model.SinceAllMessages, false)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(messages))
|
||||||
|
require.Equal(t, "batch 3", messages[0].Message)
|
||||||
|
|
||||||
|
// Empty batch should succeed
|
||||||
|
require.Nil(t, s.AddMessages([]*model.Message{}))
|
||||||
|
|
||||||
|
// Batch with invalid event type should fail
|
||||||
|
badMsgs := []*model.Message{
|
||||||
|
model.NewKeepaliveMessage("mytopic"),
|
||||||
|
}
|
||||||
|
require.NotNil(t, s.AddMessages(badMsgs))
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMessagesDue(t *testing.T, s message.Store) {
|
||||||
|
// Add a message scheduled in the past (i.e. it's due now)
|
||||||
|
m1 := model.NewDefaultMessage("mytopic", "due message")
|
||||||
|
m1.Time = time.Now().Add(-time.Second).Unix()
|
||||||
|
// Set expires in the future so it doesn't get pruned
|
||||||
|
m1.Expires = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m1))
|
||||||
|
|
||||||
|
// Add a message scheduled in the future (not due)
|
||||||
|
m2 := model.NewDefaultMessage("mytopic", "future message")
|
||||||
|
m2.Time = time.Now().Add(time.Hour).Unix()
|
||||||
|
require.Nil(t, s.AddMessage(m2))
|
||||||
|
|
||||||
|
// Mark m1 as published so it won't be "due"
|
||||||
|
// (MessagesDue returns unpublished messages whose time <= now)
|
||||||
|
// m1 is auto-published (time <= now), so it should not be due
|
||||||
|
// m2 is unpublished (time in future), not due yet
|
||||||
|
due, err := s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(due))
|
||||||
|
|
||||||
|
// Add a message that was explicitly scheduled in the past but time has "arrived"
|
||||||
|
// We need to manipulate the database to create a truly "due" message:
|
||||||
|
// a message with published=false and time <= now
|
||||||
|
m3 := model.NewDefaultMessage("mytopic", "truly due message")
|
||||||
|
m3.Time = time.Now().Add(2 * time.Second).Unix() // 2 seconds from now
|
||||||
|
require.Nil(t, s.AddMessage(m3))
|
||||||
|
|
||||||
|
// Not due yet
|
||||||
|
due, err = s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(due))
|
||||||
|
|
||||||
|
// Wait for it to become due
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
due, err = s.MessagesDue()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(due))
|
||||||
|
require.Equal(t, "truly due message", due[0].Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMessageFieldRoundTrip(t *testing.T, s message.Store) {
|
||||||
|
// Create a message with all fields populated
|
||||||
|
m := model.NewDefaultMessage("mytopic", "hello world")
|
||||||
|
m.SequenceID = "custom_seq_id"
|
||||||
|
m.Title = "A Title"
|
||||||
|
m.Priority = 4
|
||||||
|
m.Tags = []string{"warning", "srv01"}
|
||||||
|
m.Click = "https://example.com/click"
|
||||||
|
m.Icon = "https://example.com/icon.png"
|
||||||
|
m.Actions = []*model.Action{
|
||||||
|
{
|
||||||
|
ID: "action1",
|
||||||
|
Action: "view",
|
||||||
|
Label: "Open Site",
|
||||||
|
URL: "https://example.com",
|
||||||
|
Clear: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "action2",
|
||||||
|
Action: "http",
|
||||||
|
Label: "Call Webhook",
|
||||||
|
URL: "https://example.com/hook",
|
||||||
|
Method: "PUT",
|
||||||
|
Headers: map[string]string{"X-Token": "secret"},
|
||||||
|
Body: `{"key":"value"}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m.ContentType = "text/markdown"
|
||||||
|
m.Encoding = "base64"
|
||||||
|
m.Sender = netip.MustParseAddr("9.8.7.6")
|
||||||
|
m.User = "u_TestUser123"
|
||||||
|
require.Nil(t, s.AddMessage(m))
|
||||||
|
|
||||||
|
// Retrieve and verify every field
|
||||||
|
retrieved, err := s.Message(m.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, m.ID, retrieved.ID)
|
||||||
|
require.Equal(t, "custom_seq_id", retrieved.SequenceID)
|
||||||
|
require.Equal(t, m.Time, retrieved.Time)
|
||||||
|
require.Equal(t, m.Expires, retrieved.Expires)
|
||||||
|
require.Equal(t, model.MessageEvent, retrieved.Event)
|
||||||
|
require.Equal(t, "mytopic", retrieved.Topic)
|
||||||
|
require.Equal(t, "hello world", retrieved.Message)
|
||||||
|
require.Equal(t, "A Title", retrieved.Title)
|
||||||
|
require.Equal(t, 4, retrieved.Priority)
|
||||||
|
require.Equal(t, []string{"warning", "srv01"}, retrieved.Tags)
|
||||||
|
require.Equal(t, "https://example.com/click", retrieved.Click)
|
||||||
|
require.Equal(t, "https://example.com/icon.png", retrieved.Icon)
|
||||||
|
require.Equal(t, "text/markdown", retrieved.ContentType)
|
||||||
|
require.Equal(t, "base64", retrieved.Encoding)
|
||||||
|
require.Equal(t, netip.MustParseAddr("9.8.7.6"), retrieved.Sender)
|
||||||
|
require.Equal(t, "u_TestUser123", retrieved.User)
|
||||||
|
|
||||||
|
// Verify actions round-trip
|
||||||
|
require.Equal(t, 2, len(retrieved.Actions))
|
||||||
|
|
||||||
|
require.Equal(t, "action1", retrieved.Actions[0].ID)
|
||||||
|
require.Equal(t, "view", retrieved.Actions[0].Action)
|
||||||
|
require.Equal(t, "Open Site", retrieved.Actions[0].Label)
|
||||||
|
require.Equal(t, "https://example.com", retrieved.Actions[0].URL)
|
||||||
|
require.Equal(t, true, retrieved.Actions[0].Clear)
|
||||||
|
|
||||||
|
require.Equal(t, "action2", retrieved.Actions[1].ID)
|
||||||
|
require.Equal(t, "http", retrieved.Actions[1].Action)
|
||||||
|
require.Equal(t, "Call Webhook", retrieved.Actions[1].Label)
|
||||||
|
require.Equal(t, "https://example.com/hook", retrieved.Actions[1].URL)
|
||||||
|
require.Equal(t, "PUT", retrieved.Actions[1].Method)
|
||||||
|
require.Equal(t, "secret", retrieved.Actions[1].Headers["X-Token"])
|
||||||
|
require.Equal(t, `{"key":"value"}`, retrieved.Actions[1].Body)
|
||||||
|
}
|
||||||
205
model/model.go
Normal file
205
model/model.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// List of possible events
|
||||||
|
const (
|
||||||
|
OpenEvent = "open"
|
||||||
|
KeepaliveEvent = "keepalive"
|
||||||
|
MessageEvent = "message"
|
||||||
|
MessageDeleteEvent = "message_delete"
|
||||||
|
MessageClearEvent = "message_clear"
|
||||||
|
PollRequestEvent = "poll_request"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageIDLength is the length of a randomly generated message ID
|
||||||
|
const MessageIDLength = 12
|
||||||
|
|
||||||
|
// Errors for message operations
|
||||||
|
var (
|
||||||
|
ErrUnexpectedMessageType = errors.New("unexpected message type")
|
||||||
|
ErrMessageNotFound = errors.New("message not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a message published to a topic
|
||||||
|
type Message struct {
|
||||||
|
ID string `json:"id"` // Random message ID
|
||||||
|
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
||||||
|
Time int64 `json:"time"` // Unix time in seconds
|
||||||
|
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
||||||
|
Event string `json:"event"` // One of the above
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
Priority int `json:"priority,omitempty"`
|
||||||
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
Click string `json:"click,omitempty"`
|
||||||
|
Icon string `json:"icon,omitempty"`
|
||||||
|
Actions []*Action `json:"actions,omitempty"`
|
||||||
|
Attachment *Attachment `json:"attachment,omitempty"`
|
||||||
|
PollID string `json:"poll_id,omitempty"`
|
||||||
|
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
||||||
|
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
||||||
|
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
||||||
|
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns a log context for the message
|
||||||
|
func (m *Message) Context() log.Context {
|
||||||
|
fields := map[string]any{
|
||||||
|
"topic": m.Topic,
|
||||||
|
"message_id": m.ID,
|
||||||
|
"message_sequence_id": m.SequenceID,
|
||||||
|
"message_time": m.Time,
|
||||||
|
"message_event": m.Event,
|
||||||
|
"message_body_size": len(m.Message),
|
||||||
|
}
|
||||||
|
if m.Sender.IsValid() {
|
||||||
|
fields["message_sender"] = m.Sender.String()
|
||||||
|
}
|
||||||
|
if m.User != "" {
|
||||||
|
fields["message_user"] = m.User
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForJSON returns a copy of the message suitable for JSON output.
|
||||||
|
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
||||||
|
func (m *Message) ForJSON() *Message {
|
||||||
|
if m.SequenceID == m.ID {
|
||||||
|
clone := *m
|
||||||
|
clone.SequenceID = ""
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attachment represents a file attachment on a message
|
||||||
|
type Attachment struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Size int64 `json:"size,omitempty"`
|
||||||
|
Expires int64 `json:"expires,omitempty"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Action represents a user-defined action on a message
|
||||||
|
type Action struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
||||||
|
Label string `json:"label"` // action button label
|
||||||
|
Clear bool `json:"clear"` // clear notification after successful execution
|
||||||
|
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
||||||
|
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
||||||
|
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
||||||
|
Body string `json:"body,omitempty"` // used in "http" action
|
||||||
|
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
||||||
|
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
||||||
|
Value string `json:"value,omitempty"` // used in "copy" action
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAction creates a new action with initialized maps
|
||||||
|
func NewAction() *Action {
|
||||||
|
return &Action{
|
||||||
|
Headers: make(map[string]string),
|
||||||
|
Extras: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessage creates a new message with the current timestamp
|
||||||
|
func NewMessage(event, topic, msg string) *Message {
|
||||||
|
return &Message{
|
||||||
|
ID: util.RandomString(MessageIDLength),
|
||||||
|
Time: time.Now().Unix(),
|
||||||
|
Event: event,
|
||||||
|
Topic: topic,
|
||||||
|
Message: msg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenMessage is a convenience method to create an open message
|
||||||
|
func NewOpenMessage(topic string) *Message {
|
||||||
|
return NewMessage(OpenEvent, topic, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeepaliveMessage is a convenience method to create a keepalive message
|
||||||
|
func NewKeepaliveMessage(topic string) *Message {
|
||||||
|
return NewMessage(KeepaliveEvent, topic, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultMessage is a convenience method to create a notification message
|
||||||
|
func NewDefaultMessage(topic, msg string) *Message {
|
||||||
|
return NewMessage(MessageEvent, topic, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewActionMessage creates a new action message (message_delete or message_clear)
|
||||||
|
func NewActionMessage(event, topic, sequenceID string) *Message {
|
||||||
|
m := NewMessage(event, topic, "")
|
||||||
|
m.SequenceID = sequenceID
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidMessageID returns true if the given string is a valid message ID
|
||||||
|
func ValidMessageID(s string) bool {
|
||||||
|
return util.ValidRandomString(s, MessageIDLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SinceMarker represents a point in time or message ID from which to retrieve messages
|
||||||
|
type SinceMarker struct {
|
||||||
|
time time.Time
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSinceTime creates a new SinceMarker from a Unix timestamp
|
||||||
|
func NewSinceTime(timestamp int64) SinceMarker {
|
||||||
|
return SinceMarker{time.Unix(timestamp, 0), ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSinceID creates a new SinceMarker from a message ID
|
||||||
|
func NewSinceID(id string) SinceMarker {
|
||||||
|
return SinceMarker{time.Unix(0, 0), id}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAll returns true if this is the "all messages" marker
|
||||||
|
func (t SinceMarker) IsAll() bool {
|
||||||
|
return t == SinceAllMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNone returns true if this is the "no messages" marker
|
||||||
|
func (t SinceMarker) IsNone() bool {
|
||||||
|
return t == SinceNoMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLatest returns true if this is the "latest message" marker
|
||||||
|
func (t SinceMarker) IsLatest() bool {
|
||||||
|
return t == SinceLatestMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsID returns true if this marker references a specific message ID
|
||||||
|
func (t SinceMarker) IsID() bool {
|
||||||
|
return t.id != "" && t.id != "latest"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time returns the time component of the marker
|
||||||
|
func (t SinceMarker) Time() time.Time {
|
||||||
|
return t.time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the message ID component of the marker
|
||||||
|
func (t SinceMarker) ID() string {
|
||||||
|
return t.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common SinceMarker values for subscribing to messages
|
||||||
|
var (
|
||||||
|
SinceAllMessages = SinceMarker{time.Unix(0, 0), ""}
|
||||||
|
SinceNoMessages = SinceMarker{time.Unix(1, 0), ""}
|
||||||
|
SinceLatestMessage = SinceMarker{time.Unix(0, 0), "latest"}
|
||||||
|
)
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ type actionParser struct {
|
|||||||
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
|
// parseActions parses the actions string as described in https://ntfy.sh/docs/publish/#action-buttons.
|
||||||
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
|
// It supports both a JSON representation (if the string begins with "[", see parseActionsFromJSON),
|
||||||
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
|
// and the "simple" format, which is more human-readable, but harder to parse (see parseActionsFromSimple).
|
||||||
func parseActions(s string) (actions []*action, err error) {
|
func parseActions(s string) (actions []*model.Action, err error) {
|
||||||
// Parse JSON or simple format
|
// Parse JSON or simple format
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
if strings.HasPrefix(s, "[") {
|
if strings.HasPrefix(s, "[") {
|
||||||
@@ -80,8 +81,8 @@ func parseActions(s string) (actions []*action, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseActionsFromJSON converts a JSON array into an array of actions
|
// parseActionsFromJSON converts a JSON array into an array of actions
|
||||||
func parseActionsFromJSON(s string) ([]*action, error) {
|
func parseActionsFromJSON(s string) ([]*model.Action, error) {
|
||||||
actions := make([]*action, 0)
|
actions := make([]*model.Action, 0)
|
||||||
if err := json.Unmarshal([]byte(s), &actions); err != nil {
|
if err := json.Unmarshal([]byte(s), &actions); err != nil {
|
||||||
return nil, fmt.Errorf("JSON error: %w", err)
|
return nil, fmt.Errorf("JSON error: %w", err)
|
||||||
}
|
}
|
||||||
@@ -107,7 +108,7 @@ func parseActionsFromJSON(s string) ([]*action, error) {
|
|||||||
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
|
// https://github.com/adampresley/sample-ini-parser/blob/master/services/lexer/lexer/Lexer.go
|
||||||
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
|
// https://github.com/benbjohnson/sql-parser/blob/master/scanner.go
|
||||||
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
|
// https://blog.gopheracademy.com/advent-2014/parsers-lexers/
|
||||||
func parseActionsFromSimple(s string) ([]*action, error) {
|
func parseActionsFromSimple(s string) ([]*model.Action, error) {
|
||||||
if !utf8.ValidString(s) {
|
if !utf8.ValidString(s) {
|
||||||
return nil, errors.New("invalid utf-8 string")
|
return nil, errors.New("invalid utf-8 string")
|
||||||
}
|
}
|
||||||
@@ -119,8 +120,8 @@ func parseActionsFromSimple(s string) ([]*action, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse loops trough parseAction() until the end of the string is reached
|
// Parse loops trough parseAction() until the end of the string is reached
|
||||||
func (p *actionParser) Parse() ([]*action, error) {
|
func (p *actionParser) Parse() ([]*model.Action, error) {
|
||||||
actions := make([]*action, 0)
|
actions := make([]*model.Action, 0)
|
||||||
for !p.eof() {
|
for !p.eof() {
|
||||||
a, err := p.parseAction()
|
a, err := p.parseAction()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -134,7 +135,7 @@ func (p *actionParser) Parse() ([]*action, error) {
|
|||||||
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
|
// parseAction parses the individual sections of an action using parseSection into key/value pairs,
|
||||||
// and then uses populateAction to interpret the keys/values. The function terminates
|
// and then uses populateAction to interpret the keys/values. The function terminates
|
||||||
// when EOF or ";" is reached.
|
// when EOF or ";" is reached.
|
||||||
func (p *actionParser) parseAction() (*action, error) {
|
func (p *actionParser) parseAction() (*model.Action, error) {
|
||||||
a := newAction()
|
a := newAction()
|
||||||
section := 0
|
section := 0
|
||||||
for {
|
for {
|
||||||
@@ -155,7 +156,7 @@ func (p *actionParser) parseAction() (*action, error) {
|
|||||||
|
|
||||||
// populateAction is the "business logic" of the parser. It applies the key/value
|
// populateAction is the "business logic" of the parser. It applies the key/value
|
||||||
// pair to the action instance.
|
// pair to the action instance.
|
||||||
func populateAction(newAction *action, section int, key, value string) error {
|
func populateAction(newAction *model.Action, section int, key, value string) error {
|
||||||
// Auto-expand keys based on their index
|
// Auto-expand keys based on their index
|
||||||
if key == "" && section == 0 {
|
if key == "" && section == 0 {
|
||||||
key = "action"
|
key = "action"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,12 +56,12 @@ func logvr(v *visitor, r *http.Request) *log.Event {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// logvrm creates a new log event with HTTP request, visitor fields and message fields
|
// logvrm creates a new log event with HTTP request, visitor fields and message fields
|
||||||
func logvrm(v *visitor, r *http.Request, m *message) *log.Event {
|
func logvrm(v *visitor, r *http.Request, m *model.Message) *log.Event {
|
||||||
return logvr(v, r).With(m)
|
return logvr(v, r).With(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// logvrm creates a new log event with visitor fields and message fields
|
// logvrm creates a new log event with visitor fields and message fields
|
||||||
func logvm(v *visitor, m *message) *log.Event {
|
func logvm(v *visitor, m *model.Message) *log.Event {
|
||||||
return logv(v).With(m)
|
return logv(v).With(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,825 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"net/netip"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSqliteCache_Messages(t *testing.T) {
|
|
||||||
testCacheMessages(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Messages(t *testing.T) {
|
|
||||||
testCacheMessages(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessages(t *testing.T, c *messageCache) {
|
|
||||||
m1 := newDefaultMessage("mytopic", "my message")
|
|
||||||
m1.Time = 1
|
|
||||||
|
|
||||||
m2 := newDefaultMessage("mytopic", "my other message")
|
|
||||||
m2.Time = 2
|
|
||||||
|
|
||||||
require.Nil(t, c.AddMessage(m1))
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("example", "my example message")))
|
|
||||||
require.Nil(t, c.AddMessage(m2))
|
|
||||||
|
|
||||||
// Adding invalid
|
|
||||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newKeepaliveMessage("mytopic"))) // These should not be added!
|
|
||||||
require.Equal(t, errUnexpectedMessageType, c.AddMessage(newOpenMessage("example"))) // These should not be added!
|
|
||||||
|
|
||||||
// mytopic: count
|
|
||||||
counts, err := c.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, counts["mytopic"])
|
|
||||||
|
|
||||||
// mytopic: since all
|
|
||||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, "my message", messages[0].Message)
|
|
||||||
require.Equal(t, "mytopic", messages[0].Topic)
|
|
||||||
require.Equal(t, messageEvent, messages[0].Event)
|
|
||||||
require.Equal(t, "", messages[0].Title)
|
|
||||||
require.Equal(t, 0, messages[0].Priority)
|
|
||||||
require.Nil(t, messages[0].Tags)
|
|
||||||
require.Equal(t, "my other message", messages[1].Message)
|
|
||||||
|
|
||||||
// mytopic: since none
|
|
||||||
messages, _ = c.Messages("mytopic", sinceNoMessages, false)
|
|
||||||
require.Empty(t, messages)
|
|
||||||
|
|
||||||
// mytopic: since m1 (by ID)
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceID(m1.ID), false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, m2.ID, messages[0].ID)
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
require.Equal(t, "mytopic", messages[0].Topic)
|
|
||||||
|
|
||||||
// mytopic: since 2
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceTime(2), false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
|
|
||||||
// mytopic: latest
|
|
||||||
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
|
|
||||||
// example: count
|
|
||||||
counts, err = c.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, counts["example"])
|
|
||||||
|
|
||||||
// example: since all
|
|
||||||
messages, _ = c.Messages("example", sinceAllMessages, false)
|
|
||||||
require.Equal(t, "my example message", messages[0].Message)
|
|
||||||
|
|
||||||
// non-existing: count
|
|
||||||
counts, err = c.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, counts["doesnotexist"])
|
|
||||||
|
|
||||||
// non-existing: since all
|
|
||||||
messages, _ = c.Messages("doesnotexist", sinceAllMessages, false)
|
|
||||||
require.Empty(t, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_MessagesLock(t *testing.T) {
|
|
||||||
testCacheMessagesLock(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_MessagesLock(t *testing.T) {
|
|
||||||
testCacheMessagesLock(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesLock(t *testing.T, c *messageCache) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 5000; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
assert.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "test message")))
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_MessagesScheduled(t *testing.T) {
|
|
||||||
testCacheMessagesScheduled(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_MessagesScheduled(t *testing.T) {
|
|
||||||
testCacheMessagesScheduled(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesScheduled(t *testing.T, c *messageCache) {
|
|
||||||
m1 := newDefaultMessage("mytopic", "message 1")
|
|
||||||
m2 := newDefaultMessage("mytopic", "message 2")
|
|
||||||
m2.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
m3 := newDefaultMessage("mytopic", "message 3")
|
|
||||||
m3.Time = time.Now().Add(time.Minute).Unix() // earlier than m2!
|
|
||||||
m4 := newDefaultMessage("mytopic2", "message 4")
|
|
||||||
m4.Time = time.Now().Add(time.Minute).Unix()
|
|
||||||
require.Nil(t, c.AddMessage(m1))
|
|
||||||
require.Nil(t, c.AddMessage(m2))
|
|
||||||
require.Nil(t, c.AddMessage(m3))
|
|
||||||
|
|
||||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false) // exclude scheduled
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
|
|
||||||
messages, _ = c.Messages("mytopic", sinceAllMessages, true) // include scheduled
|
|
||||||
require.Equal(t, 3, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
require.Equal(t, "message 3", messages[1].Message) // Order!
|
|
||||||
require.Equal(t, "message 2", messages[2].Message)
|
|
||||||
|
|
||||||
messages, _ = c.MessagesDue()
|
|
||||||
require.Empty(t, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Topics(t *testing.T) {
|
|
||||||
testCacheTopics(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Topics(t *testing.T) {
|
|
||||||
testCacheTopics(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheTopics(t *testing.T, c *messageCache) {
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic1", "my example message")))
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 1")))
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 2")))
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("topic2", "message 3")))
|
|
||||||
|
|
||||||
topics, err := c.Topics()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
require.Equal(t, 2, len(topics))
|
|
||||||
require.Equal(t, "topic1", topics["topic1"].ID)
|
|
||||||
require.Equal(t, "topic2", topics["topic2"].ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_MessagesTagsPrioAndTitle(t *testing.T) {
|
|
||||||
testCacheMessagesTagsPrioAndTitle(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesTagsPrioAndTitle(t *testing.T, c *messageCache) {
|
|
||||||
m := newDefaultMessage("mytopic", "some message")
|
|
||||||
m.Tags = []string{"tag1", "tag2"}
|
|
||||||
m.Priority = 5
|
|
||||||
m.Title = "some title"
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
messages, _ := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Equal(t, []string{"tag1", "tag2"}, messages[0].Tags)
|
|
||||||
require.Equal(t, 5, messages[0].Priority)
|
|
||||||
require.Equal(t, "some title", messages[0].Title)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_MessagesSinceID(t *testing.T) {
|
|
||||||
testCacheMessagesSinceID(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_MessagesSinceID(t *testing.T) {
|
|
||||||
testCacheMessagesSinceID(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheMessagesSinceID(t *testing.T, c *messageCache) {
|
|
||||||
m1 := newDefaultMessage("mytopic", "message 1")
|
|
||||||
m1.Time = 100
|
|
||||||
m2 := newDefaultMessage("mytopic", "message 2")
|
|
||||||
m2.Time = 200
|
|
||||||
m3 := newDefaultMessage("mytopic", "message 3")
|
|
||||||
m3.Time = time.Now().Add(time.Hour).Unix() // Scheduled, in the future, later than m7 and m5
|
|
||||||
m4 := newDefaultMessage("mytopic", "message 4")
|
|
||||||
m4.Time = 400
|
|
||||||
m5 := newDefaultMessage("mytopic", "message 5")
|
|
||||||
m5.Time = time.Now().Add(time.Minute).Unix() // Scheduled, in the future, later than m7
|
|
||||||
m6 := newDefaultMessage("mytopic", "message 6")
|
|
||||||
m6.Time = 600
|
|
||||||
m7 := newDefaultMessage("mytopic", "message 7")
|
|
||||||
m7.Time = 700
|
|
||||||
|
|
||||||
require.Nil(t, c.AddMessage(m1))
|
|
||||||
require.Nil(t, c.AddMessage(m2))
|
|
||||||
require.Nil(t, c.AddMessage(m3))
|
|
||||||
require.Nil(t, c.AddMessage(m4))
|
|
||||||
require.Nil(t, c.AddMessage(m5))
|
|
||||||
require.Nil(t, c.AddMessage(m6))
|
|
||||||
require.Nil(t, c.AddMessage(m7))
|
|
||||||
|
|
||||||
// Case 1: Since ID exists, exclude scheduled
|
|
||||||
messages, _ := c.Messages("mytopic", newSinceID(m2.ID), false)
|
|
||||||
require.Equal(t, 3, len(messages))
|
|
||||||
require.Equal(t, "message 4", messages[0].Message)
|
|
||||||
require.Equal(t, "message 6", messages[1].Message) // Not scheduled m3/m5!
|
|
||||||
require.Equal(t, "message 7", messages[2].Message)
|
|
||||||
|
|
||||||
// Case 2: Since ID exists, include scheduled
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceID(m2.ID), true)
|
|
||||||
require.Equal(t, 5, len(messages))
|
|
||||||
require.Equal(t, "message 4", messages[0].Message)
|
|
||||||
require.Equal(t, "message 6", messages[1].Message)
|
|
||||||
require.Equal(t, "message 7", messages[2].Message)
|
|
||||||
require.Equal(t, "message 5", messages[3].Message) // Order!
|
|
||||||
require.Equal(t, "message 3", messages[4].Message) // Order!
|
|
||||||
|
|
||||||
// Case 3: Since ID does not exist (-> Return all messages), include scheduled
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceID("doesntexist"), true)
|
|
||||||
require.Equal(t, 7, len(messages))
|
|
||||||
require.Equal(t, "message 1", messages[0].Message)
|
|
||||||
require.Equal(t, "message 2", messages[1].Message)
|
|
||||||
require.Equal(t, "message 4", messages[2].Message)
|
|
||||||
require.Equal(t, "message 6", messages[3].Message)
|
|
||||||
require.Equal(t, "message 7", messages[4].Message)
|
|
||||||
require.Equal(t, "message 5", messages[5].Message) // Order!
|
|
||||||
require.Equal(t, "message 3", messages[6].Message) // Order!
|
|
||||||
|
|
||||||
// Case 4: Since ID exists and is last message (-> Return no messages), exclude scheduled
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), false)
|
|
||||||
require.Equal(t, 0, len(messages))
|
|
||||||
|
|
||||||
// Case 5: Since ID exists and is last message (-> Return no messages), include scheduled
|
|
||||||
messages, _ = c.Messages("mytopic", newSinceID(m7.ID), true)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, "message 5", messages[0].Message)
|
|
||||||
require.Equal(t, "message 3", messages[1].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Prune(t *testing.T) {
|
|
||||||
testCachePrune(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Prune(t *testing.T) {
|
|
||||||
testCachePrune(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCachePrune(t *testing.T, c *messageCache) {
|
|
||||||
now := time.Now().Unix()
|
|
||||||
|
|
||||||
m1 := newDefaultMessage("mytopic", "my message")
|
|
||||||
m1.Time = now - 10
|
|
||||||
m1.Expires = now - 5
|
|
||||||
|
|
||||||
m2 := newDefaultMessage("mytopic", "my other message")
|
|
||||||
m2.Time = now - 5
|
|
||||||
m2.Expires = now + 5 // In the future
|
|
||||||
|
|
||||||
m3 := newDefaultMessage("another_topic", "and another one")
|
|
||||||
m3.Time = now - 12
|
|
||||||
m3.Expires = now - 2
|
|
||||||
|
|
||||||
require.Nil(t, c.AddMessage(m1))
|
|
||||||
require.Nil(t, c.AddMessage(m2))
|
|
||||||
require.Nil(t, c.AddMessage(m3))
|
|
||||||
|
|
||||||
counts, err := c.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, counts["mytopic"])
|
|
||||||
require.Equal(t, 1, counts["another_topic"])
|
|
||||||
|
|
||||||
expiredMessageIDs, err := c.MessagesExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, c.DeleteMessages(expiredMessageIDs...))
|
|
||||||
|
|
||||||
counts, err = c.MessageCounts()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, counts["mytopic"])
|
|
||||||
require.Equal(t, 0, counts["another_topic"])
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "my other message", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Attachments(t *testing.T) {
|
|
||||||
testCacheAttachments(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Attachments(t *testing.T) {
|
|
||||||
testCacheAttachments(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheAttachments(t *testing.T, c *messageCache) {
|
|
||||||
expires1 := time.Now().Add(-4 * time.Hour).Unix() // Expired
|
|
||||||
m := newDefaultMessage("mytopic", "flower for you")
|
|
||||||
m.ID = "m1"
|
|
||||||
m.SequenceID = "m1"
|
|
||||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "flower.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 5000,
|
|
||||||
Expires: expires1,
|
|
||||||
URL: "https://ntfy.sh/file/AbDeFgJhal.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
|
|
||||||
m = newDefaultMessage("mytopic", "sending you a car")
|
|
||||||
m.ID = "m2"
|
|
||||||
m.SequenceID = "m2"
|
|
||||||
m.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 10000,
|
|
||||||
Expires: expires2,
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
|
|
||||||
m = newDefaultMessage("another-topic", "sending you another car")
|
|
||||||
m.ID = "m3"
|
|
||||||
m.SequenceID = "m3"
|
|
||||||
m.User = "u_BAsbaAa"
|
|
||||||
m.Sender = netip.MustParseAddr("5.6.7.8")
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "another-car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 20000,
|
|
||||||
Expires: expires3,
|
|
||||||
URL: "https://ntfy.sh/file/zakaDHFW.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
|
|
||||||
require.Equal(t, "flower for you", messages[0].Message)
|
|
||||||
require.Equal(t, "flower.jpg", messages[0].Attachment.Name)
|
|
||||||
require.Equal(t, "image/jpeg", messages[0].Attachment.Type)
|
|
||||||
require.Equal(t, int64(5000), messages[0].Attachment.Size)
|
|
||||||
require.Equal(t, expires1, messages[0].Attachment.Expires)
|
|
||||||
require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
|
|
||||||
require.Equal(t, "1.2.3.4", messages[0].Sender.String())
|
|
||||||
|
|
||||||
require.Equal(t, "sending you a car", messages[1].Message)
|
|
||||||
require.Equal(t, "car.jpg", messages[1].Attachment.Name)
|
|
||||||
require.Equal(t, "image/jpeg", messages[1].Attachment.Type)
|
|
||||||
require.Equal(t, int64(10000), messages[1].Attachment.Size)
|
|
||||||
require.Equal(t, expires2, messages[1].Attachment.Expires)
|
|
||||||
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
|
|
||||||
require.Equal(t, "1.2.3.4", messages[1].Sender.String())
|
|
||||||
|
|
||||||
size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(10000), size)
|
|
||||||
|
|
||||||
size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(0), size) // Accounted to the user, not the IP!
|
|
||||||
|
|
||||||
size, err = c.AttachmentBytesUsedByUser("u_BAsbaAa")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, int64(20000), size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Attachments_Expired(t *testing.T) {
|
|
||||||
testCacheAttachmentsExpired(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Attachments_Expired(t *testing.T) {
|
|
||||||
testCacheAttachmentsExpired(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testCacheAttachmentsExpired(t *testing.T, c *messageCache) {
|
|
||||||
m := newDefaultMessage("mytopic", "flower for you")
|
|
||||||
m.ID = "m1"
|
|
||||||
m.SequenceID = "m1"
|
|
||||||
m.Expires = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
m = newDefaultMessage("mytopic", "message with attachment")
|
|
||||||
m.ID = "m2"
|
|
||||||
m.SequenceID = "m2"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 10000,
|
|
||||||
Expires: time.Now().Add(2 * time.Hour).Unix(),
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
m = newDefaultMessage("mytopic", "message with external attachment")
|
|
||||||
m.ID = "m3"
|
|
||||||
m.SequenceID = "m3"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Expires: 0, // Unknown!
|
|
||||||
URL: "https://somedomain.com/car.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
m = newDefaultMessage("mytopic2", "message with expired attachment")
|
|
||||||
m.ID = "m4"
|
|
||||||
m.SequenceID = "m4"
|
|
||||||
m.Expires = time.Now().Add(2 * time.Hour).Unix()
|
|
||||||
m.Attachment = &attachment{
|
|
||||||
Name: "expired-car.jpg",
|
|
||||||
Type: "image/jpeg",
|
|
||||||
Size: 20000,
|
|
||||||
Expires: time.Now().Add(-1 * time.Hour).Unix(),
|
|
||||||
URL: "https://ntfy.sh/file/aCaRURL.jpg",
|
|
||||||
}
|
|
||||||
require.Nil(t, c.AddMessage(m))
|
|
||||||
|
|
||||||
ids, err := c.AttachmentsExpired()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(ids))
|
|
||||||
require.Equal(t, "m4", ids[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Migration_From0(t *testing.T) {
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
db, err := sql.Open("sqlite3", filename)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create "version 0" schema
|
|
||||||
_, err = db.Exec(`
|
|
||||||
BEGIN;
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
|
||||||
id VARCHAR(20) PRIMARY KEY,
|
|
||||||
time INT NOT NULL,
|
|
||||||
topic VARCHAR(64) NOT NULL,
|
|
||||||
message VARCHAR(1024) NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
|
||||||
COMMIT;
|
|
||||||
`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Insert a bunch of messages
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message) VALUES (?, ?, ?, ?)`,
|
|
||||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i))
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
require.Nil(t, db.Close())
|
|
||||||
|
|
||||||
// Create cache to trigger migration
|
|
||||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
|
||||||
checkSchemaVersion(t, c.db)
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 10, len(messages))
|
|
||||||
require.Equal(t, "some message 5", messages[5].Message)
|
|
||||||
require.Equal(t, "", messages[5].Title)
|
|
||||||
require.Nil(t, messages[5].Tags)
|
|
||||||
require.Equal(t, 0, messages[5].Priority)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
db, err := sql.Open("sqlite3", filename)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create "version 1" schema
|
|
||||||
_, err = db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
|
||||||
id VARCHAR(20) PRIMARY KEY,
|
|
||||||
time INT NOT NULL,
|
|
||||||
topic VARCHAR(64) NOT NULL,
|
|
||||||
message VARCHAR(512) NOT NULL,
|
|
||||||
title VARCHAR(256) NOT NULL,
|
|
||||||
priority INT NOT NULL,
|
|
||||||
tags VARCHAR(256) NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
|
||||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
|
||||||
id INT PRIMARY KEY,
|
|
||||||
version INT NOT NULL
|
|
||||||
);
|
|
||||||
INSERT INTO schemaVersion (id, version) VALUES (1, 1);
|
|
||||||
`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Insert a bunch of messages
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = db.Exec(`INSERT INTO messages (id, time, topic, message, title, priority, tags) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
fmt.Sprintf("abcd%d", i), time.Now().Unix(), "mytopic", fmt.Sprintf("some message %d", i), "", 0, "")
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
require.Nil(t, db.Close())
|
|
||||||
|
|
||||||
// Create cache to trigger migration
|
|
||||||
c := newSqliteTestCacheFromFile(t, filename, "")
|
|
||||||
checkSchemaVersion(t, c.db)
|
|
||||||
|
|
||||||
// Add delayed message
|
|
||||||
delayedMessage := newDefaultMessage("mytopic", "some delayed message")
|
|
||||||
delayedMessage.Time = time.Now().Add(time.Minute).Unix()
|
|
||||||
require.Nil(t, c.AddMessage(delayedMessage))
|
|
||||||
|
|
||||||
// 10, not 11!
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 10, len(messages))
|
|
||||||
|
|
||||||
// 11!
|
|
||||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 11, len(messages))
|
|
||||||
|
|
||||||
// Check that index "idx_topic" exists
|
|
||||||
rows, err := c.db.Query(`SELECT name FROM sqlite_master WHERE type='index' AND name='idx_topic'`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, rows.Next())
|
|
||||||
var indexName string
|
|
||||||
require.Nil(t, rows.Scan(&indexName))
|
|
||||||
require.Equal(t, "idx_topic", indexName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Migration_From9(t *testing.T) {
|
|
||||||
// This primarily tests the awkward migration that introduces the "expires" column.
|
|
||||||
// The migration logic has to update the column, using the existing "cache-duration" value.
|
|
||||||
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
db, err := sql.Open("sqlite3", filename)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create "version 8" schema
|
|
||||||
_, err = db.Exec(`
|
|
||||||
BEGIN;
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
mid TEXT NOT NULL,
|
|
||||||
time INT NOT NULL,
|
|
||||||
topic TEXT NOT NULL,
|
|
||||||
message TEXT NOT NULL,
|
|
||||||
title TEXT NOT NULL,
|
|
||||||
priority INT NOT NULL,
|
|
||||||
tags TEXT NOT NULL,
|
|
||||||
click TEXT NOT NULL,
|
|
||||||
icon TEXT NOT NULL,
|
|
||||||
actions TEXT NOT NULL,
|
|
||||||
attachment_name TEXT NOT NULL,
|
|
||||||
attachment_type TEXT NOT NULL,
|
|
||||||
attachment_size INT NOT NULL,
|
|
||||||
attachment_expires INT NOT NULL,
|
|
||||||
attachment_url TEXT NOT NULL,
|
|
||||||
sender TEXT NOT NULL,
|
|
||||||
encoding TEXT NOT NULL,
|
|
||||||
published INT NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_mid ON messages (mid);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_time ON messages (time);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_topic ON messages (topic);
|
|
||||||
CREATE TABLE IF NOT EXISTS schemaVersion (
|
|
||||||
id INT PRIMARY KEY,
|
|
||||||
version INT NOT NULL
|
|
||||||
);
|
|
||||||
INSERT INTO schemaVersion (id, version) VALUES (1, 9);
|
|
||||||
COMMIT;
|
|
||||||
`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Insert a bunch of messages
|
|
||||||
insertQuery := `
|
|
||||||
INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = db.Exec(
|
|
||||||
insertQuery,
|
|
||||||
fmt.Sprintf("abcd%d", i),
|
|
||||||
time.Now().Unix(),
|
|
||||||
"mytopic",
|
|
||||||
fmt.Sprintf("some message %d", i),
|
|
||||||
"", // title
|
|
||||||
0, // priority
|
|
||||||
"", // tags
|
|
||||||
"", // click
|
|
||||||
"", // icon
|
|
||||||
"", // actions
|
|
||||||
"", // attachment_name
|
|
||||||
"", // attachment_type
|
|
||||||
0, // attachment_size
|
|
||||||
0, // attachment_type
|
|
||||||
"", // attachment_url
|
|
||||||
"9.9.9.9", // sender
|
|
||||||
"", // encoding
|
|
||||||
1, // published
|
|
||||||
)
|
|
||||||
require.Nil(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create cache to trigger migration
|
|
||||||
cacheDuration := 17 * time.Hour
|
|
||||||
c, err := newSqliteCache(filename, "", cacheDuration, 0, 0, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
checkSchemaVersion(t, c.db)
|
|
||||||
|
|
||||||
// Check version
|
|
||||||
rows, err := db.Query(`SELECT version FROM main.schemaVersion WHERE id = 1`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, rows.Next())
|
|
||||||
var version int
|
|
||||||
require.Nil(t, rows.Scan(&version))
|
|
||||||
require.Equal(t, currentSchemaVersion, version)
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 10, len(messages))
|
|
||||||
for _, m := range messages {
|
|
||||||
require.True(t, m.Expires > time.Now().Add(cacheDuration-5*time.Second).Unix())
|
|
||||||
require.True(t, m.Expires < time.Now().Add(cacheDuration+5*time.Second).Unix())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_StartupQueries_WAL(t *testing.T) {
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
startupQueries := `pragma journal_mode = WAL;
|
|
||||||
pragma synchronous = normal;
|
|
||||||
pragma temp_store = memory;`
|
|
||||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
|
||||||
require.FileExists(t, filename)
|
|
||||||
require.FileExists(t, filename+"-wal")
|
|
||||||
require.FileExists(t, filename+"-shm")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_StartupQueries_None(t *testing.T) {
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
startupQueries := ""
|
|
||||||
db, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, db.AddMessage(newDefaultMessage("mytopic", "some message")))
|
|
||||||
require.FileExists(t, filename)
|
|
||||||
require.NoFileExists(t, filename+"-wal")
|
|
||||||
require.NoFileExists(t, filename+"-shm")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_StartupQueries_Fail(t *testing.T) {
|
|
||||||
filename := newSqliteTestCacheFile(t)
|
|
||||||
startupQueries := `xx error`
|
|
||||||
_, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_Sender(t *testing.T) {
|
|
||||||
testSender(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_Sender(t *testing.T) {
|
|
||||||
testSender(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSender(t *testing.T, c *messageCache) {
|
|
||||||
m1 := newDefaultMessage("mytopic", "mymessage")
|
|
||||||
m1.Sender = netip.MustParseAddr("1.2.3.4")
|
|
||||||
require.Nil(t, c.AddMessage(m1))
|
|
||||||
|
|
||||||
m2 := newDefaultMessage("mytopic", "mymessage without sender")
|
|
||||||
require.Nil(t, c.AddMessage(m2))
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
require.Equal(t, messages[0].Sender, netip.MustParseAddr("1.2.3.4"))
|
|
||||||
require.Equal(t, messages[1].Sender, netip.Addr{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqliteCache_DeleteScheduledBySequenceID(t *testing.T) {
|
|
||||||
testDeleteScheduledBySequenceID(t, newSqliteTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_DeleteScheduledBySequenceID(t *testing.T) {
|
|
||||||
testDeleteScheduledBySequenceID(t, newMemTestCache(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
func testDeleteScheduledBySequenceID(t *testing.T, c *messageCache) {
|
|
||||||
// Create a scheduled (unpublished) message
|
|
||||||
scheduledMsg := newDefaultMessage("mytopic", "scheduled message")
|
|
||||||
scheduledMsg.ID = "scheduled1"
|
|
||||||
scheduledMsg.SequenceID = "seq123"
|
|
||||||
scheduledMsg.Time = time.Now().Add(time.Hour).Unix() // Future time makes it scheduled
|
|
||||||
require.Nil(t, c.AddMessage(scheduledMsg))
|
|
||||||
|
|
||||||
// Create a published message with different sequence ID
|
|
||||||
publishedMsg := newDefaultMessage("mytopic", "published message")
|
|
||||||
publishedMsg.ID = "published1"
|
|
||||||
publishedMsg.SequenceID = "seq456"
|
|
||||||
publishedMsg.Time = time.Now().Add(-time.Hour).Unix() // Past time makes it published
|
|
||||||
require.Nil(t, c.AddMessage(publishedMsg))
|
|
||||||
|
|
||||||
// Create a scheduled message in a different topic
|
|
||||||
otherTopicMsg := newDefaultMessage("othertopic", "other scheduled")
|
|
||||||
otherTopicMsg.ID = "other1"
|
|
||||||
otherTopicMsg.SequenceID = "seq123" // Same sequence ID as scheduledMsg
|
|
||||||
otherTopicMsg.Time = time.Now().Add(time.Hour).Unix()
|
|
||||||
require.Nil(t, c.AddMessage(otherTopicMsg))
|
|
||||||
|
|
||||||
// Verify all messages exist (including scheduled)
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 2, len(messages))
|
|
||||||
|
|
||||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
|
|
||||||
// Delete scheduled message by sequence ID and verify returned IDs
|
|
||||||
deletedIDs, err := c.DeleteScheduledBySequenceID("mytopic", "seq123")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(deletedIDs))
|
|
||||||
require.Equal(t, "scheduled1", deletedIDs[0])
|
|
||||||
|
|
||||||
// Verify scheduled message is deleted
|
|
||||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "published message", messages[0].Message)
|
|
||||||
|
|
||||||
// Verify other topic's message still exists (topic-scoped deletion)
|
|
||||||
messages, err = c.Messages("othertopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "other scheduled", messages[0].Message)
|
|
||||||
|
|
||||||
// Deleting non-existent sequence ID should return empty list
|
|
||||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "nonexistent")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, deletedIDs)
|
|
||||||
|
|
||||||
// Deleting published message should not affect it (only deletes unpublished)
|
|
||||||
deletedIDs, err = c.DeleteScheduledBySequenceID("mytopic", "seq456")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, deletedIDs)
|
|
||||||
|
|
||||||
messages, err = c.Messages("mytopic", sinceAllMessages, true)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(messages))
|
|
||||||
require.Equal(t, "published message", messages[0].Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSchemaVersion(t *testing.T, db *sql.DB) {
|
|
||||||
rows, err := db.Query(`SELECT version FROM schemaVersion`)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.True(t, rows.Next())
|
|
||||||
|
|
||||||
var schemaVersion int
|
|
||||||
require.Nil(t, rows.Scan(&schemaVersion))
|
|
||||||
require.Equal(t, currentSchemaVersion, schemaVersion)
|
|
||||||
require.Nil(t, rows.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMemCache_NopCache(t *testing.T) {
|
|
||||||
c, _ := newNopCache()
|
|
||||||
require.Nil(t, c.AddMessage(newDefaultMessage("mytopic", "my message")))
|
|
||||||
|
|
||||||
messages, err := c.Messages("mytopic", sinceAllMessages, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, messages)
|
|
||||||
|
|
||||||
topics, err := c.Topics()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Empty(t, topics)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSqliteTestCache(t *testing.T) *messageCache {
|
|
||||||
c, err := newSqliteCache(newSqliteTestCacheFile(t), "", time.Hour, 0, 0, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSqliteTestCacheFile(t *testing.T) string {
|
|
||||||
return filepath.Join(t.TempDir(), "cache.db")
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSqliteTestCacheFromFile(t *testing.T, filename, startupQueries string) *messageCache {
|
|
||||||
c, err := newSqliteCache(filename, startupQueries, time.Hour, 0, 0, false)
|
|
||||||
require.Nil(t, err)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMemTestCache(t *testing.T) *messageCache {
|
|
||||||
c, err := newMemCache()
|
|
||||||
require.Nil(t, err)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
@@ -33,6 +33,8 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/payments"
|
"heckel.io/ntfy/v2/payments"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
@@ -57,7 +59,7 @@ type Server struct {
|
|||||||
messages int64 // Total number of messages (persisted if messageCache enabled)
|
messages int64 // Total number of messages (persisted if messageCache enabled)
|
||||||
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
messagesHistory []int64 // Last n values of the messages counter, used to determine rate
|
||||||
userManager *user.Manager // Might be nil!
|
userManager *user.Manager // Might be nil!
|
||||||
messageCache *messageCache // Database that stores the messages
|
messageCache message.Store // Database that stores the messages
|
||||||
webPush webpush.Store // Database that stores web push subscriptions
|
webPush webpush.Store // Database that stores web push subscriptions
|
||||||
fileCache *fileCache // File system based cache that stores attachments
|
fileCache *fileCache // File system based cache that stores attachments
|
||||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||||
@@ -188,10 +190,14 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
topics, err := messageCache.Topics()
|
topicIDs, err := messageCache.Topics()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
topics := make(map[string]*topic, len(topicIDs))
|
||||||
|
for _, id := range topicIDs {
|
||||||
|
topics[id] = newTopic(id)
|
||||||
|
}
|
||||||
messages, err := messageCache.Stats()
|
messages, err := messageCache.Stats()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -263,13 +269,15 @@ func New(conf *Config) (*Server, error) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createMessageCache(conf *Config) (*messageCache, error) {
|
func createMessageCache(conf *Config) (message.Store, error) {
|
||||||
if conf.CacheDuration == 0 {
|
if conf.CacheDuration == 0 {
|
||||||
return newNopCache()
|
return message.NewNopStore()
|
||||||
|
} else if conf.DatabaseURL != "" {
|
||||||
|
return message.NewPostgresStore(conf.DatabaseURL, conf.CacheBatchSize, conf.CacheBatchTimeout)
|
||||||
} else if conf.CacheFile != "" {
|
} else if conf.CacheFile != "" {
|
||||||
return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
return message.NewSQLiteStore(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
|
||||||
}
|
}
|
||||||
return newMemCache()
|
return message.NewMemStore()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
// Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
|
||||||
@@ -750,7 +758,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||||||
if s.config.CacheBatchTimeout > 0 {
|
if s.config.CacheBatchTimeout > 0 {
|
||||||
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
// Strange edge case: If we immediately after upload request the file (the web app does this for images),
|
||||||
// and messages are persisted asynchronously, retry fetching from the database
|
// and messages are persisted asynchronously, retry fetching from the database
|
||||||
m, err = util.Retry(func() (*message, error) {
|
m, err = util.Retry(func() (*model.Message, error) {
|
||||||
return s.messageCache.Message(messageID)
|
return s.messageCache.Message(messageID)
|
||||||
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
|
}, s.config.CacheBatchTimeout, 100*time.Millisecond, 300*time.Millisecond, 600*time.Millisecond)
|
||||||
}
|
}
|
||||||
@@ -796,7 +804,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
|||||||
return writeMatrixDiscoveryResponse(w)
|
return writeMatrixDiscoveryResponse(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
|
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Message, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
t, err := fromContext[*topic](r, contextTopic)
|
t, err := fromContext[*topic](r, contextTopic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -924,7 +932,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
minc(metricMessagesPublishedSuccess)
|
minc(metricMessagesPublishedSuccess)
|
||||||
return s.writeJSON(w, m.forJSON())
|
return s.writeJSON(w, m.ForJSON())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
@@ -1014,10 +1022,10 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.messages++
|
s.messages++
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return s.writeJSON(w, m.forJSON())
|
return s.writeJSON(w, m.ForJSON())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendToFirebase(v *visitor, m *message) {
|
func (s *Server) sendToFirebase(v *visitor, m *model.Message) {
|
||||||
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
|
logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
|
||||||
if err := s.firebaseClient.Send(v, m); err != nil {
|
if err := s.firebaseClient.Send(v, m); err != nil {
|
||||||
minc(metricFirebasePublishedFailure)
|
minc(metricFirebasePublishedFailure)
|
||||||
@@ -1031,7 +1039,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
|
|||||||
minc(metricFirebasePublishedSuccess)
|
minc(metricFirebasePublishedSuccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
func (s *Server) sendEmail(v *visitor, m *model.Message, email string) {
|
||||||
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
|
logvm(v, m).Tag(tagEmail).Field("email", email).Debug("Sending email to %s", email)
|
||||||
if err := s.smtpSender.Send(v, m, email); err != nil {
|
if err := s.smtpSender.Send(v, m, email); err != nil {
|
||||||
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
|
logvm(v, m).Tag(tagEmail).Field("email", email).Err(err).Warn("Unable to send email to %s: %v", email, err.Error())
|
||||||
@@ -1041,7 +1049,7 @@ func (s *Server) sendEmail(v *visitor, m *message, email string) {
|
|||||||
minc(metricEmailsPublishedSuccess)
|
minc(metricEmailsPublishedSuccess)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
func (s *Server) forwardPollRequest(v *visitor, m *model.Message) {
|
||||||
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
|
||||||
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
|
topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
|
||||||
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
|
forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
|
||||||
@@ -1073,7 +1081,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
func (s *Server) parsePublishParams(r *http.Request, m *model.Message) (cache bool, firebase bool, email, call string, template templateMode, unifiedpush bool, priorityStr string, err *errHTTP) {
|
||||||
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
|
if r.Method != http.MethodGet && updatePathRegex.MatchString(r.URL.Path) {
|
||||||
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
|
pathSequenceID, err := s.sequenceIDFromPath(r.URL.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1100,7 +1108,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
|||||||
filename := readParam(r, "x-filename", "filename", "file", "f")
|
filename := readParam(r, "x-filename", "filename", "file", "f")
|
||||||
attach := readParam(r, "x-attach", "attach", "a")
|
attach := readParam(r, "x-attach", "attach", "a")
|
||||||
if attach != "" || filename != "" {
|
if attach != "" || filename != "" {
|
||||||
m.Attachment = &attachment{}
|
m.Attachment = &model.Attachment{}
|
||||||
}
|
}
|
||||||
if filename != "" {
|
if filename != "" {
|
||||||
m.Attachment.Name = filename
|
m.Attachment.Name = filename
|
||||||
@@ -1221,7 +1229,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
|
|||||||
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
|
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
|
||||||
// 7. curl -T file.txt ntfy.sh/mytopic
|
// 7. curl -T file.txt ntfy.sh/mytopic
|
||||||
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
// In all other cases, mostly if file.txt is > message limit, treat it as an attachment
|
||||||
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser, template templateMode, unifiedpush bool, priorityStr string) error {
|
||||||
if m.Event == pollRequestEvent { // Case 1
|
if m.Event == pollRequestEvent { // Case 1
|
||||||
return s.handleBodyDiscard(body)
|
return s.handleBodyDiscard(body)
|
||||||
} else if unifiedpush {
|
} else if unifiedpush {
|
||||||
@@ -1244,7 +1252,7 @@ func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
|
func (s *Server) handleBodyAsMessageAutoDetect(m *model.Message, body *util.PeekedReadCloser) error {
|
||||||
if utf8.Valid(body.PeekedBytes) {
|
if utf8.Valid(body.PeekedBytes) {
|
||||||
m.Message = string(body.PeekedBytes) // Do not trim
|
m.Message = string(body.PeekedBytes) // Do not trim
|
||||||
} else {
|
} else {
|
||||||
@@ -1254,7 +1262,7 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
|
func (s *Server) handleBodyAsTextMessage(m *model.Message, body *util.PeekedReadCloser) error {
|
||||||
if !utf8.Valid(body.PeekedBytes) {
|
if !utf8.Valid(body.PeekedBytes) {
|
||||||
return errHTTPBadRequestMessageNotUTF8.With(m)
|
return errHTTPBadRequestMessageNotUTF8.With(m)
|
||||||
}
|
}
|
||||||
@@ -1267,7 +1275,7 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
func (s *Server) handleBodyAsTemplatedTextMessage(m *model.Message, template templateMode, body *util.PeekedReadCloser, priorityStr string) error {
|
||||||
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
|
body, err := util.Peek(body, max(s.config.MessageSizeLimit, jsonBodyBytesLimit))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1292,7 +1300,7 @@ func (s *Server) handleBodyAsTemplatedTextMessage(m *message, template templateM
|
|||||||
|
|
||||||
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
|
// renderTemplateFromFile transforms the JSON message body according to a template from the filesystem.
|
||||||
// The template file must be in the templates directory, or in the configured template directory.
|
// The template file must be in the templates directory, or in the configured template directory.
|
||||||
func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody string) error {
|
func (s *Server) renderTemplateFromFile(m *model.Message, templateName, peekedBody string) error {
|
||||||
if !templateNameRegex.MatchString(templateName) {
|
if !templateNameRegex.MatchString(templateName) {
|
||||||
return errHTTPBadRequestTemplateFileNotFound
|
return errHTTPBadRequestTemplateFileNotFound
|
||||||
}
|
}
|
||||||
@@ -1334,7 +1342,7 @@ func (s *Server) renderTemplateFromFile(m *message, templateName, peekedBody str
|
|||||||
|
|
||||||
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
|
// renderTemplateFromParams transforms the JSON message body according to the inline template in the
|
||||||
// message, title, and priority parameters.
|
// message, title, and priority parameters.
|
||||||
func (s *Server) renderTemplateFromParams(m *message, peekedBody string, priorityStr string) error {
|
func (s *Server) renderTemplateFromParams(m *model.Message, peekedBody string, priorityStr string) error {
|
||||||
var err error
|
var err error
|
||||||
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
|
if m.Message, err = s.renderTemplate("priority query parameter", m.Message, peekedBody); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1375,7 +1383,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) {
|
|||||||
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
|
return strings.TrimSpace(strings.ReplaceAll(buf.String(), "\\n", "\n")), nil // replace any remaining "\n" (those outside of template curly braces) with newlines
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
|
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error {
|
||||||
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
|
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
|
||||||
return errHTTPBadRequestAttachmentsDisallowed.With(m)
|
return errHTTPBadRequestAttachmentsDisallowed.With(m)
|
||||||
}
|
}
|
||||||
@@ -1399,7 +1407,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.Attachment == nil {
|
if m.Attachment == nil {
|
||||||
m.Attachment = &attachment{}
|
m.Attachment = &model.Attachment{}
|
||||||
}
|
}
|
||||||
var ext string
|
var ext string
|
||||||
m.Attachment.Expires = attachmentExpiry
|
m.Attachment.Expires = attachmentExpiry
|
||||||
@@ -1426,9 +1434,9 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
encoder := func(msg *message) (string, error) {
|
encoder := func(msg *model.Message) (string, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
@@ -1437,9 +1445,9 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
encoder := func(msg *message) (string, error) {
|
encoder := func(msg *model.Message) (string, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
|
if err := json.NewEncoder(&buf).Encode(msg.ForJSON()); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||||
@@ -1451,7 +1459,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
encoder := func(msg *message) (string, error) {
|
encoder := func(msg *model.Message) (string, error) {
|
||||||
if msg.Event == messageEvent { // only handle default events
|
if msg.Event == messageEvent { // only handle default events
|
||||||
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
|
||||||
}
|
}
|
||||||
@@ -1487,7 +1495,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||||||
closed = true
|
closed = true
|
||||||
wlock.Unlock()
|
wlock.Unlock()
|
||||||
}()
|
}()
|
||||||
sub := func(v *visitor, msg *message) error {
|
sub := func(v *visitor, msg *model.Message) error {
|
||||||
if !filters.Pass(msg) {
|
if !filters.Pass(msg) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1649,7 +1657,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
sub := func(v *visitor, msg *message) error {
|
sub := func(v *visitor, msg *model.Message) error {
|
||||||
if !filters.Pass(msg) {
|
if !filters.Pass(msg) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1696,7 +1704,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
|
func parseSubscribeParams(r *http.Request) (poll bool, since model.SinceMarker, scheduled bool, filters *queryFilter, err error) {
|
||||||
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
||||||
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
||||||
since, err = parseSince(r, poll)
|
since, err = parseSince(r, poll)
|
||||||
@@ -1777,11 +1785,11 @@ func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topi
|
|||||||
|
|
||||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||||
// marker, returning only messages that are newer than the marker.
|
// marker, returning only messages that are newer than the marker.
|
||||||
func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
func (s *Server) sendOldMessages(topics []*topic, since model.SinceMarker, scheduled bool, v *visitor, sub subscriber) error {
|
||||||
if since.IsNone() {
|
if since.IsNone() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
messages := make([]*message, 0)
|
messages := make([]*model.Message, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
|
topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1804,7 +1812,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
|
|||||||
//
|
//
|
||||||
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
|
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
|
||||||
// "all" for all messages, or "latest" for the most recent message for a topic
|
// "all" for all messages, or "latest" for the most recent message for a topic
|
||||||
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
|
func parseSince(r *http.Request, poll bool) (model.SinceMarker, error) {
|
||||||
since := readParam(r, "x-since", "since", "si")
|
since := readParam(r, "x-since", "since", "si")
|
||||||
|
|
||||||
// Easy cases (empty, all, none)
|
// Easy cases (empty, all, none)
|
||||||
@@ -2035,7 +2043,7 @@ func (s *Server) sendDelayedMessages() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
|
func (s *Server) sendDelayedMessage(v *visitor, m *model.Message) error {
|
||||||
logvm(v, m).Debug("Sending delayed message")
|
logvm(v, m).Debug("Sending delayed message")
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -11,427 +11,451 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestVersion_Admin(t *testing.T) {
|
func TestVersion_Admin(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
forEachBackend(t, func(t *testing.T) {
|
||||||
c.BuildVersion = "1.2.3"
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.BuildCommit = "abcdef0"
|
c.BuildVersion = "1.2.3"
|
||||||
c.BuildDate = "2026-02-08T00:00:00Z"
|
c.BuildCommit = "abcdef0"
|
||||||
s := newTestServer(t, c)
|
c.BuildDate = "2026-02-08T00:00:00Z"
|
||||||
defer s.closeDatabases()
|
s := newTestServer(t, c)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin and regular user
|
// Create admin and regular user
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
|
|
||||||
// Admin can access /v1/version
|
// Admin can access /v1/version
|
||||||
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
|
rr := request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
var versionResponse apiVersionResponse
|
||||||
|
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
|
||||||
|
require.Equal(t, "1.2.3", versionResponse.Version)
|
||||||
|
require.Equal(t, "abcdef0", versionResponse.Commit)
|
||||||
|
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
|
||||||
|
|
||||||
|
// Non-admin user cannot access /v1/version
|
||||||
|
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
|
|
||||||
|
// Unauthenticated user cannot access /v1/version
|
||||||
|
rr = request(t, s, "GET", "/v1/version", "", nil)
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
var versionResponse apiVersionResponse
|
|
||||||
require.Nil(t, json.NewDecoder(rr.Body).Decode(&versionResponse))
|
|
||||||
require.Equal(t, "1.2.3", versionResponse.Version)
|
|
||||||
require.Equal(t, "abcdef0", versionResponse.Commit)
|
|
||||||
require.Equal(t, "2026-02-08T00:00:00Z", versionResponse.Date)
|
|
||||||
|
|
||||||
// Non-admin user cannot access /v1/version
|
|
||||||
rr = request(t, s, "GET", "/v1/version", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
|
|
||||||
// Unauthenticated user cannot access /v1/version
|
|
||||||
rr = request(t, s, "GET", "/v1/version", "", nil)
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_AddRemove(t *testing.T) {
|
func TestUser_AddRemove(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
defer s.closeDatabases()
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin, tier
|
// Create admin, tier
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
Code: "tier1",
|
Code: "tier1",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Create user via API
|
// Create user via API
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Create user with tier via API
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check users
|
||||||
|
users, err := s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 4, len(users))
|
||||||
|
require.Equal(t, "phil", users[0].Name)
|
||||||
|
require.Equal(t, "ben", users[1].Name)
|
||||||
|
require.Equal(t, user.RoleUser, users[1].Role)
|
||||||
|
require.Nil(t, users[1].Tier)
|
||||||
|
require.Equal(t, "emma", users[2].Name)
|
||||||
|
require.Equal(t, user.RoleUser, users[2].Role)
|
||||||
|
require.Equal(t, "tier1", users[2].Tier.Code)
|
||||||
|
require.Equal(t, user.Everyone, users[3].Name)
|
||||||
|
|
||||||
|
// Delete user via API
|
||||||
|
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check user was deleted
|
||||||
|
users, err = s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, len(users))
|
||||||
|
require.Equal(t, "phil", users[0].Name)
|
||||||
|
require.Equal(t, "emma", users[1].Name)
|
||||||
|
require.Equal(t, user.Everyone, users[2].Name)
|
||||||
|
|
||||||
|
// Reject invalid user change
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 400, rr.Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Create user with tier via API
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "tier1"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check users
|
|
||||||
users, err := s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 4, len(users))
|
|
||||||
require.Equal(t, "phil", users[0].Name)
|
|
||||||
require.Equal(t, "ben", users[1].Name)
|
|
||||||
require.Equal(t, user.RoleUser, users[1].Role)
|
|
||||||
require.Nil(t, users[1].Tier)
|
|
||||||
require.Equal(t, "emma", users[2].Name)
|
|
||||||
require.Equal(t, user.RoleUser, users[2].Role)
|
|
||||||
require.Equal(t, "tier1", users[2].Tier.Code)
|
|
||||||
require.Equal(t, user.Everyone, users[3].Name)
|
|
||||||
|
|
||||||
// Delete user via API
|
|
||||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check user was deleted
|
|
||||||
users, err = s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 3, len(users))
|
|
||||||
require.Equal(t, "phil", users[0].Name)
|
|
||||||
require.Equal(t, "emma", users[1].Name)
|
|
||||||
require.Equal(t, user.Everyone, users[2].Name)
|
|
||||||
|
|
||||||
// Reject invalid user change
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 400, rr.Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_AddWithPasswordHash(t *testing.T) {
|
func TestUser_AddWithPasswordHash(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
defer s.closeDatabases()
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
// Create admin
|
// Create admin
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
|
||||||
// Create user via API
|
// Create user via API
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
// Check that user can login with password
|
// Check that user can login with password
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check users
|
|
||||||
users, err := s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 3, len(users))
|
|
||||||
require.Equal(t, "phil", users[0].Name)
|
|
||||||
require.Equal(t, user.RoleAdmin, users[0].Role)
|
|
||||||
require.Equal(t, "ben", users[1].Name)
|
|
||||||
require.Equal(t, user.RoleUser, users[1].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_ChangeUserPassword(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
|
|
||||||
// Create user via API
|
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Try to login with first password
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Change password via API
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Make sure first password fails
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
|
|
||||||
// Try to login with second password
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_ChangeUserTier(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin, tier
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
|
||||||
Code: "tier1",
|
|
||||||
}))
|
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
|
||||||
Code: "tier2",
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Create user with tier via API
|
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check users
|
|
||||||
users, err := s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 3, len(users))
|
|
||||||
require.Equal(t, "phil", users[0].Name)
|
|
||||||
require.Equal(t, "ben", users[1].Name)
|
|
||||||
require.Equal(t, user.RoleUser, users[1].Role)
|
|
||||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
|
||||||
|
|
||||||
// Change user tier via API
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check users again
|
|
||||||
users, err = s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin, tier
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
|
||||||
Code: "tier1",
|
|
||||||
}))
|
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
|
||||||
Code: "tier2",
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Create user with tier via API
|
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check users
|
|
||||||
users, err := s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 3, len(users))
|
|
||||||
require.Equal(t, "phil", users[0].Name)
|
|
||||||
require.Equal(t, "ben", users[1].Name)
|
|
||||||
require.Equal(t, user.RoleUser, users[1].Role)
|
|
||||||
require.Equal(t, "tier1", users[1].Tier.Code)
|
|
||||||
|
|
||||||
// Change user password and tier via API
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Make sure first password fails
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
|
|
||||||
// Try to login with second password
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben-two"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Check new tier
|
|
||||||
users, err = s.userManager.Users()
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "tier2", users[1].Tier.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
|
|
||||||
// Create user with tier via API
|
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Try to login with first password
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "not-ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Change user password and tier via API
|
|
||||||
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Try to login with second password
|
|
||||||
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
|
||||||
|
|
||||||
// Try to change password via API
|
|
||||||
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 403, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUser_AddRemove_Failures(t *testing.T) {
|
|
||||||
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// Create admin
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
|
||||||
|
|
||||||
// Cannot create user with invalid username
|
|
||||||
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 400, rr.Code)
|
|
||||||
|
|
||||||
// Cannot create user if user already exists
|
|
||||||
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
|
||||||
|
|
||||||
// Cannot create user with invalid tier
|
|
||||||
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
|
||||||
|
|
||||||
// Cannot delete user as non-admin
|
|
||||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
|
|
||||||
// Delete user via API
|
|
||||||
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccess_AllowReset(t *testing.T) {
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
|
||||||
s := newTestServer(t, c)
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// User and admin
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
|
||||||
|
|
||||||
// Subscribing not allowed
|
|
||||||
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 403, rr.Code)
|
|
||||||
|
|
||||||
// Grant access
|
|
||||||
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Now subscribing is allowed
|
|
||||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Reset access
|
|
||||||
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, rr.Code)
|
|
||||||
|
|
||||||
// Subscribing not allowed (again)
|
|
||||||
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 403, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
|
||||||
s := newTestServer(t, c)
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// User
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
|
||||||
|
|
||||||
// Grant access fails, because non-admin
|
|
||||||
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 401, rr.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
|
||||||
c.AuthDefault = user.PermissionDenyAll
|
|
||||||
s := newTestServer(t, c)
|
|
||||||
defer s.closeDatabases()
|
|
||||||
|
|
||||||
// User and admin, grant access to "gol*" topics
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
|
||||||
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
|
||||||
|
|
||||||
start, timeTaken := time.Now(), atomic.Int64{}
|
|
||||||
go func() {
|
|
||||||
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
timeTaken.Store(time.Since(start).Milliseconds())
|
|
||||||
}()
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// Reset access
|
// Check users
|
||||||
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
users, err := s.userManager.Users()
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
require.Nil(t, err)
|
||||||
})
|
require.Equal(t, 3, len(users))
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, "phil", users[0].Name)
|
||||||
|
require.Equal(t, user.RoleAdmin, users[0].Role)
|
||||||
// Wait for connection to be killed; this will fail if the connection is never killed
|
require.Equal(t, "ben", users[1].Name)
|
||||||
waitFor(t, func() bool {
|
require.Equal(t, user.RoleUser, users[1].Role)
|
||||||
return timeTaken.Load() >= 500
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_ChangeUserPassword(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
|
||||||
|
// Create user via API
|
||||||
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Try to login with first password
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Change password via API
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Make sure first password fails
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
|
|
||||||
|
// Try to login with second password
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_ChangeUserTier(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin, tier
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "tier1",
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "tier2",
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Create user with tier via API
|
||||||
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check users
|
||||||
|
users, err := s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, len(users))
|
||||||
|
require.Equal(t, "phil", users[0].Name)
|
||||||
|
require.Equal(t, "ben", users[1].Name)
|
||||||
|
require.Equal(t, user.RoleUser, users[1].Role)
|
||||||
|
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||||
|
|
||||||
|
// Change user tier via API
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check users again
|
||||||
|
users, err = s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin, tier
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "tier1",
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "tier2",
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Create user with tier via API
|
||||||
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check users
|
||||||
|
users, err := s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 3, len(users))
|
||||||
|
require.Equal(t, "phil", users[0].Name)
|
||||||
|
require.Equal(t, "ben", users[1].Name)
|
||||||
|
require.Equal(t, user.RoleUser, users[1].Role)
|
||||||
|
require.Equal(t, "tier1", users[1].Tier.Code)
|
||||||
|
|
||||||
|
// Change user password and tier via API
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Make sure first password fails
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
|
|
||||||
|
// Try to login with second password
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben-two"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Check new tier
|
||||||
|
users, err = s.userManager.Users()
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "tier2", users[1].Tier.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
|
||||||
|
// Create user with tier via API
|
||||||
|
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Try to login with first password
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "not-ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Change user password and tier via API
|
||||||
|
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Try to login with second password
|
||||||
|
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_DontChangeAdminPassword(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
|
||||||
|
|
||||||
|
// Try to change password via API
|
||||||
|
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 403, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser_AddRemove_Failures(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithAuthFile(t))
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// Create admin
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
|
|
||||||
|
// Cannot create user with invalid username
|
||||||
|
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 400, rr.Code)
|
||||||
|
|
||||||
|
// Cannot create user if user already exists
|
||||||
|
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
|
||||||
|
|
||||||
|
// Cannot create user with invalid tier
|
||||||
|
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
|
||||||
|
|
||||||
|
// Cannot delete user as non-admin
|
||||||
|
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
|
|
||||||
|
// Delete user via API
|
||||||
|
rr = request(t, s, "DELETE", "/v1/users", `{"username": "ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccess_AllowReset(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// User and admin
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
|
|
||||||
|
// Subscribing not allowed
|
||||||
|
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 403, rr.Code)
|
||||||
|
|
||||||
|
// Grant access
|
||||||
|
rr = request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Now subscribing is allowed
|
||||||
|
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Reset access
|
||||||
|
rr = request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Subscribing not allowed (again)
|
||||||
|
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 403, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// User
|
||||||
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
|
|
||||||
|
// Grant access fails, because non-admin
|
||||||
|
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 401, rr.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccess_AllowReset_KillConnection(t *testing.T) {
|
||||||
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
defer s.closeDatabases()
|
||||||
|
|
||||||
|
// User and admin, grant access to "gol*" topics
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
|
||||||
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
|
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
|
||||||
|
|
||||||
|
start, timeTaken := time.Now(), atomic.Int64{}
|
||||||
|
go func() {
|
||||||
|
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
timeTaken.Store(time.Since(start).Milliseconds())
|
||||||
|
}()
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Reset access
|
||||||
|
rr := request(t, s, "DELETE", "/v1/users/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
|
// Wait for connection to be killed; this will fail if the connection is never killed
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
return timeTaken.Load() >= 500
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"firebase.google.com/go/v4/messaging"
|
"firebase.google.com/go/v4/messaging"
|
||||||
"fmt"
|
"fmt"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -43,7 +44,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||||
if !v.FirebaseAllowed() {
|
if !v.FirebaseAllowed() {
|
||||||
return errFirebaseTemporarilyBanned
|
return errFirebaseTemporarilyBanned
|
||||||
}
|
}
|
||||||
@@ -121,7 +122,7 @@ func (c *firebaseSenderImpl) Send(m *messaging.Message) error {
|
|||||||
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
|
// On Android, this will trigger the app to poll the topic and thereby displaying new messages.
|
||||||
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
|
// - If UpstreamBaseURL is set, messages are forwarded as poll requests to an upstream server and then forwarded
|
||||||
// to Firebase here. This is mainly for iOS to support self-hosted servers.
|
// to Firebase here. This is mainly for iOS to support self-hosted servers.
|
||||||
func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, error) {
|
func toFirebaseMessage(m *model.Message, auther user.Auther) (*messaging.Message, error) {
|
||||||
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
var data map[string]string // Mostly matches https://ntfy.sh/docs/subscribe/api/#json-message-format
|
||||||
var apnsConfig *messaging.APNSConfig
|
var apnsConfig *messaging.APNSConfig
|
||||||
switch m.Event {
|
switch m.Event {
|
||||||
@@ -235,7 +236,7 @@ func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
|
|||||||
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
|
// createAPNSAlertConfig creates an APNS config for iOS notifications that show up as an alert (only relevant for iOS).
|
||||||
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
|
// We must set the Alert struct ("alert"), and we need to set MutableContent ("mutable-content"), so the Notification Service
|
||||||
// Extension in iOS can modify the message.
|
// Extension in iOS can modify the message.
|
||||||
func createAPNSAlertConfig(m *message, data map[string]string) *messaging.APNSConfig {
|
func createAPNSAlertConfig(m *model.Message, data map[string]string) *messaging.APNSConfig {
|
||||||
apnsData := make(map[string]any)
|
apnsData := make(map[string]any)
|
||||||
for k, v := range data {
|
for k, v := range data {
|
||||||
apnsData[k] = v
|
apnsData[k] = v
|
||||||
@@ -296,7 +297,7 @@ func maybeTruncateAPNSBodyMessage(s string) string {
|
|||||||
//
|
//
|
||||||
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
// This empties all the fields that are not needed for a poll request and just sets the required fields,
|
||||||
// most importantly, the PollID.
|
// most importantly, the PollID.
|
||||||
func toPollRequest(m *message) *message {
|
func toPollRequest(m *model.Message) *model.Message {
|
||||||
pr := newPollRequestMessage(m.Topic, m.ID)
|
pr := newPollRequestMessage(m.Topic, m.ID)
|
||||||
pr.ID = m.ID
|
pr.ID = m.ID
|
||||||
pr.Time = m.Time
|
pr.Time = m.Time
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ var (
|
|||||||
type firebaseClient struct {
|
type firebaseClient struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *firebaseClient) Send(v *visitor, m *message) error {
|
func (c *firebaseClient) Send(v *visitor, m *model.Message) error {
|
||||||
return errFirebaseNotAvailable
|
return errFirebaseNotAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -131,7 +132,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
|||||||
m.Click = "https://google.com"
|
m.Click = "https://google.com"
|
||||||
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
|
m.Icon = "https://ntfy.sh/static/img/ntfy.png"
|
||||||
m.Title = "some title"
|
m.Title = "some title"
|
||||||
m.Actions = []*action{
|
m.Actions = []*model.Action{
|
||||||
{
|
{
|
||||||
ID: "123",
|
ID: "123",
|
||||||
Action: "view",
|
Action: "view",
|
||||||
@@ -150,7 +151,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
m.Attachment = &attachment{
|
m.Attachment = &model.Attachment{
|
||||||
Name: "some file.jpg",
|
Name: "some file.jpg",
|
||||||
Type: "image/jpeg",
|
Type: "image/jpeg",
|
||||||
Size: 12345,
|
Size: 12345,
|
||||||
@@ -346,16 +347,16 @@ func TestToFirebaseSender_Abuse(t *testing.T) {
|
|||||||
client := newFirebaseClient(sender, &testAuther{})
|
client := newFirebaseClient(sender, &testAuther{})
|
||||||
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
visitor := newVisitor(newTestConfig(t), newMemTestCache(t), nil, netip.MustParseAddr("1.2.3.4"), nil)
|
||||||
|
|
||||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 1, len(sender.Messages()))
|
require.Equal(t, 1, len(sender.Messages()))
|
||||||
|
|
||||||
require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
|
require.Nil(t, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 2, len(sender.Messages()))
|
require.Equal(t, 2, len(sender.Messages()))
|
||||||
|
|
||||||
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &message{Topic: "mytopic"}))
|
require.Equal(t, errFirebaseQuotaExceeded, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 2, len(sender.Messages()))
|
require.Equal(t, 2, len(sender.Messages()))
|
||||||
|
|
||||||
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
sender.messages = make([]*messaging.Message, 0) // Reset to test that time limit is working
|
||||||
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &message{Topic: "mytopic"}))
|
require.Equal(t, errFirebaseTemporarilyBanned, client.Send(visitor, &model.Message{Topic: "mytopic"}))
|
||||||
require.Equal(t, 0, len(sender.Messages()))
|
require.Equal(t, 0, len(sender.Messages()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,23 +6,25 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
func TestServer_Manager_Prune_Messages_Without_Attachments_DoesNotPanic(t *testing.T) {
|
||||||
// Tests that the manager runs without attachment-cache-dir set, see #617
|
forEachBackend(t, func(t *testing.T) {
|
||||||
c := newTestConfig(t)
|
// Tests that the manager runs without attachment-cache-dir set, see #617
|
||||||
c.AttachmentCacheDir = ""
|
c := newTestConfig(t)
|
||||||
s := newTestServer(t, c)
|
c.AttachmentCacheDir = ""
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
// Publish a message
|
// Publish a message
|
||||||
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
rr := request(t, s, "POST", "/mytopic", "hi", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
m := toMessage(t, rr.Body.String())
|
m := toMessage(t, rr.Body.String())
|
||||||
|
|
||||||
// Expire message
|
// Expire message
|
||||||
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
require.Nil(t, s.messageCache.ExpireMessages("mytopic"))
|
||||||
|
|
||||||
// Does not panic
|
// Does not panic
|
||||||
s.pruneMessages()
|
s.pruneMessages()
|
||||||
|
|
||||||
// Actually deleted
|
// Actually deleted
|
||||||
_, err := s.messageCache.Message(m.ID)
|
_, err := s.messageCache.Message(m.ID)
|
||||||
require.Equal(t, errMessageNotFound, err)
|
require.Equal(t, errMessageNotFound, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -11,6 +11,7 @@ import (
|
|||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -76,7 +77,7 @@ func (s *Server) convertPhoneNumber(u *user.User, phoneNumber string) (string, *
|
|||||||
|
|
||||||
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
|
// callPhone calls the Twilio API to make a phone call to the given phone number, using the given message.
|
||||||
// Failures will be logged, but not returned to the caller.
|
// Failures will be logged, but not returned to the caller.
|
||||||
func (s *Server) callPhone(v *visitor, r *http.Request, m *message, to string) {
|
func (s *Server) callPhone(v *visitor, r *http.Request, m *model.Message, to string) {
|
||||||
u, sender := v.User(), m.Sender.String()
|
u, sender := v.User(), m.Sender.String()
|
||||||
if u != nil {
|
if u != nil {
|
||||||
sender = u.Name
|
sender = u.Name
|
||||||
|
|||||||
@@ -14,217 +14,224 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
|
||||||
var called, verified atomic.Bool
|
forEachBackend(t, func(t *testing.T) {
|
||||||
var code atomic.Pointer[string]
|
var called, verified atomic.Bool
|
||||||
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var code atomic.Pointer[string]
|
||||||
body, err := io.ReadAll(r.Body)
|
twilioVerifyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Nil(t, err)
|
body, err := io.ReadAll(r.Body)
|
||||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
require.Nil(t, err)
|
||||||
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||||
if code.Load() != nil {
|
if r.URL.Path == "/v2/Services/VA1234567890/Verifications" {
|
||||||
|
if code.Load() != nil {
|
||||||
|
t.Fatal("Should be only called once")
|
||||||
|
}
|
||||||
|
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
||||||
|
code.Store(util.String("123456"))
|
||||||
|
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
||||||
|
if verified.Load() {
|
||||||
|
t.Fatal("Should be only called once")
|
||||||
|
}
|
||||||
|
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
||||||
|
verified.Store(true)
|
||||||
|
} else {
|
||||||
|
t.Fatal("Unexpected path:", r.URL.Path)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer twilioVerifyServer.Close()
|
||||||
|
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if called.Load() {
|
||||||
t.Fatal("Should be only called once")
|
t.Fatal("Should be only called once")
|
||||||
}
|
}
|
||||||
require.Equal(t, "Channel=sms&To=%2B12223334444", string(body))
|
body, err := io.ReadAll(r.Body)
|
||||||
code.Store(util.String("123456"))
|
require.Nil(t, err)
|
||||||
} else if r.URL.Path == "/v2/Services/VA1234567890/VerificationCheck" {
|
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||||
if verified.Load() {
|
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||||
t.Fatal("Should be only called once")
|
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||||
}
|
called.Store(true)
|
||||||
require.Equal(t, "Code=123456&To=%2B12223334444", string(body))
|
}))
|
||||||
verified.Store(true)
|
defer twilioCallsServer.Close()
|
||||||
} else {
|
|
||||||
t.Fatal("Unexpected path:", r.URL.Path)
|
c := newTestConfigWithAuthFile(t)
|
||||||
}
|
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
||||||
}))
|
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
||||||
defer twilioVerifyServer.Close()
|
c.TwilioAccount = "AC1234567890"
|
||||||
twilioCallsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
if called.Load() {
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
t.Fatal("Should be only called once")
|
c.TwilioVerifyService = "VA1234567890"
|
||||||
}
|
s := newTestServer(t, c)
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
|
// Add tier and user
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
MessageLimit: 10,
|
||||||
|
CallLimit: 1,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
|
||||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
|
||||||
require.Equal(t, "From=%2B1234567890&To=%2B12223334444&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
|
||||||
called.Store(true)
|
|
||||||
}))
|
|
||||||
defer twilioCallsServer.Close()
|
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
// Send verification code for phone number
|
||||||
c.TwilioVerifyBaseURL = twilioVerifyServer.URL
|
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
||||||
c.TwilioCallsBaseURL = twilioCallsServer.URL
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
c.TwilioAccount = "AC1234567890"
|
})
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
require.Equal(t, 200, response.Code)
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
waitFor(t, func() bool {
|
||||||
c.TwilioVerifyService = "VA1234567890"
|
return *code.Load() == "123456"
|
||||||
s := newTestServer(t, c)
|
})
|
||||||
|
|
||||||
// Add tier and user
|
// Add phone number with code
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
||||||
Code: "pro",
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
MessageLimit: 10,
|
})
|
||||||
CallLimit: 1,
|
require.Equal(t, 200, response.Code)
|
||||||
}))
|
waitFor(t, func() bool {
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
return verified.Load()
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
})
|
||||||
u, err := s.userManager.User("phil")
|
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(phoneNumbers))
|
||||||
|
require.Equal(t, "+12223334444", phoneNumbers[0])
|
||||||
|
|
||||||
// Send verification code for phone number
|
// Do the thing
|
||||||
response := request(t, s, "PUT", "/v1/account/phone/verify", `{"number":"+12223334444","channel":"sms"}`, map[string]string{
|
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
"x-call": "yes",
|
||||||
require.Equal(t, 200, response.Code)
|
})
|
||||||
waitFor(t, func() bool {
|
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||||
return *code.Load() == "123456"
|
waitFor(t, func() bool {
|
||||||
})
|
return called.Load()
|
||||||
|
})
|
||||||
|
|
||||||
// Add phone number with code
|
// Remove the phone number
|
||||||
response = request(t, s, "PUT", "/v1/account/phone", `{"number":"+12223334444","code":"123456"}`, map[string]string{
|
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
waitFor(t, func() bool {
|
|
||||||
return verified.Load()
|
|
||||||
})
|
|
||||||
phoneNumbers, err := s.userManager.PhoneNumbers(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 1, len(phoneNumbers))
|
|
||||||
require.Equal(t, "+12223334444", phoneNumbers[0])
|
|
||||||
|
|
||||||
// Do the thing
|
// Verify the phone number is gone from the DB
|
||||||
response = request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
require.Nil(t, err)
|
||||||
"x-call": "yes",
|
require.Equal(t, 0, len(phoneNumbers))
|
||||||
})
|
})
|
||||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
|
||||||
waitFor(t, func() bool {
|
|
||||||
return called.Load()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Remove the phone number
|
|
||||||
response = request(t, s, "DELETE", "/v1/account/phone", `{"number":"+12223334444"}`, map[string]string{
|
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
})
|
|
||||||
require.Equal(t, 200, response.Code)
|
|
||||||
|
|
||||||
// Verify the phone number is gone from the DB
|
|
||||||
phoneNumbers, err = s.userManager.PhoneNumbers(u.ID)
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, 0, len(phoneNumbers))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success(t *testing.T) {
|
func TestServer_Twilio_Call_Success(t *testing.T) {
|
||||||
var called atomic.Bool
|
forEachBackend(t, func(t *testing.T) {
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var called atomic.Bool
|
||||||
if called.Load() {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Fatal("Should be only called once")
|
if called.Load() {
|
||||||
}
|
t.Fatal("Should be only called once")
|
||||||
body, err := io.ReadAll(r.Body)
|
}
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||||
|
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||||
|
called.Store(true)
|
||||||
|
}))
|
||||||
|
defer twilioServer.Close()
|
||||||
|
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
|
c.TwilioAccount = "AC1234567890"
|
||||||
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
// Add tier and user
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
MessageLimit: 10,
|
||||||
|
CallLimit: 1,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
|
||||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
|
||||||
called.Store(true)
|
|
||||||
}))
|
|
||||||
defer twilioServer.Close()
|
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
// Do the thing
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||||
c.TwilioAccount = "AC1234567890"
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
"x-call": "+11122233344",
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
})
|
||||||
s := newTestServer(t, c)
|
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||||
|
waitFor(t, func() bool {
|
||||||
// Add tier and user
|
return called.Load()
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
})
|
||||||
Code: "pro",
|
|
||||||
MessageLimit: 10,
|
|
||||||
CallLimit: 1,
|
|
||||||
}))
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
|
||||||
u, err := s.userManager.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
|
||||||
|
|
||||||
// Do the thing
|
|
||||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
"x-call": "+11122233344",
|
|
||||||
})
|
|
||||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
|
||||||
waitFor(t, func() bool {
|
|
||||||
return called.Load()
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
|
||||||
var called atomic.Bool
|
forEachBackend(t, func(t *testing.T) {
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var called atomic.Bool
|
||||||
if called.Load() {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Fatal("Should be only called once")
|
if called.Load() {
|
||||||
}
|
t.Fatal("Should be only called once")
|
||||||
body, err := io.ReadAll(r.Body)
|
}
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||||
|
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||||
|
called.Store(true)
|
||||||
|
}))
|
||||||
|
defer twilioServer.Close()
|
||||||
|
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
|
c.TwilioAccount = "AC1234567890"
|
||||||
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
// Add tier and user
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
MessageLimit: 10,
|
||||||
|
CallLimit: 1,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
|
||||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+loop%3D%223%22%3E%0A%09%09You+have+a+message+from+notify+on+topic+mytopic.+Message%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09End+of+message.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09This+message+was+sent+by+user+phil.+It+will+be+repeated+three+times.%0A%09%09To+unsubscribe+from+calls+like+this%2C+remove+your+phone+number+in+the+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay%3EGoodbye.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
|
||||||
called.Store(true)
|
|
||||||
}))
|
|
||||||
defer twilioServer.Close()
|
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
// Do the thing
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||||
c.TwilioAccount = "AC1234567890"
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
"x-call": "yes", // <<<------
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
})
|
||||||
s := newTestServer(t, c)
|
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||||
|
waitFor(t, func() bool {
|
||||||
// Add tier and user
|
return called.Load()
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
})
|
||||||
Code: "pro",
|
|
||||||
MessageLimit: 10,
|
|
||||||
CallLimit: 1,
|
|
||||||
}))
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
|
||||||
u, err := s.userManager.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
|
||||||
|
|
||||||
// Do the thing
|
|
||||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
|
||||||
"x-call": "yes", // <<<------
|
|
||||||
})
|
|
||||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
|
||||||
waitFor(t, func() bool {
|
|
||||||
return called.Load()
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
||||||
var called atomic.Bool
|
forEachBackend(t, func(t *testing.T) {
|
||||||
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
var called atomic.Bool
|
||||||
if called.Load() {
|
twilioServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Fatal("Should be only called once")
|
if called.Load() {
|
||||||
}
|
t.Fatal("Should be only called once")
|
||||||
body, err := io.ReadAll(r.Body)
|
}
|
||||||
require.Nil(t, err)
|
body, err := io.ReadAll(r.Body)
|
||||||
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
require.Equal(t, "/2010-04-01/Accounts/AC1234567890/Calls.json", r.URL.Path)
|
||||||
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
require.Equal(t, "Basic QUMxMjM0NTY3ODkwOkFBRUFBMTIzNDU2Nzg5MA==", r.Header.Get("Authorization"))
|
||||||
called.Store(true)
|
require.Equal(t, "From=%2B1234567890&To=%2B11122233344&Twiml=%0A%3CResponse%3E%0A%09%3CPause+length%3D%221%22%2F%3E%0A%09%3CSay+language%3D%22de-DE%22+loop%3D%223%22%3E%0A%09%09Du+hast+eine+Nachricht+von+notify+im+Thema+mytopic.+Nachricht%3A%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09hi+there%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Ende+der+Nachricht.%0A%09%09%3Cbreak+time%3D%221s%22%2F%3E%0A%09%09Diese+Nachricht+wurde+von+Benutzer+phil+gesendet.+Sie+wird+drei+Mal+wiederholt.%0A%09%09Um+dich+von+Anrufen+wie+diesen+abzumelden%2C+entferne+deine+Telefonnummer+in+der+notify+web+app.%0A%09%09%3Cbreak+time%3D%223s%22%2F%3E%0A%09%3C%2FSay%3E%0A%09%3CSay+language%3D%22de-DE%22%3EAuf+Wiederh%C3%B6ren.%3C%2FSay%3E%0A%3C%2FResponse%3E", string(body))
|
||||||
}))
|
called.Store(true)
|
||||||
defer twilioServer.Close()
|
}))
|
||||||
|
defer twilioServer.Close()
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.TwilioCallsBaseURL = twilioServer.URL
|
c.TwilioCallsBaseURL = twilioServer.URL
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
c.TwilioCallFormat = template.Must(template.New("twiml").Parse(`
|
||||||
<Response>
|
<Response>
|
||||||
<Pause length="1"/>
|
<Pause length="1"/>
|
||||||
<Say language="de-DE" loop="3">
|
<Say language="de-DE" loop="3">
|
||||||
@@ -240,88 +247,97 @@ func TestServer_Twilio_Call_Success_with_custom_twiml(t *testing.T) {
|
|||||||
</Say>
|
</Say>
|
||||||
<Say language="de-DE">Auf Wiederhören.</Say>
|
<Say language="de-DE">Auf Wiederhören.</Say>
|
||||||
</Response>`))
|
</Response>`))
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
// Add tier and user
|
// Add tier and user
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
MessageLimit: 10,
|
MessageLimit: 10,
|
||||||
CallLimit: 1,
|
CallLimit: 1,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
u, err := s.userManager.User("phil")
|
u, err := s.userManager.User("phil")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
require.Nil(t, s.userManager.AddPhoneNumber(u.ID, "+11122233344"))
|
||||||
|
|
||||||
// Do the thing
|
// Do the thing
|
||||||
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
response := request(t, s, "POST", "/mytopic", "hi there", map[string]string{
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
"x-call": "+11122233344",
|
"x-call": "+11122233344",
|
||||||
})
|
})
|
||||||
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
require.Equal(t, "hi there", toMessage(t, response.Body.String()).Message)
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
return called.Load()
|
return called.Load()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
forEachBackend(t, func(t *testing.T) {
|
||||||
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioCallsBaseURL = "http://dummy.invalid"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
s := newTestServer(t, c)
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
// Add tier and user
|
// Add tier and user
|
||||||
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
MessageLimit: 10,
|
MessageLimit: 10,
|
||||||
CallLimit: 1,
|
CallLimit: 1,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
// Do the thing
|
// Do the thing
|
||||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
"authorization": util.BasicAuth("phil", "phil"),
|
"authorization": util.BasicAuth("phil", "phil"),
|
||||||
"x-call": "+11122233344",
|
"x-call": "+11122233344",
|
||||||
|
})
|
||||||
|
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 40034, toHTTPError(t, response.Body.String()).Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
func TestServer_Twilio_Call_InvalidNumber(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
forEachBackend(t, func(t *testing.T) {
|
||||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
s := newTestServer(t, c)
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
"x-call": "+invalid",
|
"x-call": "+invalid",
|
||||||
|
})
|
||||||
|
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 40033, toHTTPError(t, response.Body.String()).Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
func TestServer_Twilio_Call_Anonymous(t *testing.T) {
|
||||||
c := newTestConfigWithAuthFile(t)
|
forEachBackend(t, func(t *testing.T) {
|
||||||
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
c := newTestConfigWithAuthFile(t)
|
||||||
c.TwilioAccount = "AC1234567890"
|
c.TwilioCallsBaseURL = "https://127.0.0.1"
|
||||||
c.TwilioAuthToken = "AAEAA1234567890"
|
c.TwilioAccount = "AC1234567890"
|
||||||
c.TwilioPhoneNumber = "+1234567890"
|
c.TwilioAuthToken = "AAEAA1234567890"
|
||||||
s := newTestServer(t, c)
|
c.TwilioPhoneNumber = "+1234567890"
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
"x-call": "+123123",
|
"x-call": "+123123",
|
||||||
|
})
|
||||||
|
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 40035, toHTTPError(t, response.Body.String()).Code)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
func TestServer_Twilio_Call_Unconfigured(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
s := newTestServer(t, newTestConfig(t))
|
||||||
"x-call": "+1234",
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
|
"x-call": "+1234",
|
||||||
|
})
|
||||||
|
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
||||||
})
|
})
|
||||||
require.Equal(t, 40032, toHTTPError(t, response.Body.String()).Code)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/SherClockHolmes/webpush-go"
|
"github.com/SherClockHolmes/webpush-go"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
wpush "heckel.io/ntfy/v2/webpush"
|
wpush "heckel.io/ntfy/v2/webpush"
|
||||||
)
|
)
|
||||||
@@ -83,14 +84,14 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
|||||||
return s.writeJSON(w, newSuccessResponse())
|
return s.writeJSON(w, newSuccessResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||||
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
logvm(v, m).Err(err).With(v, m).Warn("Unable to publish web push messages")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
|
||||||
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
|
payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.ForJSON()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,7 +22,7 @@ func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *
|
|||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
|
func (s *Server) publishToWebPushEndpoints(v *visitor, m *model.Message) {
|
||||||
// Nothing to see here
|
// Nothing to see here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,236 +26,262 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_WebPush_Enabled(t *testing.T) {
|
func TestServer_WebPush_Enabled(t *testing.T) {
|
||||||
conf := newTestConfig(t)
|
forEachBackend(t, func(t *testing.T) {
|
||||||
conf.WebRoot = "" // Disable web app
|
conf := newTestConfig(t)
|
||||||
s := newTestServer(t, conf)
|
conf.WebRoot = "" // Disable web app
|
||||||
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
rr := request(t, s, "GET", "/manifest.webmanifest", "", nil)
|
||||||
require.Equal(t, 404, rr.Code)
|
require.Equal(t, 404, rr.Code)
|
||||||
|
|
||||||
conf2 := newTestConfig(t)
|
conf2 := newTestConfig(t)
|
||||||
s2 := newTestServer(t, conf2)
|
s2 := newTestServer(t, conf2)
|
||||||
|
|
||||||
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
rr = request(t, s2, "GET", "/manifest.webmanifest", "", nil)
|
||||||
require.Equal(t, 404, rr.Code)
|
require.Equal(t, 404, rr.Code)
|
||||||
|
|
||||||
conf3 := newTestConfigWithWebPush(t)
|
conf3 := newTestConfigWithWebPush(t)
|
||||||
s3 := newTestServer(t, conf3)
|
s3 := newTestServer(t, conf3)
|
||||||
|
|
||||||
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
rr = request(t, s3, "GET", "/manifest.webmanifest", "", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
require.Equal(t, "application/manifest+json", rr.Header().Get("Content-Type"))
|
||||||
|
|
||||||
|
})
|
||||||
}
|
}
|
||||||
func TestServer_WebPush_Disabled(t *testing.T) {
|
func TestServer_WebPush_Disabled(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfig(t))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 404, response.Code)
|
require.Equal(t, 404, response.Code)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
func TestServer_WebPush_TopicAdd(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
|
||||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
require.Len(t, subs, 1)
|
require.Len(t, subs, 1)
|
||||||
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
|
||||||
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
require.Equal(t, subs[0].P256dh, "p256dh-key")
|
||||||
require.Equal(t, subs[0].Auth, "auth-key")
|
require.Equal(t, subs[0].Auth, "auth-key")
|
||||||
require.Equal(t, subs[0].UserID, "")
|
require.Equal(t, subs[0].UserID, "")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
func TestServer_WebPush_TopicAdd_InvalidEndpoint(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, "https://ddos-target.example.com/webpush"), nil)
|
||||||
require.Equal(t, 400, response.Code)
|
require.Equal(t, 400, response.Code)
|
||||||
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
require.Equal(t, `{"code":40039,"http":400,"error":"invalid request: web push endpoint unknown"}`+"\n", response.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
func TestServer_WebPush_TopicAdd_TooManyTopics(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
topicList := make([]string, 51)
|
topicList := make([]string, 51)
|
||||||
for i := range topicList {
|
for i := range topicList {
|
||||||
topicList[i] = util.RandomString(5)
|
topicList[i] = util.RandomString(5)
|
||||||
}
|
}
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, topicList, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 400, response.Code)
|
require.Equal(t, 400, response.Code)
|
||||||
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
require.Equal(t, `{"code":40040,"http":400,"error":"invalid request: too many web push topic subscriptions"}`+"\n", response.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
func TestServer_WebPush_TopicUnsubscribe(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Delete(t *testing.T) {
|
func TestServer_WebPush_Delete(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
addSubscription(t, s, testWebPushEndpoint, "test-topic")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
|
|
||||||
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
response := request(t, s, "DELETE", "/v1/webpush", fmt.Sprintf(`{"endpoint":"%s"}`, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
config.AuthDefault = user.PermissionDenyAll
|
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||||
s := newTestServer(t, config)
|
config.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
|
||||||
|
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Len(t, subs, 1)
|
||||||
|
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, response.Code)
|
|
||||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
|
||||||
|
|
||||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Len(t, subs, 1)
|
|
||||||
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
config.AuthDefault = user.PermissionDenyAll
|
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||||
s := newTestServer(t, config)
|
config.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), nil)
|
||||||
require.Equal(t, 403, response.Code)
|
require.Equal(t, 403, response.Code)
|
||||||
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
|
||||||
config := configureAuth(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
s := newTestServer(t, config)
|
config := configureAuth(t, newTestConfigWithWebPush(t))
|
||||||
|
s := newTestServer(t, config)
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
|
||||||
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
|
||||||
|
|
||||||
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
||||||
|
|
||||||
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
|
|
||||||
|
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
})
|
||||||
|
// should've been deleted with the account
|
||||||
|
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
require.Equal(t, 200, response.Code)
|
|
||||||
require.Equal(t, `{"success":true}`+"\n", response.Body.String())
|
|
||||||
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
|
||||||
|
|
||||||
request(t, s, "DELETE", "/v1/account", `{"password":"ben"}`, map[string]string{
|
|
||||||
"Authorization": util.BasicAuth("ben", "ben"),
|
|
||||||
})
|
|
||||||
// should've been deleted with the account
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Publish(t *testing.T) {
|
func TestServer_WebPush_Publish(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := io.ReadAll(r.Body)
|
_, err := io.ReadAll(r.Body)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, "/push-receive", r.URL.Path)
|
require.Equal(t, "/push-receive", r.URL.Path)
|
||||||
require.Equal(t, "high", r.Header.Get("Urgency"))
|
require.Equal(t, "high", r.Header.Get("Urgency"))
|
||||||
require.Equal(t, "", r.Header.Get("Topic"))
|
require.Equal(t, "", r.Header.Get("Topic"))
|
||||||
received.Store(true)
|
received.Store(true)
|
||||||
}))
|
}))
|
||||||
defer pushService.Close()
|
defer pushService.Close()
|
||||||
|
|
||||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
|
||||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||||
|
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
return received.Load()
|
return received.Load()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := io.ReadAll(r.Body)
|
_, err := io.ReadAll(r.Body)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
w.WriteHeader(http.StatusGone)
|
w.WriteHeader(http.StatusGone)
|
||||||
received.Store(true)
|
received.Store(true)
|
||||||
}))
|
}))
|
||||||
defer pushService.Close()
|
defer pushService.Close()
|
||||||
|
|
||||||
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic", "test-topic-abc")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
requireSubscriptionCount(t, s, "test-topic-abc", 1)
|
||||||
|
|
||||||
request(t, s, "POST", "/test-topic", "web push test", nil)
|
request(t, s, "POST", "/test-topic", "web push test", nil)
|
||||||
|
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
return received.Load()
|
return received.Load()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
||||||
|
|
||||||
|
requireSubscriptionCount(t, s, "test-topic", 0)
|
||||||
|
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint
|
|
||||||
|
|
||||||
requireSubscriptionCount(t, s, "test-topic", 0)
|
|
||||||
requireSubscriptionCount(t, s, "test-topic-abc", 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_WebPush_Expiry(t *testing.T) {
|
func TestServer_WebPush_Expiry(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfigWithWebPush(t))
|
forEachBackend(t, func(t *testing.T) {
|
||||||
|
s := newTestServer(t, newTestConfigWithWebPush(t))
|
||||||
|
|
||||||
var received atomic.Bool
|
var received atomic.Bool
|
||||||
|
|
||||||
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
pushService := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := io.ReadAll(r.Body)
|
_, err := io.ReadAll(r.Body)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
w.Write([]byte(``))
|
w.Write([]byte(``))
|
||||||
received.Store(true)
|
received.Store(true)
|
||||||
}))
|
}))
|
||||||
defer pushService.Close()
|
defer pushService.Close()
|
||||||
|
|
||||||
endpoint := pushService.URL + "/push-receive"
|
endpoint := pushService.URL + "/push-receive"
|
||||||
addSubscription(t, s, endpoint, "test-topic")
|
addSubscription(t, s, endpoint, "test-topic")
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
|
|
||||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
|
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-55*24*time.Hour).Unix()))
|
||||||
|
|
||||||
s.pruneAndNotifyWebPushSubscriptions()
|
s.pruneAndNotifyWebPushSubscriptions()
|
||||||
requireSubscriptionCount(t, s, "test-topic", 1)
|
requireSubscriptionCount(t, s, "test-topic", 1)
|
||||||
|
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
return received.Load()
|
return received.Load()
|
||||||
})
|
})
|
||||||
|
|
||||||
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
|
require.Nil(t, s.webPush.SetSubscriptionUpdatedAt(endpoint, time.Now().Add(-60*24*time.Hour).Unix()))
|
||||||
|
|
||||||
s.pruneAndNotifyWebPushSubscriptions()
|
s.pruneAndNotifyWebPushSubscriptions()
|
||||||
waitFor(t, func() bool {
|
waitFor(t, func() bool {
|
||||||
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
return len(subs) == 0
|
return len(subs) == 0
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +311,9 @@ func newTestConfigWithWebPush(t *testing.T) *Config {
|
|||||||
conf := newTestConfig(t)
|
conf := newTestConfig(t)
|
||||||
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
if conf.DatabaseURL == "" {
|
||||||
|
conf.WebPushFile = filepath.Join(t.TempDir(), "webpush.db")
|
||||||
|
}
|
||||||
conf.WebPushEmailAddress = "testing@example.com"
|
conf.WebPushEmailAddress = "testing@example.com"
|
||||||
conf.WebPushPrivateKey = privateKey
|
conf.WebPushPrivateKey = privateKey
|
||||||
conf.WebPushPublicKey = publicKey
|
conf.WebPushPublicKey = publicKey
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mailer interface {
|
type mailer interface {
|
||||||
Send(v *visitor, m *message, to string) error
|
Send(v *visitor, m *model.Message, to string) error
|
||||||
Counts() (total int64, success int64, failure int64)
|
Counts() (total int64, success int64, failure int64)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ type smtpSender struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSender) Send(v *visitor, m *message, to string) error {
|
func (s *smtpSender) Send(v *visitor, m *model.Message, to string) error {
|
||||||
return s.withCount(v, m, func() error {
|
return s.withCount(v, m, func() error {
|
||||||
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
|
host, _, err := net.SplitHostPort(s.config.SMTPSenderAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -63,7 +64,7 @@ func (s *smtpSender) Counts() (total int64, success int64, failure int64) {
|
|||||||
return s.success + s.failure, s.success, s.failure
|
return s.success + s.failure, s.success, s.failure
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
func (s *smtpSender) withCount(v *visitor, m *model.Message, fn func() error) error {
|
||||||
err := fn()
|
err := fn()
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -76,7 +77,7 @@ func (s *smtpSender) withCount(v *visitor, m *message, fn func() error) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatMail(baseURL, senderIP, from, to string, m *message) (string, error) {
|
func formatMail(baseURL, senderIP, from, to string, m *model.Message) (string, error) {
|
||||||
topicURL := baseURL + "/" + m.Topic
|
topicURL := baseURL + "/" + m.Topic
|
||||||
subject := m.Title
|
subject := m.Title
|
||||||
if subject == "" {
|
if subject == "" {
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFormatMail_Basic(t *testing.T) {
|
func TestFormatMail_Basic(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
@@ -27,7 +29,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMail_JustEmojis(t *testing.T) {
|
func TestFormatMail_JustEmojis(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
@@ -49,7 +51,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMail_JustOtherTags(t *testing.T) {
|
func TestFormatMail_JustOtherTags(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
@@ -73,7 +75,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMail_JustPriority(t *testing.T) {
|
func TestFormatMail_JustPriority(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
@@ -97,7 +99,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMail_UTF8Subject(t *testing.T) {
|
func TestFormatMail_UTF8Subject(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
@@ -119,7 +121,7 @@ This message was sent by 1.2.3.4 at Fri, 24 Dec 2021 21:43:24 UTC via https://nt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMail_WithAllTheThings(t *testing.T) {
|
func TestFormatMail_WithAllTheThings(t *testing.T) {
|
||||||
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &message{
|
actual, _ := formatMail("https://ntfy.sh", "1.2.3.4", "ntfy@ntfy.sh", "phil@example.com", &model.Message{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Time: 1640382204,
|
Time: 1640382204,
|
||||||
Event: "message",
|
Event: "message",
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"github.com/microcosm-cc/bluemonday"
|
"github.com/microcosm-cc/bluemonday"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -183,7 +184,7 @@ func (s *smtpSession) Data(r io.Reader) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *smtpSession) publishMessage(m *message) error {
|
func (s *smtpSession) publishMessage(m *model.Message) error {
|
||||||
// Extract remote address (for rate limiting)
|
// Extract remote address (for rate limiting)
|
||||||
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
|
remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ type topicSubscriber struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// subscriber is a function that is called for every new message on a topic
|
// subscriber is a function that is called for every new message on a topic
|
||||||
type subscriber func(v *visitor, msg *message) error
|
type subscriber func(v *visitor, msg *model.Message) error
|
||||||
|
|
||||||
// newTopic creates a new topic
|
// newTopic creates a new topic
|
||||||
func newTopic(id string) *topic {
|
func newTopic(id string) *topic {
|
||||||
@@ -103,7 +104,7 @@ func (t *topic) Unsubscribe(id int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Publish asynchronously publishes to all subscribers
|
// Publish asynchronously publishes to all subscribers
|
||||||
func (t *topic) Publish(v *visitor, m *message) error {
|
func (t *topic) Publish(v *visitor, m *model.Message) error {
|
||||||
go func() {
|
go func() {
|
||||||
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
// We want to lock the topic as short as possible, so we make a shallow copy of the
|
||||||
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
// subscribers map here. Actually sending out the messages then doesn't have to lock.
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"heckel.io/ntfy/v2/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
||||||
subFn := func(v *visitor, msg *message) error {
|
subFn := func(v *visitor, msg *model.Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
canceled1 := atomic.Bool{}
|
canceled1 := atomic.Bool{}
|
||||||
@@ -33,7 +34,7 @@ func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
|
|||||||
func TestTopic_CancelSubscribersUser(t *testing.T) {
|
func TestTopic_CancelSubscribersUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
subFn := func(v *visitor, msg *message) error {
|
subFn := func(v *visitor, msg *model.Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
canceled1 := atomic.Bool{}
|
canceled1 := atomic.Bool{}
|
||||||
@@ -76,7 +77,7 @@ func TestTopic_Subscribe_DuplicateID(t *testing.T) {
|
|||||||
cancel: func() {},
|
cancel: func() {},
|
||||||
}
|
}
|
||||||
|
|
||||||
subFn := func(v *visitor, msg *message) error {
|
subFn := func(v *visitor, msg *model.Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
255
server/types.go
255
server/types.go
@@ -2,219 +2,78 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// List of possible events
|
// Event constants
|
||||||
const (
|
const (
|
||||||
openEvent = "open"
|
openEvent = model.OpenEvent
|
||||||
keepaliveEvent = "keepalive"
|
keepaliveEvent = model.KeepaliveEvent
|
||||||
messageEvent = "message"
|
messageEvent = model.MessageEvent
|
||||||
messageDeleteEvent = "message_delete"
|
messageDeleteEvent = model.MessageDeleteEvent
|
||||||
messageClearEvent = "message_clear"
|
messageClearEvent = model.MessageClearEvent
|
||||||
pollRequestEvent = "poll_request"
|
pollRequestEvent = model.PollRequestEvent
|
||||||
|
messageIDLength = model.MessageIDLength
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// SinceMarker aliases
|
||||||
messageIDLength = 12
|
var (
|
||||||
|
sinceAllMessages = model.SinceAllMessages
|
||||||
|
sinceNoMessages = model.SinceNoMessages
|
||||||
|
sinceLatestMessage = model.SinceLatestMessage
|
||||||
)
|
)
|
||||||
|
|
||||||
// message represents a message published to a topic
|
// Error aliases
|
||||||
type message struct {
|
var (
|
||||||
ID string `json:"id"` // Random message ID
|
errMessageNotFound = model.ErrMessageNotFound
|
||||||
SequenceID string `json:"sequence_id,omitempty"` // Message sequence ID for updating message contents (omitted if same as ID)
|
)
|
||||||
Time int64 `json:"time"` // Unix time in seconds
|
|
||||||
Expires int64 `json:"expires,omitempty"` // Unix time in seconds (not required for open/keepalive)
|
|
||||||
Event string `json:"event"` // One of the above
|
|
||||||
Topic string `json:"topic"`
|
|
||||||
Title string `json:"title,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
Priority int `json:"priority,omitempty"`
|
|
||||||
Tags []string `json:"tags,omitempty"`
|
|
||||||
Click string `json:"click,omitempty"`
|
|
||||||
Icon string `json:"icon,omitempty"`
|
|
||||||
Actions []*action `json:"actions,omitempty"`
|
|
||||||
Attachment *attachment `json:"attachment,omitempty"`
|
|
||||||
PollID string `json:"poll_id,omitempty"`
|
|
||||||
ContentType string `json:"content_type,omitempty"` // text/plain by default (if empty), or text/markdown
|
|
||||||
Encoding string `json:"encoding,omitempty"` // Empty for raw UTF-8, or "base64" for encoded bytes
|
|
||||||
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
|
|
||||||
User string `json:"-"` // UserID of the uploader, used to associated attachments
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *message) Context() log.Context {
|
// Constructors and helpers
|
||||||
fields := map[string]any{
|
var (
|
||||||
"topic": m.Topic,
|
newMessage = model.NewMessage
|
||||||
"message_id": m.ID,
|
newDefaultMessage = model.NewDefaultMessage
|
||||||
"message_sequence_id": m.SequenceID,
|
newOpenMessage = model.NewOpenMessage
|
||||||
"message_time": m.Time,
|
newKeepaliveMessage = model.NewKeepaliveMessage
|
||||||
"message_event": m.Event,
|
newActionMessage = model.NewActionMessage
|
||||||
"message_body_size": len(m.Message),
|
newAction = model.NewAction
|
||||||
}
|
newSinceTime = model.NewSinceTime
|
||||||
if m.Sender.IsValid() {
|
newSinceID = model.NewSinceID
|
||||||
fields["message_sender"] = m.Sender.String()
|
validMessageID = model.ValidMessageID
|
||||||
}
|
)
|
||||||
if m.User != "" {
|
|
||||||
fields["message_user"] = m.User
|
|
||||||
}
|
|
||||||
return fields
|
|
||||||
}
|
|
||||||
|
|
||||||
// forJSON returns a copy of the message suitable for JSON output.
|
|
||||||
// It clears the SequenceID if it equals the ID to reduce redundancy.
|
|
||||||
func (m *message) forJSON() *message {
|
|
||||||
if m.SequenceID == m.ID {
|
|
||||||
clone := *m
|
|
||||||
clone.SequenceID = ""
|
|
||||||
return &clone
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
type attachment struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Size int64 `json:"size,omitempty"`
|
|
||||||
Expires int64 `json:"expires,omitempty"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type action struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Action string `json:"action"` // "view", "broadcast", "http", or "copy"
|
|
||||||
Label string `json:"label"` // action button label
|
|
||||||
Clear bool `json:"clear"` // clear notification after successful execution
|
|
||||||
URL string `json:"url,omitempty"` // used in "view" and "http" actions
|
|
||||||
Method string `json:"method,omitempty"` // used in "http" action, default is POST (!)
|
|
||||||
Headers map[string]string `json:"headers,omitempty"` // used in "http" action
|
|
||||||
Body string `json:"body,omitempty"` // used in "http" action
|
|
||||||
Intent string `json:"intent,omitempty"` // used in "broadcast" action
|
|
||||||
Extras map[string]string `json:"extras,omitempty"` // used in "broadcast" action
|
|
||||||
Value string `json:"value,omitempty"` // used in "copy" action
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAction() *action {
|
|
||||||
return &action{
|
|
||||||
Headers: make(map[string]string),
|
|
||||||
Extras: make(map[string]string),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// publishMessage is used as input when publishing as JSON
|
|
||||||
type publishMessage struct {
|
|
||||||
Topic string `json:"topic"`
|
|
||||||
SequenceID string `json:"sequence_id"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Priority int `json:"priority"`
|
|
||||||
Tags []string `json:"tags"`
|
|
||||||
Click string `json:"click"`
|
|
||||||
Icon string `json:"icon"`
|
|
||||||
Actions []action `json:"actions"`
|
|
||||||
Attach string `json:"attach"`
|
|
||||||
Markdown bool `json:"markdown"`
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
Call string `json:"call"`
|
|
||||||
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
|
||||||
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
|
||||||
Delay string `json:"delay"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// messageEncoder is a function that knows how to encode a message
|
|
||||||
type messageEncoder func(msg *message) (string, error)
|
|
||||||
|
|
||||||
// newMessage creates a new message with the current timestamp
|
|
||||||
func newMessage(event, topic, msg string) *message {
|
|
||||||
return &message{
|
|
||||||
ID: util.RandomString(messageIDLength),
|
|
||||||
Time: time.Now().Unix(),
|
|
||||||
Event: event,
|
|
||||||
Topic: topic,
|
|
||||||
Message: msg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// newOpenMessage is a convenience method to create an open message
|
|
||||||
func newOpenMessage(topic string) *message {
|
|
||||||
return newMessage(openEvent, topic, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// newKeepaliveMessage is a convenience method to create a keepalive message
|
|
||||||
func newKeepaliveMessage(topic string) *message {
|
|
||||||
return newMessage(keepaliveEvent, topic, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDefaultMessage is a convenience method to create a notification message
|
|
||||||
func newDefaultMessage(topic, msg string) *message {
|
|
||||||
return newMessage(messageEvent, topic, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newPollRequestMessage is a convenience method to create a poll request message
|
// newPollRequestMessage is a convenience method to create a poll request message
|
||||||
func newPollRequestMessage(topic, pollID string) *message {
|
func newPollRequestMessage(topic, pollID string) *model.Message {
|
||||||
m := newMessage(pollRequestEvent, topic, newMessageBody)
|
m := newMessage(pollRequestEvent, topic, newMessageBody)
|
||||||
m.PollID = pollID
|
m.PollID = pollID
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// newActionMessage creates a new action message (message_delete or message_clear)
|
// publishMessage is used as input when publishing as JSON
|
||||||
func newActionMessage(event, topic, sequenceID string) *message {
|
type publishMessage struct {
|
||||||
m := newMessage(event, topic, "")
|
Topic string `json:"topic"`
|
||||||
m.SequenceID = sequenceID
|
SequenceID string `json:"sequence_id"`
|
||||||
return m
|
Title string `json:"title"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
Click string `json:"click"`
|
||||||
|
Icon string `json:"icon"`
|
||||||
|
Actions []model.Action `json:"actions"`
|
||||||
|
Attach string `json:"attach"`
|
||||||
|
Markdown bool `json:"markdown"`
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Call string `json:"call"`
|
||||||
|
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
|
||||||
|
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
|
||||||
|
Delay string `json:"delay"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func validMessageID(s string) bool {
|
// messageEncoder is a function that knows how to encode a message
|
||||||
return util.ValidRandomString(s, messageIDLength)
|
type messageEncoder func(msg *model.Message) (string, error)
|
||||||
}
|
|
||||||
|
|
||||||
type sinceMarker struct {
|
|
||||||
time time.Time
|
|
||||||
id string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSinceTime(timestamp int64) sinceMarker {
|
|
||||||
return sinceMarker{time.Unix(timestamp, 0), ""}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSinceID(id string) sinceMarker {
|
|
||||||
return sinceMarker{time.Unix(0, 0), id}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) IsAll() bool {
|
|
||||||
return t == sinceAllMessages
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) IsNone() bool {
|
|
||||||
return t == sinceNoMessages
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) IsLatest() bool {
|
|
||||||
return t == sinceLatestMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) IsID() bool {
|
|
||||||
return t.id != "" && t.id != "latest"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) Time() time.Time {
|
|
||||||
return t.time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t sinceMarker) ID() string {
|
|
||||||
return t.id
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
|
|
||||||
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
|
|
||||||
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
|
|
||||||
)
|
|
||||||
|
|
||||||
type queryFilter struct {
|
type queryFilter struct {
|
||||||
ID string
|
ID string
|
||||||
@@ -246,7 +105,7 @@ func parseQueryFilters(r *http.Request) (*queryFilter, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queryFilter) Pass(msg *message) bool {
|
func (q *queryFilter) Pass(msg *model.Message) bool {
|
||||||
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageClearEvent {
|
||||||
return true // filters only apply to messages
|
return true // filters only apply to messages
|
||||||
} else if q.ID != "" && msg.ID != q.ID {
|
} else if q.ID != "" && msg.ID != q.ID {
|
||||||
@@ -570,12 +429,12 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type webPushPayload struct {
|
type webPushPayload struct {
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
SubscriptionID string `json:"subscription_id"`
|
SubscriptionID string `json:"subscription_id"`
|
||||||
Message *message `json:"message"`
|
Message *model.Message `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWebPushPayload(subscriptionID string, message *message) *webPushPayload {
|
func newWebPushPayload(subscriptionID string, message *model.Message) *webPushPayload {
|
||||||
return &webPushPayload{
|
return &webPushPayload{
|
||||||
Event: webPushMessageEvent,
|
Event: webPushMessageEvent,
|
||||||
SubscriptionID: subscriptionID,
|
SubscriptionID: subscriptionID,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
|
"heckel.io/ntfy/v2/message"
|
||||||
"heckel.io/ntfy/v2/user"
|
"heckel.io/ntfy/v2/user"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
@@ -53,7 +54,7 @@ const (
|
|||||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||||
type visitor struct {
|
type visitor struct {
|
||||||
config *Config
|
config *Config
|
||||||
messageCache *messageCache
|
messageCache message.Store
|
||||||
userManager *user.Manager // May be nil
|
userManager *user.Manager // May be nil
|
||||||
ip netip.Addr // Visitor IP address
|
ip netip.Addr // Visitor IP address
|
||||||
user *user.User // Only set if authenticated user, otherwise nil
|
user *user.User // Only set if authenticated user, otherwise nil
|
||||||
@@ -114,7 +115,7 @@ const (
|
|||||||
visitorLimitBasisTier = visitorLimitBasis("tier")
|
visitorLimitBasisTier = visitorLimitBasis("tier")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
func newVisitor(conf *Config, messageCache message.Store, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||||
var messages, emails, calls int64
|
var messages, emails, calls int64
|
||||||
if user != nil {
|
if user != nil {
|
||||||
messages = user.Stats.Messages
|
messages = user.Stats.Messages
|
||||||
|
|||||||
@@ -691,7 +691,6 @@ func TestManager_Token_Expire(t *testing.T) {
|
|||||||
// But the token row should still exist
|
// But the token row should still exist
|
||||||
tokens, err := a.Tokens(u.ID)
|
tokens, err := a.Tokens(u.ID)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, token1.Value, tokens[0].Value)
|
|
||||||
require.Equal(t, 2, len(tokens))
|
require.Equal(t, 2, len(tokens))
|
||||||
|
|
||||||
// Expire tokens and check that token1 is gone
|
// Expire tokens and check that token1 is gone
|
||||||
|
|||||||
32
user/util.go
32
user/util.go
@@ -94,38 +94,6 @@ func nullInt64(v int64) sql.NullInt64 {
|
|||||||
return sql.NullInt64{Int64: v, Valid: true}
|
return sql.NullInt64{Int64: v, Valid: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
|
||||||
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
if err := f(tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryTx executes a function in a transaction and returns the result. If the function
|
|
||||||
// returns an error, the transaction is rolled back.
|
|
||||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
var zero T
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
t, err := f(tx)
|
|
||||||
if err != nil {
|
|
||||||
return t, err
|
|
||||||
}
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return t, err
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
|
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
|
||||||
// and escapes '_', assuming '\' as escape character.
|
// and escapes '_', assuming '\' as escape character.
|
||||||
func toSQLWildcard(s string) string {
|
func toSQLWildcard(s string) string {
|
||||||
|
|||||||
Reference in New Issue
Block a user