From 54514454bf5f65c7610edbc3f3003ccdaa30c4f9 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 4 Jul 2025 10:16:49 +0200 Subject: [PATCH] Works --- cmd/serve.go | 16 ++++++-- server/config.go | 6 +++ server/server.go | 2 +- server/server_test.go | 92 +++++++++++++++++++++++++++++++++++++++++-- server/util.go | 11 ++++-- server/visitor.go | 18 ++++----- 6 files changed, 125 insertions(+), 20 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 373ba69e..27ae0fcb 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -80,6 +80,7 @@ var flagsServe = append( altsrc.NewStringFlag(&cli.StringFlag{Name: "message-delay-limit", Aliases: []string{"message_delay_limit"}, EnvVars: []string{"NTFY_MESSAGE_DELAY_LIMIT"}, Value: util.FormatDuration(server.DefaultMessageDelayMax), Usage: "max duration a message can be scheduled into the future"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"global_topic_limit", "T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultTotalTopicLimit, Usage: "total number of topics allowed"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"visitor_subscription_limit"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}), + altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-total-size-limit", Aliases: []string{"visitor_attachment_total_size_limit"}, EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultVisitorAttachmentTotalSizeLimit), Usage: "total storage limit used for attachments per visitor"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-daily-bandwidth-limit", Aliases: []string{"visitor_attachment_daily_bandwidth_limit"}, EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT"}, Value: "500M", Usage: "total daily attachment download/upload bandwidth limit per visitor"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}), @@ -88,7 +89,8 @@ var flagsServe = append( altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}), altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: util.FormatDuration(server.DefaultVisitorEmailLimitReplenish), Usage: "interval at which burst limit is replenished (one per x)"}), - altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-prefix-bits-ipv4", Aliases: []string{"visitor_prefix_bits_ipv4"}, EnvVars: []string{"NTFY_VISITOR_PREFIX_BITS_IPV4"}, Value: server.DefaultVisitorPrefixBitsIPv4, Usage: "number of bits of the IPv4 address to use for rate limiting (default: 32, full address)"}), + altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-prefix-bits-ipv6", Aliases: []string{"visitor_prefix_bits_ipv6"}, EnvVars: []string{"NTFY_VISITOR_PREFIX_BITS_IPV6"}, Value: server.DefaultVisitorPrefixBitsIPv6, Usage: "number of bits of the IPv6 address to use for rate limiting (default: 64, /64 subnet)"}), altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "proxy-forwarded-header", Aliases: []string{"proxy_forwarded_header"}, EnvVars: []string{"NTFY_PROXY_FORWARDED_HEADER"}, Value: "X-Forwarded-For", Usage: "use specified header to determine visitor IP address (for rate limiting)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "proxy-trusted-addresses", Aliases: []string{"proxy_trusted_addresses"}, EnvVars: []string{"NTFY_PROXY_TRUSTED_ADDRESSES"}, Value: "", Usage: "comma-separated list of trusted IP addresses to remove from forwarded header"}), @@ -192,6 +194,8 @@ func execServe(c *cli.Context) error { visitorMessageDailyLimit := c.Int("visitor-message-daily-limit") visitorEmailLimitBurst := c.Int("visitor-email-limit-burst") visitorEmailLimitReplenishStr := c.String("visitor-email-limit-replenish") + visitorPrefixBitsIPv4 := c.Int("visitor-prefix-bits-ipv4") + visitorPrefixBitsIPv6 := c.Int("visitor-prefix-bits-ipv6") behindProxy := c.Bool("behind-proxy") proxyForwardedHeader := c.String("proxy-forwarded-header") proxyTrustedAddresses := util.SplitNoEmpty(c.String("proxy-trusted-addresses"), ",") @@ -325,6 +329,10 @@ func execServe(c *cli.Context) error { return errors.New("web push expiry warning duration cannot be higher than web push expiry duration") } else if behindProxy && proxyForwardedHeader == "" { return errors.New("if behind-proxy is set, proxy-forwarded-header must also be set") + } else if visitorPrefixBitsIPv4 < 1 || visitorPrefixBitsIPv4 > 32 { + return errors.New("visitor-prefix-bits-ipv4 must be between 1 and 32") + } else if visitorPrefixBitsIPv6 < 1 || visitorPrefixBitsIPv6 > 128 { + return errors.New("visitor-prefix-bits-ipv6 must be between 1 and 128") } // Backwards compatibility @@ -413,6 +421,7 @@ func execServe(c *cli.Context) error { conf.MessageDelayMax = messageDelayLimit conf.TotalTopicLimit = totalTopicLimit conf.VisitorSubscriptionLimit = visitorSubscriptionLimit + conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit conf.VisitorRequestLimitBurst = visitorRequestLimitBurst @@ -421,7 +430,8 @@ func execServe(c *cli.Context) error { conf.VisitorMessageDailyLimit = visitorMessageDailyLimit conf.VisitorEmailLimitBurst = visitorEmailLimitBurst conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish - conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting + conf.VisitorPrefixBitsIPv4 = visitorPrefixBitsIPv4 + conf.VisitorPrefixBitsIPv6 = visitorPrefixBitsIPv6 conf.BehindProxy = behindProxy conf.ProxyForwardedHeader = proxyForwardedHeader conf.ProxyTrustedAddresses = proxyTrustedAddresses @@ -434,7 +444,6 @@ func execServe(c *cli.Context) error { conf.EnableMetrics = enableMetrics conf.MetricsListenHTTP = metricsListenHTTP conf.ProfileListenHTTP = profileListenHTTP - conf.Version = c.App.Version conf.WebPushPrivateKey = webPushPrivateKey conf.WebPushPublicKey = webPushPublicKey conf.WebPushFile = webPushFile @@ -442,6 +451,7 @@ func execServe(c *cli.Context) error { conf.WebPushStartupQueries = webPushStartupQueries conf.WebPushExpiryDuration = webPushExpiryDuration conf.WebPushExpiryWarningDuration = webPushExpiryWarningDuration + conf.Version = c.App.Version // Set up hot-reloading of config go sigHandlerConfigReload(config) diff --git a/server/config.go b/server/config.go index f3320c5b..c351ba85 100644 --- a/server/config.go +++ b/server/config.go @@ -61,6 +61,8 @@ const ( DefaultVisitorAuthFailureLimitReplenish = time.Minute DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB + DefaultVisitorPrefixBitsIPv4 = 32 // Use the entire IPv4 address for rate limiting + DefaultVisitorPrefixBitsIPv6 = 64 // Use /64 for IPv6 rate limiting ) var ( @@ -143,6 +145,8 @@ type Config struct { VisitorAuthFailureLimitReplenish time.Duration VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats VisitorSubscriberRateLimiting bool // Enable subscriber-based rate limiting for UnifiedPush topics + VisitorPrefixBitsIPv4 int // Number of bits for IPv4 rate limiting (default: 32) + VisitorPrefixBitsIPv6 int // Number of bits for IPv6 rate limiting (default: 64) BehindProxy bool // If true, the server will trust the proxy client IP header to determine the client IP address (IPv4 and IPv6 supported) ProxyForwardedHeader string // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For" (IPv4 and IPv6 supported) ProxyTrustedAddresses []string // List of trusted proxy addresses (IPv4 or IPv6) that will be stripped from the Forwarded header if BehindProxy is true @@ -234,6 +238,8 @@ func NewConfig() *Config { VisitorAuthFailureLimitReplenish: DefaultVisitorAuthFailureLimitReplenish, VisitorStatsResetTime: DefaultVisitorStatsResetTime, VisitorSubscriberRateLimiting: false, + VisitorPrefixBitsIPv4: 32, // Default: use full IPv4 address + VisitorPrefixBitsIPv6: 64, // Default: use /64 for IPv6 BehindProxy: false, // If true, the server will trust the proxy client IP header to determine the client IP address ProxyForwardedHeader: "X-Forwarded-For", // Default header for reverse proxy client IPs StripeSecretKey: "", diff --git a/server/server.go b/server/server.go index e1126757..8d33d396 100644 --- a/server/server.go +++ b/server/server.go @@ -2023,7 +2023,7 @@ func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.Us func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor { s.mu.Lock() defer s.mu.Unlock() - id := visitorID(ip, user) + id := visitorID(ip, user, s.config) v, exists := s.visitors[id] if !exists { s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user) diff --git a/server/server_test.go b/server/server_test.go index be0610ac..0a5bcc08 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1169,7 +1169,7 @@ func (t *testMailer) Count() int { return t.count } -func TestServer_PublishTooRequests_Defaults(t *testing.T) { +func TestServer_PublishTooManyRequests_Defaults(t *testing.T) { s := newTestServer(t, newTestConfig(t)) for i := 0; i < 60; i++ { response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil) @@ -1179,7 +1179,50 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) { require.Equal(t, 429, response.Code) } -func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { +func TestServer_PublishTooManyRequests_Defaults_IPv6(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + overrideRemoteAddr1 := func(r *http.Request) { + r.RemoteAddr = "[2001:db8:9999:8888:1::1]:1234" + } + overrideRemoteAddr2 := func(r *http.Request) { + r.RemoteAddr = "[2001:db8:9999:8888:2::1]:1234" // Same /64 + } + for i := 0; i < 30; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1) + require.Equal(t, 200, response.Code) + } + for i := 0; i < 30; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2) + require.Equal(t, 200, response.Code) + } + response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1) + require.Equal(t, 429, response.Code) +} + +func TestServer_PublishTooManyRequests_IPv6_Slash48(t *testing.T) { + c := newTestConfig(t) + c.VisitorRequestLimitBurst = 6 + c.VisitorPrefixBitsIPv6 = 48 // Use /48 for IPv6 prefixes + s := newTestServer(t, c) + overrideRemoteAddr1 := func(r *http.Request) { + r.RemoteAddr = "[2001:db8:9999::1]:1234" + } + overrideRemoteAddr2 := func(r *http.Request) { + r.RemoteAddr = "[2001:db8:9999::2]:1234" // Same /48 + } + for i := 0; i < 3; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1) + require.Equal(t, 200, response.Code) + } + for i := 0; i < 3; i++ { + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2) + require.Equal(t, 200, response.Code) + } + response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1) + require.Equal(t, 429, response.Code) +} + +func TestServer_PublishTooManyRequests_Defaults_ExemptHosts(t *testing.T) { c := newTestConfig(t) c.VisitorRequestLimitBurst = 3 c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request() @@ -1190,7 +1233,21 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) { } } -func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) { +func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_IPv6(t *testing.T) { + c := newTestConfig(t) + c.VisitorRequestLimitBurst = 3 + c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("2001:db8:9999::/48")} + s := newTestServer(t, c) + overrideRemoteAddr := func(r *http.Request) { + r.RemoteAddr = "[2001:db8:9999::1]:1234" + } + for i := 0; i < 5; i++ { // > 3 + response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr) + require.Equal(t, 200, response.Code) + } +} + +func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) { c := newTestConfig(t) c.VisitorRequestLimitBurst = 10 c.VisitorMessageDailyLimit = 4 @@ -1202,7 +1259,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *tes } } -func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) { +func TestServer_PublishTooManyRequests_ShortReplenish(t *testing.T) { t.Parallel() c := newTestConfig(t) c.VisitorRequestLimitBurst = 60 @@ -2244,6 +2301,19 @@ func TestServer_Visitor_Custom_ClientIP_Header(t *testing.T) { require.Equal(t, "1.2.3.4", v.ip.String()) } +func TestServer_Visitor_Custom_ClientIP_Header_IPv6(t *testing.T) { + c := newTestConfig(t) + c.BehindProxy = true + c.ProxyForwardedHeader = "X-Client-IP" + s := newTestServer(t, c) + r, _ := http.NewRequest("GET", "/bla", nil) + r.RemoteAddr = "[2001:db8:9999::1]:1234" + r.Header.Set("X-Client-IP", "2001:db8:7777::1") + v, err := s.maybeAuthenticate(r) + require.Nil(t, err) + require.Equal(t, "2001:db8:7777::1", v.ip.String()) +} + func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) { c := newTestConfig(t) c.BehindProxy = true @@ -2258,6 +2328,20 @@ func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) { require.Equal(t, "5.6.7.8", v.ip.String()) } +func TestServer_Visitor_Custom_Forwarded_Header_IPv6(t *testing.T) { + c := newTestConfig(t) + c.BehindProxy = true + c.ProxyForwardedHeader = "Forwarded" + c.ProxyTrustedAddresses = []string{"2001:db8:1111::1"} + s := newTestServer(t, c) + r, _ := http.NewRequest("GET", "/bla", nil) + r.RemoteAddr = "[2001:db8:2222::1]:1234" + r.Header.Set("Forwarded", " for=[2001:db8:1111::1], by=example.com;for=[2001:db8:3333::1]") + v, err := s.maybeAuthenticate(r) + require.Nil(t, err) + require.Equal(t, "2001:db8:3333::1", v.ip.String()) +} + func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) { t.Parallel() count := 50000 diff --git a/server/util.go b/server/util.go index 54c1851b..687e7d0e 100644 --- a/server/util.go +++ b/server/util.go @@ -22,8 +22,13 @@ var ( priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`) // forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239) - // IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]" - forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`) + // IPv6 addresses in Forwarded header are enclosed in square brackets. The port is optional. + // + // Examples: + // for="1.2.3.4" + // for="[2001:db8::1]"; for=1.2.3.4:8080, by=phil + // for="1.2.3.4:8080" + forwardedHeaderRegex = regexp.MustCompile(`(?i)\bfor="?(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|\[[0-9a-f:]+])(?::\d+)?"?`) ) func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { @@ -105,7 +110,7 @@ func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader st // then take the right-most address in the list (as this is the one added by our proxy server). // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details. func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) { - value := strings.TrimSpace(r.Header.Get(forwardedHeader)) + value := strings.TrimSpace(strings.ToLower(r.Header.Get(forwardedHeader))) if value == "" { return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader) } diff --git a/server/visitor.go b/server/visitor.go index 155d7be0..f26729f1 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -2,13 +2,13 @@ package server import ( "fmt" - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/user" "net/netip" "sync" "time" "golang.org/x/time/rate" + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/user" "heckel.io/ntfy/v2/util" ) @@ -151,7 +151,7 @@ func (v *visitor) Context() log.Context { func (v *visitor) contextNoLock() log.Context { info := v.infoLightNoLock() fields := log.Context{ - "visitor_id": visitorID(v.ip, v.user), + "visitor_id": visitorID(v.ip, v.user, v.config), "visitor_ip": v.ip.String(), "visitor_seen": util.FormatTime(v.seen), "visitor_messages": info.Stats.Messages, @@ -524,15 +524,15 @@ func dailyLimitToRate(limit int64) rate.Limit { return rate.Limit(limit) * rate.Every(oneDay) } -func visitorID(ip netip.Addr, u *user.User) string { +// visitorID returns a unique identifier for a visitor based on user or IP, using configurable prefix bits for IPv4/IPv6 +func visitorID(ip netip.Addr, u *user.User, conf *Config) string { if u != nil && u.Tier != nil { return fmt.Sprintf("user:%s", u.ID) } - if ip.Is6() { - // IPv6 addresses are too long to be used as visitor IDs, so we use the first 8 bytes - ip = netip.PrefixFrom(ip, 64).Masked().Addr() - } else if ip.Is4() { - ip = netip.PrefixFrom(ip, 20).Masked().Addr() + if ip.Is4() { + ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv4).Masked().Addr() + } else if ip.Is6() { + ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv6).Masked().Addr() } return fmt.Sprintf("ip:%s", ip.String()) }