diff --git a/src/blocklist.rs b/src/blocklist.rs index ef865c4..20ac95d 100644 --- a/src/blocklist.rs +++ b/src/blocklist.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use std::time::Instant; +use std::time::{Duration, Instant}; use log::{info, warn}; @@ -355,27 +355,139 @@ mod tests { } } +const RETRY_DELAYS_SECS: &[u64] = &[2, 10, 30]; + pub async fn download_blocklists(lists: &[String]) -> Vec<(String, String)> { let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .timeout(Duration::from_secs(30)) .gzip(true) .build() .unwrap_or_default(); - let mut results = Vec::new(); + let fetches = lists.iter().map(|url| { + let client = &client; + async move { + let text = fetch_with_retry(client, url).await?; + info!("downloaded blocklist: {} ({} bytes)", url, text.len()); + Some((url.clone(), text)) + } + }); + futures::future::join_all(fetches) + .await + .into_iter() + .flatten() + .collect() +} - for url in lists { - match client.get(url).send().await { - Ok(resp) => match resp.text().await { - Ok(text) => { - info!("downloaded blocklist: {} ({} bytes)", url, text.len()); - results.push((url.clone(), text)); - } - Err(e) => warn!("failed to read blocklist body {}: {}", url, e), - }, - Err(e) => warn!("failed to download blocklist {}: {}", url, e), +async fn fetch_with_retry(client: &reqwest::Client, url: &str) -> Option { + fetch_with_retry_delays(client, url, RETRY_DELAYS_SECS).await +} + +async fn fetch_with_retry_delays( + client: &reqwest::Client, + url: &str, + delays: &[u64], +) -> Option { + let total = delays.len() + 1; + for attempt in 1..=total { + match fetch_once(client, url).await { + Ok(text) => return Some(text), + Err(msg) if attempt < total => { + let delay = delays[attempt - 1]; + warn!( + "blocklist {} attempt {}/{} failed: {} — retrying in {}s", + url, attempt, total, msg, delay + ); + tokio::time::sleep(Duration::from_secs(delay)).await; + } + Err(msg) => { + warn!( + "blocklist {} attempt {}/{} failed: {} — giving up", + url, attempt, total, msg + ); + } } } - - results + None +} + +async fn fetch_once(client: &reqwest::Client, url: &str) -> Result { + let resp = client + .get(url) + .send() + .await + .map_err(|e| format_error_chain(&e))?; + resp.text().await.map_err(|e| format_error_chain(&e)) +} + +fn format_error_chain(e: &(dyn std::error::Error + 'static)) -> String { + let mut parts = vec![e.to_string()]; + let mut src = e.source(); + while let Some(s) = src { + parts.push(s.to_string()); + src = s.source(); + } + parts.join(": ") +} + +#[cfg(test)] +mod retry_tests { + use super::*; + use std::net::SocketAddr; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + async fn flaky_http_server(drop_first_n: usize, body: &'static str) -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + for _ in 0..drop_first_n { + if let Ok((sock, _)) = listener.accept().await { + drop(sock); + } + } + loop { + let Ok((mut sock, _)) = listener.accept().await else { + return; + }; + tokio::spawn(async move { + let mut buf = [0u8; 2048]; + let _ = sock.read(&mut buf).await; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n{}", + body.len(), + body, + ); + let _ = sock.write_all(response.as_bytes()).await; + let _ = sock.shutdown().await; + }); + } + }); + addr + } + + fn zero_delays() -> Vec { + vec![0; RETRY_DELAYS_SECS.len()] + } + + #[tokio::test] + async fn retry_succeeds_on_final_attempt() { + let body = "ads.example.com\ntracker.example.net\n"; + let delays = zero_delays(); + let addr = flaky_http_server(delays.len(), body).await; + let client = reqwest::Client::new(); + let url = format!("http://{addr}/"); + let result = fetch_with_retry_delays(&client, &url, &delays).await; + assert_eq!(result.as_deref(), Some(body)); + } + + #[tokio::test] + async fn retry_gives_up_when_all_attempts_fail() { + let delays = zero_delays(); + let addr = flaky_http_server(delays.len() + 2, "unreachable").await; + let client = reqwest::Client::new(); + let url = format!("http://{addr}/"); + let result = fetch_with_retry_delays(&client, &url, &delays).await; + assert_eq!(result, None); + } }