diff --git a/src/blocklist.rs b/src/blocklist.rs index e5caa99..ef865c4 100644 --- a/src/blocklist.rs +++ b/src/blocklist.rs @@ -81,66 +81,70 @@ impl BlocklistStore { if !self.enabled { return false; } - if let Some(until) = self.paused_until { if Instant::now() < until { return false; } } - - if self.allowlist.contains(domain) { + let domain = Self::normalize(domain); + if Self::find_in_set(&domain, &self.allowlist).is_some() { return false; } - - if self.domains.contains(domain) { - return true; - } - - // Walk up: ads.tracker.example.com → tracker.example.com → example.com - let mut d = domain; - while let Some(dot) = d.find('.') { - d = &d[dot + 1..]; - if self.allowlist.contains(d) { - return false; - } - if self.domains.contains(d) { - return true; - } - } - - false + Self::find_in_set(&domain, &self.domains).is_some() } - /// Check if a domain is blocked and return the reason. pub fn check(&self, domain: &str) -> BlockCheckResult { - let domain = domain.to_lowercase(); - if !self.enabled { return BlockCheckResult::disabled(); } - if self.allowlist.contains(&domain) { - return BlockCheckResult::allowed(&domain, "exact match in allowlist"); + if let Some(until) = self.paused_until { + if Instant::now() < until { + return BlockCheckResult::disabled(); + } } - if self.domains.contains(&domain) { - return BlockCheckResult::blocked(&domain, "exact match in blocklist"); + let domain = Self::normalize(domain); + + if let Some(matched) = Self::find_in_set(&domain, &self.allowlist) { + let reason = if matched == domain { + "exact match in allowlist" + } else { + "parent domain in allowlist" + }; + return BlockCheckResult::allowed(matched, reason); } - let mut d = domain.as_str(); - while let Some(dot) = d.find('.') { - d = &d[dot + 1..]; - if self.allowlist.contains(d) { - return BlockCheckResult::allowed(d, "parent domain in allowlist"); - } - if self.domains.contains(d) { - return BlockCheckResult::blocked(d, "parent domain in blocklist"); - } + if let Some(matched) = Self::find_in_set(&domain, &self.domains) { + let reason = if matched == domain { + "exact match in blocklist" + } else { + "parent domain in blocklist" + }; + return BlockCheckResult::blocked(matched, reason); } BlockCheckResult::not_blocked() } + fn normalize(domain: &str) -> String { + domain.to_lowercase().trim_end_matches('.').to_string() + } + + fn find_in_set<'a>(domain: &'a str, set: &HashSet) -> Option<&'a str> { + if set.contains(domain) { + return Some(domain); + } + let mut d = domain; + while let Some(dot) = d.find('.') { + d = &d[dot + 1..]; + if set.contains(d) { + return Some(d); + } + } + None + } + /// Atomically swap in a new domain set. Build the set outside the lock, /// then call this to swap — keeps lock hold time sub-microsecond. pub fn swap_domains(&mut self, domains: HashSet, sources: Vec) { @@ -172,11 +176,11 @@ impl BlocklistStore { } pub fn add_to_allowlist(&mut self, domain: &str) { - self.allowlist.insert(domain.to_lowercase()); + self.allowlist.insert(Self::normalize(domain)); } pub fn remove_from_allowlist(&mut self, domain: &str) -> bool { - self.allowlist.remove(&domain.to_lowercase()) + self.allowlist.remove(&Self::normalize(domain)) } pub fn allowlist(&self) -> Vec { @@ -247,6 +251,97 @@ pub fn parse_blocklist(text: &str) -> HashSet { mod tests { use super::*; + fn store_with(domains: &[&str], allowlist: &[&str]) -> BlocklistStore { + let mut store = BlocklistStore::new(); + store.swap_domains(domains.iter().map(|s| s.to_string()).collect(), vec![]); + for d in allowlist { + store.add_to_allowlist(d); + } + store + } + + #[test] + fn exact_block() { + let store = store_with(&["ads.example.com"], &[]); + assert!(store.is_blocked("ads.example.com")); + assert!(!store.is_blocked("example.com")); + } + + #[test] + fn parent_block_covers_subdomain() { + let store = store_with(&["tracker.com"], &[]); + assert!(store.is_blocked("tracker.com")); + assert!(store.is_blocked("www.tracker.com")); + assert!(store.is_blocked("deep.sub.tracker.com")); + } + + #[test] + fn exact_allowlist_unblocks() { + let store = store_with(&["ads.example.com"], &["ads.example.com"]); + assert!(!store.is_blocked("ads.example.com")); + } + + #[test] + fn parent_allowlist_unblocks_subdomain() { + let store = store_with(&["example.com", "www.example.com"], &["example.com"]); + assert!(!store.is_blocked("example.com")); + assert!(!store.is_blocked("www.example.com")); + assert!(!store.is_blocked("sub.deep.example.com")); + } + + #[test] + fn allowlist_does_not_unblock_sibling() { + let store = store_with( + &["www.example.com", "ads.example.com"], + &["www.example.com"], + ); + assert!(!store.is_blocked("www.example.com")); + assert!(store.is_blocked("ads.example.com")); + } + + #[test] + fn check_reports_parent_allowlist() { + let store = store_with( + &["goatcounter.com", "www.goatcounter.com"], + &["goatcounter.com"], + ); + let result = store.check("www.goatcounter.com"); + assert!(!result.blocked); + assert_eq!(result.matched_rule.as_deref(), Some("goatcounter.com")); + } + + #[test] + fn disabled_never_blocks() { + let mut store = store_with(&["ads.example.com"], &[]); + store.set_enabled(false); + assert!(!store.is_blocked("ads.example.com")); + } + + #[test] + fn trailing_dot_normalized() { + let store = store_with(&["ads.example.com"], &["safe.example.com"]); + assert!(store.is_blocked("ads.example.com.")); + assert!(!store.is_blocked("safe.example.com.")); + let result = store.check("ads.example.com."); + assert!(result.blocked); + } + + #[test] + fn case_insensitive() { + let store = store_with(&["ads.example.com"], &["safe.example.com"]); + assert!(store.is_blocked("ADS.Example.COM")); + assert!(!store.is_blocked("Safe.Example.COM")); + } + + #[test] + fn domain_in_neither_list() { + let store = store_with(&["ads.example.com"], &[]); + let result = store.check("clean.example.org"); + assert!(!result.blocked); + assert_eq!(result.reason, "not in blocklist"); + assert!(result.matched_rule.is_none()); + } + #[test] fn heap_bytes_grows_with_domains() { let mut store = BlocklistStore::new();