Merge client ip header

This commit is contained in:
binwiederhier
2025-05-31 15:33:21 -04:00
parent a49cafbadb
commit 2cb4d089ab
72 changed files with 5585 additions and 3157 deletions

View File

@@ -26,8 +26,8 @@ const (
// Defines default Web Push settings
const (
DefaultWebPushExpiryWarningDuration = 7 * 24 * time.Hour
DefaultWebPushExpiryDuration = 9 * 24 * time.Hour
DefaultWebPushExpiryWarningDuration = 55 * 24 * time.Hour
DefaultWebPushExpiryDuration = 60 * 24 * time.Hour
)
// Defines all global and per-visitor limits
@@ -144,7 +144,7 @@ type Config struct {
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
VisitorSubscriberRateLimiting bool // Enable subscriber-based rate limiting for UnifiedPush topics
BehindProxy bool
ProxyClientIPHeader string
ProxyForwardedHeader string
StripeSecretKey string
StripeWebhookKey string
StripePriceCacheDuration time.Duration
@@ -233,8 +233,8 @@ func NewConfig() *Config {
VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish,
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
VisitorSubscriberRateLimiting: false,
BehindProxy: false,
ProxyClientIPHeader: "",
BehindProxy: false, // If true, the server will trust the proxy client IP header to determine the client IP address
ProxyForwardedHeader: "X-Forwarded-For", // Default header for reverse proxy client IPs
StripeSecretKey: "",
StripeWebhookKey: "",
StripePriceCacheDuration: DefaultStripePriceCacheDuration,

View File

@@ -99,6 +99,13 @@ const (
WHERE topic = ? AND (id > ? OR published = 0)
ORDER BY time, id
`
selectMessagesLatestQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
WHERE topic = ? AND published = 1
ORDER BY time DESC, id DESC
LIMIT 1
`
selectMessagesDueQuery = `
SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, content_type, encoding
FROM messages
@@ -416,6 +423,8 @@ func (c *messageCache) addMessages(ms []*message) error {
func (c *messageCache) Messages(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
if since.IsNone() {
return make([]*message, 0), nil
} else if since.IsLatest() {
return c.messagesLatest(topic)
} else if since.IsID() {
return c.messagesSinceID(topic, since, scheduled)
}
@@ -462,6 +471,14 @@ func (c *messageCache) messagesSinceID(topic string, since sinceMarker, schedule
return readMessages(rows)
}
func (c *messageCache) messagesLatest(topic string) ([]*message, error) {
rows, err := c.db.Query(selectMessagesLatestQuery, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
func (c *messageCache) MessagesDue() ([]*message, error) {
rows, err := c.db.Query(selectMessagesDueQuery, time.Now().Unix())
if err != nil {

View File

@@ -66,6 +66,11 @@ func testCacheMessages(t *testing.T, c *messageCache) {
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// mytopic: latest
messages, _ = c.Messages("mytopic", sinceLatestMessage, false)
require.Equal(t, 1, len(messages))
require.Equal(t, "my other message", messages[0].Message)
// example: count
counts, err = c.MessageCounts()
require.Nil(t, err)

View File

@@ -413,7 +413,8 @@ func (s *Server) handleError(w http.ResponseWriter, r *http.Request, v *visitor,
} else {
ev.Info("WebSocket error: %s", err.Error())
}
return // Do not attempt to write to upgraded connection
w.WriteHeader(httpErr.HTTPCode)
return // Do not attempt to write any body to upgraded connection
}
if isNormalError {
ev.Debug("Connection closed with HTTP %d (ntfy error %d)", httpErr.HTTPCode, httpErr.Code)
@@ -445,8 +446,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.ensureWebPushEnabled(s.handleWebManifest)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersGet)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
} else if r.Method == http.MethodPost && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersAdd)(w, r, v)
} else if r.Method == http.MethodPut && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersUpdate)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiUsersPath {
return s.ensureAdmin(s.handleUsersDelete)(w, r, v)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == apiUsersAccessPath {
@@ -1016,7 +1019,7 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
if actionsStr != "" {
m.Actions, e = parseActions(actionsStr)
if e != nil {
return false, false, "", "", false, false, errHTTPBadRequestActionsInvalid.Wrap(e.Error())
return false, false, "", "", false, false, errHTTPBadRequestActionsInvalid.Wrap("%s", e.Error())
}
}
contentType, markdown := readParam(r, "content-type", "content_type"), readBoolParam(r, false, "x-markdown", "markdown", "md")
@@ -1025,7 +1028,8 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
}
template = readBoolParam(r, false, "x-template", "template", "tpl")
unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
if unifiedpush {
contentEncoding := readParam(r, "content-encoding")
if unifiedpush || contentEncoding == "aes128gcm" {
firebase = false
unifiedpush = true
}
@@ -1556,8 +1560,8 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
// parseSince returns a timestamp identifying the time span from which cached messages should be received.
//
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
// "all" for all messages.
// Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h),
// "all" for all messages, or "latest" for the most recent message for a topic
func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
since := readParam(r, "x-since", "since", "si")
@@ -1569,6 +1573,8 @@ func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
return sinceNoMessages, nil
} else if since == "all" {
return sinceAllMessages, nil
} else if since == "latest" {
return sinceLatestMessage, nil
} else if since == "none" {
return sinceNoMessages, nil
}
@@ -1862,6 +1868,12 @@ func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
if m.Call != "" {
r.Header.Set("X-Call", m.Call)
}
if m.Cache != "" {
r.Header.Set("X-Cache", m.Cache)
}
if m.Firebase != "" {
r.Header.Set("X-Firebase", m.Firebase)
}
return next(w, r, v)
}
}
@@ -1885,14 +1897,14 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
}
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
return s.autorizeTopic(next, user.PermissionWrite)
return s.authorizeTopic(next, user.PermissionWrite)
}
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
return s.autorizeTopic(next, user.PermissionRead)
return s.authorizeTopic(next, user.PermissionRead)
}
func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
func (s *Server) authorizeTopic(next handleFunc, perm user.Permission) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.userManager == nil {
return next(w, r, v)
@@ -1925,7 +1937,7 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
// that subsequent logging calls still have a visitor context.
func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
// Read "Authorization" header value, and exit out early if it's not set
ip := extractIPAddress(r, s.config.BehindProxy, s.config.ProxyClientIPHeader)
ip := extractIPAddress(r, s.config.BehindProxy, s.config.ProxyForwardedHeader)
vip := s.visitor(ip, nil)
if s.userManager == nil {
return vip, nil
@@ -2000,7 +2012,7 @@ func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.Us
if err != nil {
return nil, err
}
ip := extractIPAddress(r, s.config.BehindProxy, s.config.ProxyClientIPHeader)
ip := extractIPAddress(r, s.config.BehindProxy, s.config.ProxyForwardedHeader)
go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{
LastAccess: time.Now(),
LastOrigin: ip,

View File

@@ -146,7 +146,7 @@
# Web Push support (background notifications for browsers)
#
# If enabled, allows ntfy to receive push notifications, even when the ntfy web app is closed. When enabled, users
# If enabled, allows the ntfy web app to receive push notifications, even when the web app is closed. When enabled, users
# can enable background notifications in the web app. Once enabled, ntfy will forward published messages to the push
# endpoint, which will then forward it to the browser.
#
@@ -155,15 +155,19 @@
#
# - web-push-public-key is the generated VAPID public key, e.g. AA1234BBCCddvveekaabcdfqwertyuiopasdfghjklzxcvbnm1234567890
# - web-push-private-key is the generated VAPID private key, e.g. AA2BB1234567890abcdefzxcvbnm1234567890
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. `/var/cache/ntfy/webpush.db`
# - web-push-email-address is the admin email address send to the push provider, e.g. `sysadmin@example.com`
# - web-push-file is a database file to keep track of browser subscription endpoints, e.g. /var/cache/ntfy/webpush.db
# - web-push-email-address is the admin email address send to the push provider, e.g. sysadmin@example.com
# - web-push-startup-queries is an optional list of queries to run on startup`
# - web-push-expiry-warning-duration defines the duration after which unused subscriptions are sent a warning (default is 55d`)
# - web-push-expiry-duration defines the duration after which unused subscriptions will expire (default is 60d)
#
# web-push-public-key:
# web-push-private-key:
# web-push-file:
# web-push-email-address:
# web-push-startup-queries:
# web-push-expiry-warning-duration: "55d"
# web-push-expiry-duration: "60d"
# If enabled, ntfy can perform voice calls via Twilio via the "X-Call" header.
#

View File

@@ -37,7 +37,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
return errHTTPConflictUserExists
}
logvr(v, r).Tag(tagAccount).Field("user_name", newAccount.Username).Info("Creating user %s", newAccount.Username)
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser, false); err != nil {
if errors.Is(err, user.ErrInvalidArgument) {
return errHTTPBadRequestInvalidUsername
}
@@ -207,7 +207,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
return errHTTPBadRequestIncorrectPasswordConfirmation
}
logvr(v, r).Tag(tagAccount).Debug("Changing password for user %s", u.Name)
if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil {
if err := s.userManager.ChangePassword(u.Name, req.NewPassword, false); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())

View File

@@ -87,9 +87,9 @@ func TestAccount_Signup_AsUser(t *testing.T) {
defer s.closeDatabases()
log.Info("1")
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
log.Info("2")
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
log.Info("3")
rr := request(t, s, "POST", "/v1/account", `{"username":"emma", "password":"emma"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -174,7 +174,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
u, _ := s.userManager.User("phil")
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
@@ -203,7 +203,7 @@ func TestAccount_Subscription_AddUpdateDelete(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
rr := request(t, s, "POST", "/v1/account/subscription", `{"base_url": "http://abc.com", "topic": "def"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -254,7 +254,7 @@ func TestAccount_ChangePassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
rr := request(t, s, "POST", "/v1/account/password", `{"password": "WRONG", "new_password": ""}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -296,7 +296,7 @@ func TestAccount_ExtendToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -332,7 +332,7 @@ func TestAccount_ExtendToken_NoTokenProvided(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
rr := request(t, s, "PATCH", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"), // Not Bearer!
@@ -345,7 +345,7 @@ func TestAccount_DeleteToken(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
rr := request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -455,14 +455,14 @@ func TestAccount_Reservation_AddAdminSuccess(t *testing.T) {
Code: "pro",
ReservationLimit: 2,
}))
require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser))
require.Nil(t, s.userManager.AddUser("noadmin1", "pass", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("noadmin1", "pro"))
require.Nil(t, s.userManager.AddReservation("noadmin1", "mytopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser))
require.Nil(t, s.userManager.AddUser("noadmin2", "pass", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("noadmin2", "pro"))
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "adminpass", user.RoleAdmin, false))
// Admin can reserve topic
rr := request(t, s, "POST", "/v1/account/reservation", `{"topic":"sometopic","everyone":"deny-all"}`, map[string]string{
@@ -624,7 +624,7 @@ func TestAccount_Reservation_Delete_Messages_And_Attachments(t *testing.T) {
s := newTestServer(t, conf)
// Create user with tier
require.Nil(t, s.userManager.AddUser("phil", "mypass", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "mypass", user.RoleUser, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "pro",
MessageLimit: 20,

View File

@@ -39,11 +39,11 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit
}
func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserAddRequest](r.Body, jsonBodyBytesLimit, false)
req, err := readJSONWithLimit[apiUserAddOrUpdateRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
} else if !user.AllowedUsername(req.Username) || req.Password == "" {
return errHTTPBadRequest.Wrap("username invalid, or password missing")
} else if !user.AllowedUsername(req.Username) || (req.Password == "" && req.Hash == "") {
return errHTTPBadRequest.Wrap("username invalid, or password/password_hash missing")
}
u, err := s.userManager.User(req.Username)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
@@ -60,7 +60,11 @@ func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visit
return err
}
}
if err := s.userManager.AddUser(req.Username, req.Password, user.RoleUser); err != nil {
password, hashed := req.Password, false
if req.Hash != "" {
password, hashed = req.Hash, true
}
if err := s.userManager.AddUser(req.Username, password, user.RoleUser, hashed); err != nil {
return err
}
if tier != nil {
@@ -71,6 +75,53 @@ func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visit
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleUsersUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserAddOrUpdateRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
} else if !user.AllowedUsername(req.Username) {
return errHTTPBadRequest.Wrap("username invalid")
} else if req.Password == "" && req.Hash == "" && req.Tier == "" {
return errHTTPBadRequest.Wrap("need to provide at least one of \"password\", \"password_hash\" or \"tier\"")
}
u, err := s.userManager.User(req.Username)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return err
} else if u != nil {
if u.IsAdmin() {
return errHTTPForbidden
}
if req.Hash != "" {
if err := s.userManager.ChangePassword(req.Username, req.Hash, true); err != nil {
return err
}
} else if req.Password != "" {
if err := s.userManager.ChangePassword(req.Username, req.Password, false); err != nil {
return err
}
}
} else {
password, hashed := req.Password, false
if req.Hash != "" {
password, hashed = req.Hash, true
}
if err := s.userManager.AddUser(req.Username, password, user.RoleUser, hashed); err != nil {
return err
}
}
if req.Tier != "" {
if _, err = s.userManager.Tier(req.Tier); errors.Is(err, user.ErrTierNotFound) {
return errHTTPBadRequestTierInvalid
} else if err != nil {
return err
}
if err := s.userManager.ChangeTier(req.Username, req.Tier); err != nil {
return err
}
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {

View File

@@ -14,13 +14,13 @@ func TestUser_AddRemove(t *testing.T) {
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
// Create user via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
@@ -49,6 +49,226 @@ func TestUser_AddRemove(t *testing.T) {
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check user was deleted
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "emma", users[1].Name)
require.Equal(t, user.Everyone, users[2].Name)
// Reject invalid user change
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
}
func TestUser_AddWithPasswordHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check that user can login with password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, user.RoleAdmin, users[0].Role)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
}
func TestUser_ChangeUserPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password": "ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Change password via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password": "ben-two"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_ChangeUserTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users again
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordAndTier(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin, tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier1",
}))
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "tier2",
}))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"ben", "tier": "tier1"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Check users
users, err := s.userManager.Users()
require.Nil(t, err)
require.Equal(t, 3, len(users))
require.Equal(t, "phil", users[0].Name)
require.Equal(t, "ben", users[1].Name)
require.Equal(t, user.RoleUser, users[1].Role)
require.Equal(t, "tier1", users[1].Tier.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "password":"ben-two", "tier": "tier2"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Make sure first password fails
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben-two"),
})
require.Equal(t, 200, rr.Code)
// Check new tier
users, err = s.userManager.Users()
require.Nil(t, err)
require.Equal(t, "tier2", users[1].Tier.Code)
}
func TestUser_ChangeUserPasswordWithHash(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
// Create user with tier via API
rr := request(t, s, "POST", "/v1/users", `{"username": "ben", "password":"not-ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with first password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "not-ben"),
})
require.Equal(t, 200, rr.Code)
// Change user password and tier via API
rr = request(t, s, "PUT", "/v1/users", `{"username": "ben", "hash":"$2a$04$2aPIIqPXQU16OfkSUZH1XOzpu1gsPRKkrfVdFLgWQ.tqb.vtTCuVe"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Try to login with second password
rr = request(t, s, "POST", "/v1/account/token", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
}
func TestUser_DontChangeAdminPassword(t *testing.T) {
s := newTestServer(t, newTestConfigWithAuthFile(t))
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("admin", "admin", user.RoleAdmin, false))
// Try to change password via API
rr := request(t, s, "PUT", "/v1/users", `{"username": "admin", "password": "admin-new"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 403, rr.Code)
}
func TestUser_AddRemove_Failures(t *testing.T) {
@@ -56,23 +276,23 @@ func TestUser_AddRemove_Failures(t *testing.T) {
defer s.closeDatabases()
// Create admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Cannot create user with invalid username
rr := request(t, s, "PUT", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
rr := request(t, s, "POST", "/v1/users", `{"username": "not valid", "password":"ben"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 400, rr.Code)
// Cannot create user if user already exists
rr = request(t, s, "PUT", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
rr = request(t, s, "POST", "/v1/users", `{"username": "phil", "password":"phil"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40901, toHTTPError(t, rr.Body.String()).Code)
// Cannot create user with invalid tier
rr = request(t, s, "PUT", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
rr = request(t, s, "POST", "/v1/users", `{"username": "emma", "password":"emma", "tier": "invalid"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
@@ -97,8 +317,8 @@ func TestAccess_AllowReset(t *testing.T) {
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
@@ -138,7 +358,7 @@ func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/users/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
@@ -154,8 +374,8 @@ func TestAccess_AllowReset_KillConnection(t *testing.T) {
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}

View File

@@ -50,7 +50,7 @@ func (c *firebaseClient) Send(v *visitor, m *message) error {
ev.Field("firebase_message", util.MaybeMarshalJSON(fbm)).Trace("Firebase message")
}
err = c.sender.Send(fbm)
if err == errFirebaseQuotaExceeded {
if errors.Is(err, errFirebaseQuotaExceeded) {
logvm(v, m).
Tag(tagFirebase).
Err(err).
@@ -133,56 +133,55 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
"time": fmt.Sprintf("%d", m.Time),
"event": m.Event,
"topic": m.Topic,
"message": m.Message,
"message": newMessageBody,
"poll_id": m.PollID,
}
apnsConfig = createAPNSAlertConfig(m, data)
case messageEvent:
allowForward := true
if auther != nil {
allowForward = auther.Authorize(nil, m.Topic, user.PermissionRead) == nil
}
if allowForward {
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": m.Event,
"topic": m.Topic,
"priority": fmt.Sprintf("%d", m.Priority),
"tags": strings.Join(m.Tags, ","),
"click": m.Click,
"icon": m.Icon,
"title": m.Title,
"message": m.Message,
"content_type": m.ContentType,
"encoding": m.Encoding,
}
if len(m.Actions) > 0 {
actions, err := json.Marshal(m.Actions)
if err != nil {
return nil, err
}
data["actions"] = string(actions)
}
if m.Attachment != nil {
data["attachment_name"] = m.Attachment.Name
data["attachment_type"] = m.Attachment.Type
data["attachment_size"] = fmt.Sprintf("%d", m.Attachment.Size)
data["attachment_expires"] = fmt.Sprintf("%d", m.Attachment.Expires)
data["attachment_url"] = m.Attachment.URL
}
apnsConfig = createAPNSAlertConfig(m, data)
} else {
// If anonymous read for a topic is not allowed, we cannot send the message along
// If "anonymous read" for a topic is not allowed, we cannot send the message along
// via Firebase. Instead, we send a "poll_request" message, asking the client to poll.
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": pollRequestEvent,
"topic": m.Topic,
//
// The data map needs to contain all the fields for it to function properly. If not all
// fields are set, the iOS app fails to decode the message.
//
// See https://github.com/binwiederhier/ntfy/pull/1345
if err := auther.Authorize(nil, m.Topic, user.PermissionRead); err != nil {
m = toPollRequest(m)
}
// TODO Handle APNS?
}
data = map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": m.Event,
"topic": m.Topic,
"priority": fmt.Sprintf("%d", m.Priority),
"tags": strings.Join(m.Tags, ","),
"click": m.Click,
"icon": m.Icon,
"title": m.Title,
"message": m.Message,
"content_type": m.ContentType,
"encoding": m.Encoding,
}
if len(m.Actions) > 0 {
actions, err := json.Marshal(m.Actions)
if err != nil {
return nil, err
}
data["actions"] = string(actions)
}
if m.Attachment != nil {
data["attachment_name"] = m.Attachment.Name
data["attachment_type"] = m.Attachment.Type
data["attachment_size"] = fmt.Sprintf("%d", m.Attachment.Size)
data["attachment_expires"] = fmt.Sprintf("%d", m.Attachment.Expires)
data["attachment_url"] = m.Attachment.URL
}
if m.PollID != "" {
data["poll_id"] = m.PollID
}
apnsConfig = createAPNSAlertConfig(m, data)
}
var androidConfig *messaging.AndroidConfig
if m.Priority >= 4 {
@@ -276,3 +275,17 @@ func maybeTruncateAPNSBodyMessage(s string) string {
}
return s
}
// toPollRequest converts a message to a poll request message.
//
// This empties all the fields that are not needed for a poll request and just sets the required fields,
// most importantly, the PollID.
func toPollRequest(m *message) *message {
pr := newPollRequestMessage(m.Topic, m.ID)
pr.ID = m.ID
pr.Time = m.Time
pr.Priority = m.Priority // Keep priority
pr.ContentType = m.ContentType
pr.Encoding = m.Encoding
return pr
}

View File

@@ -223,14 +223,25 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
require.Equal(t, &messaging.AndroidConfig{
Priority: "high",
}, fbm.Android)
require.Equal(t, "", fbm.Data["message"])
require.Equal(t, "", fbm.Data["priority"])
require.Equal(t, "New message", fbm.Data["message"])
require.Equal(t, "5", fbm.Data["priority"])
require.Equal(t, map[string]string{
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": "poll_request",
"topic": "mytopic",
"id": m.ID,
"time": fmt.Sprintf("%d", m.Time),
"event": "poll_request",
"topic": "mytopic",
"message": "New message",
"title": "",
"tags": "",
"click": "",
"icon": "",
"priority": "5",
"encoding": "",
"content_type": "",
"poll_id": m.ID,
}, fbm.Data)
require.Equal(t, "", fbm.APNS.Payload.Aps.Alert.Title)
require.Equal(t, "New message", fbm.APNS.Payload.Aps.Alert.Body)
}
func TestToFirebaseMessage_PollRequest(t *testing.T) {

View File

@@ -148,7 +148,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
// Create subscription
response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
@@ -184,7 +184,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -226,7 +226,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
Code: "pro",
StripeMonthlyPriceID: "price_123",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -280,7 +280,7 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
MessageExpiryDuration: time.Hour,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false)) // No tier
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -461,7 +461,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
AttachmentTotalSizeLimit: 1000000,
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
@@ -570,7 +570,7 @@ func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
StripeMonthlyPriceID: "price_1234",
ReservationLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
@@ -658,7 +658,7 @@ func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
StripeMonthlyPriceID: "price_456",
StripeYearlyPriceID: "price_457",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
StripeCustomerID: "acct_123",
@@ -690,7 +690,7 @@ func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
Return(&stripe.Subscription{}, nil)
// Create user
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123",
@@ -724,7 +724,7 @@ func TestPayments_CreatePortalSession(t *testing.T) {
}, nil)
// Create user
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
StripeCustomerID: "acct_123",
StripeSubscriptionID: "sub_123",

View File

@@ -84,6 +84,22 @@ func TestServer_PublishWithFirebase(t *testing.T) {
require.Equal(t, "my first message", sender.Messages()[0].APNS.Payload.CustomData["message"])
}
func TestServer_PublishWithoutFirebase(t *testing.T) {
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
response := request(t, s, "PUT", "/mytopic", "my first message", map[string]string{
"firebase": "no",
})
msg1 := toMessage(t, response.Body.String())
require.NotEmpty(t, msg1.ID)
require.Equal(t, "my first message", msg1.Message)
time.Sleep(100 * time.Millisecond) // Firebase publishing happens
require.Equal(t, 0, len(sender.Messages()))
}
func TestServer_PublishWithFirebase_WithoutUsers_AndWithoutPanic(t *testing.T) {
// This tests issue #641, which used to panic before the fix
@@ -411,7 +427,7 @@ func TestServer_PublishAt_FromUser(t *testing.T) {
t.Parallel()
s := newTestServer(t, newTestConfigWithAuthFile(t))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
response := request(t, s, "PUT", "/mytopic", "a message", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
"In": "1h",
@@ -594,6 +610,11 @@ func TestServer_PublishAndPollSince(t *testing.T) {
require.Equal(t, 1, len(messages))
require.Equal(t, "test 2", messages[0].Message)
response = request(t, s, "GET", "/mytopic/json?poll=1&since=latest", "", nil)
messages = toMessages(t, response.Body.String())
require.Equal(t, 1, len(messages))
require.Equal(t, "test 2", messages[0].Message)
response = request(t, s, "GET", "/mytopic/json?poll=1&since=INVALID", "", nil)
require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code)
}
@@ -781,7 +802,7 @@ func TestServer_Auth_Success_Admin(t *testing.T) {
c := newTestConfigWithAuthFile(t)
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
@@ -795,7 +816,7 @@ func TestServer_Auth_Success_User(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@@ -809,7 +830,7 @@ func TestServer_Auth_Success_User_MultipleTopics(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "mytopic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess("ben", "anothertopic", user.PermissionReadWrite))
@@ -830,7 +851,7 @@ func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
"Authorization": util.BasicAuth("phil", "INVALID"),
@@ -843,7 +864,7 @@ func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "sometopic", user.PermissionReadWrite)) // Not mytopic!
response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
@@ -857,7 +878,7 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
c.AuthDefault = user.PermissionReadWrite // Open by default
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin, false))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "private", user.PermissionDenyAll))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
@@ -906,7 +927,7 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "some pass", user.RoleAdmin, false))
u := fmt.Sprintf("/mytopic/json?poll=1&auth=%s", base64.RawURLEncoding.EncodeToString([]byte(util.BasicAuth("ben", "some pass"))))
response := request(t, s, "GET", u, "", nil)
@@ -954,8 +975,8 @@ func TestServer_StatsResetter(t *testing.T) {
MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
// Send an anonymous message
@@ -1099,7 +1120,7 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test",
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
u, err := s.userManager.User("phil")
@@ -1679,6 +1700,35 @@ func TestServer_PublishAsJSON_WithActions(t *testing.T) {
require.Equal(t, "target_temp_f=65", m.Actions[1].Body)
}
func TestServer_PublishAsJSON_NoCache(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
body := `{"topic":"mytopic","message": "this message is not cached","cache":"no"}`
response := request(t, s, "PUT", "/", body, nil)
msg := toMessage(t, response.Body.String())
require.NotEmpty(t, msg.ID)
require.Equal(t, "this message is not cached", msg.Message)
require.Equal(t, int64(0), msg.Expires)
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
messages := toMessages(t, response.Body.String())
require.Empty(t, messages)
}
func TestServer_PublishAsJSON_WithoutFirebase(t *testing.T) {
sender := newTestFirebaseSender(10)
s := newTestServer(t, newTestConfig(t))
s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
body := `{"topic":"mytopic","message": "my first message","firebase":"no"}`
response := request(t, s, "PUT", "/", body, nil)
msg1 := toMessage(t, response.Body.String())
require.NotEmpty(t, msg1.ID)
require.Equal(t, "my first message", msg1.Message)
time.Sleep(100 * time.Millisecond) // Firebase publishing happens
require.Equal(t, 0, len(sender.Messages()))
}
func TestServer_PublishAsJSON_Invalid(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
body := `{"topic":"mytopic",INVALID`
@@ -1696,7 +1746,7 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
MessageLimit: 5,
MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish to reach message limit
@@ -1932,7 +1982,7 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
AttachmentExpiryDuration: sevenDays, // 7 days
AttachmentBandwidthLimit: 100000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
@@ -1977,7 +2027,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
AttachmentExpiryDuration: time.Hour,
AttachmentBandwidthLimit: 14000, // < 3x5000 bytes -> enough for one upload, one download
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish and make sure we can retrieve it
@@ -2015,7 +2065,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
AttachmentExpiryDuration: 30 * time.Second,
AttachmentBandwidthLimit: 1000000,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
// Publish small file as anonymous
@@ -2184,7 +2234,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
func TestServer_Visitor_Custom_ClientIP_Header(t *testing.T) {
c := newTestConfig(t)
c.BehindProxy = true
c.ProxyClientIPHeader = "X-Client-IP"
c.ProxyForwardedHeader = "X-Client-IP"
s := newTestServer(t, c)
r, _ := http.NewRequest("GET", "/bla", nil)
r.RemoteAddr = "8.9.10.11"
@@ -2250,7 +2300,7 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) {
defer s.closeDatabases()
// Create user without tier
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
// Publish a message (anonymous user)
rr := request(t, s, "POST", "/mytopic", "hi", nil)

View File

@@ -63,7 +63,7 @@ func TestServer_Twilio_Call_Add_Verify_Call_Delete_Success(t *testing.T) {
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -140,7 +140,7 @@ func TestServer_Twilio_Call_Success(t *testing.T) {
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -185,7 +185,7 @@ func TestServer_Twilio_Call_Success_With_Yes(t *testing.T) {
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
u, err := s.userManager.User("phil")
require.Nil(t, err)
@@ -216,7 +216,7 @@ func TestServer_Twilio_Call_UnverifiedNumber(t *testing.T) {
MessageLimit: 10,
CallLimit: 1,
}))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Do the thing

View File

@@ -96,7 +96,7 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
config.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
@@ -126,7 +126,7 @@ func TestServer_WebPush_DeleteAccountUnsubscribe(t *testing.T) {
config := configureAuth(t, newTestConfigWithWebPush(t))
s := newTestServer(t, config)
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser, false))
require.Nil(t, s.userManager.AllowAccess("ben", "test-topic", user.PermissionReadWrite))
response := request(t, s, "POST", "/v1/webpush", payloadForTopics(t, []string{"test-topic"}, testWebPushEndpoint), map[string]string{
@@ -212,7 +212,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
addSubscription(t, s, pushService.URL+"/push-receive", "test-topic")
requireSubscriptionCount(t, s, "test-topic", 1)
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-7*24*time.Hour).Unix())
_, err := s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-55*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()
@@ -222,7 +222,7 @@ func TestServer_WebPush_Expiry(t *testing.T) {
return received.Load()
})
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-9*24*time.Hour).Unix())
_, err = s.webPush.db.Exec("UPDATE subscription SET updated_at = ?", time.Now().Add(-60*24*time.Hour).Unix())
require.Nil(t, err)
s.pruneAndNotifyWebPushSubscriptions()

View File

@@ -105,6 +105,8 @@ type publishMessage struct {
Filename string `json:"filename"`
Email string `json:"email"`
Call string `json:"call"`
Cache string `json:"cache"` // use string as it defaults to true (or use &bool instead)
Firebase string `json:"firebase"` // use string as it defaults to true (or use &bool instead)
Delay string `json:"delay"`
}
@@ -169,8 +171,12 @@ func (t sinceMarker) IsNone() bool {
return t == sinceNoMessages
}
func (t sinceMarker) IsLatest() bool {
return t == sinceLatestMessage
}
func (t sinceMarker) IsID() bool {
return t.id != ""
return t.id != "" && t.id != "latest"
}
func (t sinceMarker) Time() time.Time {
@@ -182,8 +188,9 @@ func (t sinceMarker) ID() string {
}
var (
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
sinceAllMessages = sinceMarker{time.Unix(0, 0), ""}
sinceNoMessages = sinceMarker{time.Unix(1, 0), ""}
sinceLatestMessage = sinceMarker{time.Unix(0, 0), "latest"}
)
type queryFilter struct {
@@ -248,9 +255,10 @@ type apiStatsResponse struct {
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
}
type apiUserAddRequest struct {
type apiUserAddOrUpdateRequest struct {
Username string `json:"username"`
Password string `json:"password"`
Hash string `json:"hash"`
Tier string `json:"tier"`
// Do not add 'role' here. We don't want to add admins via the API.
}

View File

@@ -73,61 +73,33 @@ func readQueryParam(r *http.Request, names ...string) string {
return ""
}
func extractIPAddress(r *http.Request, behindProxy bool, proxyClientIPHeader string) netip.Addr {
logr(r).Debug("Starting IP extraction")
func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader string) netip.Addr {
remoteAddr := r.RemoteAddr
logr(r).Debug("RemoteAddr: %s", remoteAddr)
addrPort, err := netip.ParseAddrPort(remoteAddr)
ip := addrPort.Addr()
if err != nil {
logr(r).Warn("Failed to parse RemoteAddr as AddrPort: %v", err)
// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
ip, err = netip.ParseAddr(remoteAddr)
if err != nil {
ip = netip.IPv4Unspecified()
logr(r).Error("Failed to parse RemoteAddr as IP: %v, defaulting to 0.0.0.0", err)
if remoteAddr != "@" && !behindProxy { // RemoteAddr is @ when unix socket is used
logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", remoteAddr)
}
}
}
// Log initial IP before further processing
logr(r).Debug("Initial IP after RemoteAddr parsing: %s", ip)
if proxyClientIPHeader != "" {
logr(r).Debug("Using ProxyClientIPHeader: %s", proxyClientIPHeader)
if customHeaderIP := r.Header.Get(proxyClientIPHeader); customHeaderIP != "" {
logr(r).Debug("Custom header %s value: %s", proxyClientIPHeader, customHeaderIP)
realIP, err := netip.ParseAddr(customHeaderIP)
if err != nil {
logr(r).Error("Invalid IP in %s header: %s, error: %v", proxyClientIPHeader, customHeaderIP, err)
} else {
logr(r).Debug("Successfully parsed IP from custom header: %s", realIP)
ip = realIP
}
if behindProxy && strings.TrimSpace(r.Header.Get(proxyForwardedHeader)) != "" {
// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
// only the right-most address can be trusted (as this is the one added by our proxy server).
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
ips := util.SplitNoEmpty(r.Header.Get(proxyForwardedHeader), ",")
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
logr(r).Err(err).Error("invalid IP address %s received in %s header", ip, proxyForwardedHeader)
// Fall back to the regular remote address if X-Forwarded-For is damaged
} else {
logr(r).Warn("Custom header %s is empty or missing", proxyClientIPHeader)
ip = realIP
}
} else if behindProxy {
logr(r).Debug("No ProxyClientIPHeader set, checking X-Forwarded-For")
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
logr(r).Debug("X-Forwarded-For value: %s", xff)
ips := util.SplitNoEmpty(xff, ",")
realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
if err != nil {
logr(r).Error("Invalid IP in X-Forwarded-For header: %s, error: %v", xff, err)
} else {
logr(r).Debug("Successfully parsed IP from X-Forwarded-For: %s", realIP)
ip = realIP
}
} else {
logr(r).Debug("X-Forwarded-For header is empty or missing")
}
} else {
logr(r).Debug("Behind proxy is false, skipping proxy headers")
}
// Final resolved IP
logr(r).Debug("Final resolved IP: %s", ip)
return ip
}

View File

@@ -79,8 +79,9 @@ const (
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
)
// Schema management queries
@@ -271,6 +272,10 @@ func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
if err != nil {
return err
}
_, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
return err
}