diff --git a/.github/README.md b/.github/README.md index a345bce..92aca0f 100644 --- a/.github/README.md +++ b/.github/README.md @@ -158,6 +158,8 @@ There's also some additional optional environment variables that you may set: - `ADGUARD_PROTOCOL` - The protocol to use when connecting to AdGuard (defaults to `http`) - `ADGUARD_UPDATE_INTERVAL` - The rate at which to refresh the UI in seconds (defaults to `2`) +- `ADGUARD_TIMEOUT` - The per-request timeout when contacting AdGuard, in seconds (defaults to `5`) +- `ADGUARD_QUERYLOG_LIMIT` - The number of query log entries to fetch per update (defaults to `100`)
Examples diff --git a/quick-start.sh b/quick-start.sh index 7cdfc5a..000336d 100755 --- a/quick-start.sh +++ b/quick-start.sh @@ -41,10 +41,33 @@ function print_info { print_heading "Checking system type" if [[ "$OSTYPE" == "linux-gnu"* ]]; then print_info "System type: Linux" - bin_target="adguardian-linux" + case "$(uname -m)" in + x86_64|amd64) + bin_target="adguardian-linux" + ;; + aarch64|arm64) + bin_target="adguardian-linux-arm64" + ;; + armv7l|armv7*) + bin_target="adguardian-linux-armv7" + ;; + *) + exit_script "Unsupported Linux architecture: $(uname -m)" + ;; + esac elif [[ "$OSTYPE" == "darwin"* ]]; then print_info "System type: Apple OS X" - bin_target="adguardian-macos" + case "$(uname -m)" in + arm64) + bin_target="adguardian-macos" + ;; + x86_64) + bin_target="adguardian-macos-x86_64" + ;; + *) + exit_script "Unsupported macOS architecture: $(uname -m)" + ;; + esac elif [[ "$OSTYPE" == "cygwin" ]]; then print_info "System type: Windows/Cygwin" bin_target="adguardian-windows.exe" @@ -52,32 +75,26 @@ else exit_script "Unsupported System" fi -# Make the download link to latest binary for users system type -download_link="$upstream_repo/releases/$adguardian_version/download/$bin_target" - # Check if the binary already exists print_heading "Preparing to download" +download_link="$upstream_repo/releases/$adguardian_version/download/$bin_target" if [ -f "$download_location" ]; then print_info "File already exists, skipping download." +elif hash "curl" 2> /dev/null; then + print_info "Downloading to $download_location (with curl) from $download_link" + curl --fail --location --output "$download_location" "$download_link" \ + || { rm -f "$download_location"; exit_script "Unable to download a binary for your system"; } +elif hash "wget" 2> /dev/null; then + print_info "Downloading to $download_location (with wget) from $download_link" + wget --no-verbose --show-progress --progress=dot:mega -q -S -O "$download_location" "$download_link" \ + || { rm -f "$download_location"; exit_script "Unable to download a binary for your system"; } else - # Download with either curl or wget, depending on what is installed - if hash "curl" 2> /dev/null; then - print_info "Downloading to $download_location (with curl)" - curl -L -o $download_location $download_link - elif hash "wget" 2> /dev/null; then - print_info "Downloading to $download_location (with wget)" - wget \ - --no-verbose --show-progress \ - --progress=dot:mega -q -S \ - -O $download_location $download_link - else - exit_script "Neither curl nor wget were found on your system" - fi + exit_script "Neither curl nor wget were found on your system" fi # Make the binary executable, then run the application print_heading "Preparing to run" print_info "Updating permissions for $download_location" -chmod +x $download_location +chmod +x "$download_location" print_info "Starting AdGuardian....\n\n" -$download_location +"$download_location" diff --git a/src/fetch/fetch_filters.rs b/src/fetch/fetch_filters.rs index 18ed770..e35ac41 100644 --- a/src/fetch/fetch_filters.rs +++ b/src/fetch/fetch_filters.rs @@ -28,6 +28,12 @@ pub async fn fetch_adguard_filter_list( headers.insert("Authorization", auth_header_value.parse()?); let res: Response = client.get(&url).headers(headers).send().await?; + if !res.status().is_success() { + return Err(anyhow::anyhow!( + "Request failed with status code {}", + res.status() + )); + } let status: AdGuardFilteringStatus = res.json().await?; Ok(status) diff --git a/src/fetch/fetch_query_log.rs b/src/fetch/fetch_query_log.rs index 2ee0972..c2ab7a9 100644 --- a/src/fetch/fetch_query_log.rs +++ b/src/fetch/fetch_query_log.rs @@ -2,12 +2,14 @@ use base64::{engine::general_purpose::STANDARD, Engine as _}; use reqwest::header::{HeaderValue, AUTHORIZATION, CONTENT_LENGTH}; use serde::Deserialize; -#[derive(Deserialize)] +#[derive(Default, Deserialize)] +#[serde(default)] pub struct QueryResponse { pub data: Vec, } -#[derive(Deserialize)] +#[derive(Default, Deserialize)] +#[serde(default)] pub struct Query { pub cached: bool, pub client: String, @@ -19,7 +21,8 @@ pub struct Query { pub time: String, } -#[derive(Deserialize)] +#[derive(Default, Deserialize)] +#[serde(default)] pub struct Question { pub class: String, pub name: String, @@ -52,3 +55,24 @@ pub async fn fetch_adguard_query_log( let data = response.json().await?; Ok(data) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::fetch::fetch_stats::StatsResponse; + use crate::fetch::fetch_status::StatusResponse; + + // Missing or partial fields get decoded to defaults, instead of erroring + #[test] + fn empty_and_partial_json_decode_to_defaults() { + serde_json::from_str::("{}").unwrap(); + serde_json::from_str::("{}").unwrap(); + serde_json::from_str::("{}").unwrap(); + serde_json::from_str::(r#"{"num_dns_queries":5}"#).unwrap(); + + // A blocked query has no `upstream`, default to empty + let q = r#"{"cached":false,"client":"1.2.3.4","elapsedMs":"0.1", + "question":{"class":"IN","name":"x.com","type":"A"},"reason":"x","time":"t"}"#; + assert_eq!(serde_json::from_str::(q).unwrap().upstream, ""); + } +} diff --git a/src/fetch/fetch_stats.rs b/src/fetch/fetch_stats.rs index c19bc09..4943c5d 100644 --- a/src/fetch/fetch_stats.rs +++ b/src/fetch/fetch_stats.rs @@ -12,7 +12,8 @@ pub struct DomainData { pub count: i32, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Default, Deserialize, Clone)] +#[serde(default)] pub struct StatsResponse { pub num_dns_queries: u64, pub num_blocked_filtering: u64, diff --git a/src/fetch/fetch_status.rs b/src/fetch/fetch_status.rs index f3c25f5..bee758c 100644 --- a/src/fetch/fetch_status.rs +++ b/src/fetch/fetch_status.rs @@ -28,7 +28,8 @@ use serde::Deserialize; /// * `protection_enabled` - Whether or not protection is currently enabled. /// * `dhcp_available` - Whether or not DHCP is available. /// * `running` - Whether or not the AdGuard Home instance is currently running. -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Default, Deserialize, Clone)] +#[serde(default)] pub struct StatusResponse { pub version: String, pub dns_port: u16, diff --git a/src/main.rs b/src/main.rs index 57f85b7..8b0d9ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use tokio::time::{interval, MissedTickBehavior}; use ui::draw_ui; use fetch::{ - fetch_filters::fetch_adguard_filter_list, + fetch_filters::{fetch_adguard_filter_list, AdGuardFilteringStatus}, fetch_query_log::{fetch_adguard_query_log, Query}, fetch_stats::{fetch_adguard_stats, StatsResponse}, fetch_status::{fetch_adguard_status, StatusResponse}, @@ -33,8 +33,14 @@ async fn fetch_all( } async fn run() -> anyhow::Result<()> { - // Create a reqwest client - let client = Client::new(); + // Per-request timeout (seconds), clamped to at least 1, so no request can hang + let timeout_secs: u64 = env::var("ADGUARD_TIMEOUT") + .unwrap_or_else(|_| "5".into()) + .parse::()? + .max(1); + let client = Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build()?; // AdGuard instance details, from env vars (verified in welcome.rs) let ip = env::var("ADGUARD_IP")?; @@ -44,8 +50,18 @@ async fn run() -> anyhow::Result<()> { let username = env::var("ADGUARD_USERNAME")?; let password = env::var("ADGUARD_PASSWORD")?; - // Fetch data that doesn't require updates - let filters = fetch_adguard_filter_list(&client, &hostname, &username, &password).await?; + // Fetch the filter list, use empty list on failures is fine + let filters = welcome::with_retries( + 3, + Duration::from_secs(5), + "Fetching AdGuard filters", + || fetch_adguard_filter_list(&client, &hostname, &username, &password), + ) + .await + .unwrap_or_else(|e| { + eprintln!("Could not fetch filter list, starting without it: {}", e); + AdGuardFilteringStatus { filters: None } + }); // Open channels for data fetching where updates are required let (queries_tx, queries_rx) = tokio::sync::mpsc::channel(1); @@ -64,10 +80,11 @@ async fn run() -> anyhow::Result<()> { shutdown_tx, )); - // Get update interval (in seconds) + // Get update interval (in seconds), clamped to at least 1 (interval() panics on zero) let interval_secs: u64 = env::var("ADGUARD_UPDATE_INTERVAL") .unwrap_or_else(|_| "2".into()) - .parse()?; + .parse::()? + .max(1); let mut interval = interval(Duration::from_secs(interval_secs)); interval.set_missed_tick_behavior(MissedTickBehavior::Skip); @@ -106,7 +123,7 @@ async fn run() -> anyhow::Result<()> { } fn main() { - let rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Runtime::new().expect("failed to start async runtime"); rt.block_on(async { welcome::welcome().await.unwrap_or_else(|e| { eprintln!("Failed to initialize: {}", e); diff --git a/src/welcome.rs b/src/welcome.rs index 826818f..cbc0023 100644 --- a/src/welcome.rs +++ b/src/welcome.rs @@ -1,9 +1,15 @@ use base64::{engine::general_purpose::STANDARD, Engine as _}; use colored::*; +use crossterm::{ + event::{self, Event, KeyCode, KeyModifiers}, + terminal::{disable_raw_mode, enable_raw_mode}, +}; use reqwest::{Client, Error}; use std::{ + cmp::Ordering, env, - io::{self, Write}, + fmt::Display, + io::{self, IsTerminal, Write}, time::Duration, }; @@ -43,7 +49,7 @@ fn print_ascii_art() { } /// Print error message, along with (optional) stack trace, then exit -fn print_error(message: &str, sub_message: &str, error: Option<&Error>) { +fn print_error(message: &str, sub_message: &str, error: Option<&Error>) -> ! { eprintln!( "{}{}{}", message.red(), @@ -118,14 +124,54 @@ fn check_version(version: Option<&str>) { } } -/// With the users specified AdGuard details, verify the connection (exit on fail) +/// Run an async operation, retrying on error up to `attempts` times, `delay` apart. +/// Each failure is reported; the last error is returned once attempts are exhausted. +pub async fn with_retries( + attempts: u32, + delay: Duration, + label: &str, + mut operation: F, +) -> Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, + E: Display, +{ + let mut attempt = 1; + loop { + match operation().await { + Ok(value) => return Ok(value), + Err(e) if attempt < attempts => { + println!( + "{}", + format!( + "{} failed (attempt {}/{}): {}\nRetrying in {}s...", + label, + attempt, + attempts, + e, + delay.as_secs() + ) + .yellow() + ); + tokio::time::sleep(delay).await; + attempt += 1; + } + Err(e) => return Err(e), + } + } +} + +/// With the users specified AdGuard details, verify the connection. +/// Returns `Err` on a failed connection (so the caller can retry); exits on +/// rejected auth or an unsupported version, which retrying wouldn't fix. async fn verify_connection( client: &Client, - ip: String, - port: String, - protocol: String, - username: String, - password: String, + ip: &str, + port: &str, + protocol: &str, + username: &str, + password: &str, ) -> Result<(), Box> { println!( "{}", @@ -159,23 +205,13 @@ async fn verify_connection( Ok(()) } // Connection failed to authenticate. Print error and exit - Ok(_) => { - print_error( - &format!("Authentication with AdGuard at {}:{} failed", ip, port), - "Please check your environmental variables and try again.", - None, - ); - Ok(()) - } - // Connection failed to establish. Print error and exit - Err(e) => { - print_error( - &format!("Failed to connect to AdGuard at: {}:{}", ip, port), - "Please check your environmental variables and try again.", - Some(&e), - ); - Ok(()) - } + Ok(_) => print_error( + &format!("Authentication with AdGuard at {}:{} failed", ip, port), + "Check the credentials you passed as environmental variables and try again.", + None, + ), + // Connection failed to establish - return so the caller can retry + Err(e) => Err(e.into()), } } @@ -200,6 +236,7 @@ async fn get_latest_version(crate_name: &str) -> Result println!( "{}", format!( "A new version of AdGuardian is available.\nUpdate from {} to {} for the best experience", @@ -243,27 +280,124 @@ async fn check_for_updates() { latest_version.to_string().bold() ) .yellow() - ); - } else if current_version == latest_version { - println!( + ), + Ordering::Equal => println!( "{}", format!( "AdGuardian is up-to-date, running version {}", current_version.to_string().bold() ) .green() - ); - } else if current_version > latest_version { - println!( + ), + Ordering::Greater => println!( "{}", format!( "Running a pre-released edition of AdGuardian, version {}", current_version.to_string().bold() ) .green() - ); + ), + } +} + +/// The value to pre-fill for a field's interactive prompt, where a sensible one exists +fn default_for(key: &str) -> Option<&'static str> { + match key { + "ADGUARD_IP" => Some("127.0.0.1"), + "ADGUARD_PORT" => Some("3000"), + _ => None, + } +} + +/// Read a line from the terminal in raw mode, echoing nothing. Ctrl-C cancels. +fn read_masked() -> io::Result { + enable_raw_mode()?; + let result = masked_loop(); + let _ = disable_raw_mode(); + if result.is_ok() { + println!(); + } + result +} + +fn masked_loop() -> io::Result { + let mut value = String::new(); + loop { + if let Event::Key(key) = event::read()? { + let ctrl = key.modifiers.contains(KeyModifiers::CONTROL); + match key.code { + KeyCode::Enter => return Ok(value), + KeyCode::Char('c') if ctrl => return Err(io::ErrorKind::Interrupted.into()), + KeyCode::Char(c) if !ctrl => value.push(c), + KeyCode::Backspace => { + value.pop(); + } + _ => {} + } + } + } +} + +/// Print the prompt and read a value, masking secret fields on an interactive terminal +fn read_field(prompt: &ColoredString, secret: bool) -> io::Result { + print!("{}", prompt); + io::stdout().flush()?; + if secret && io::stdin().is_terminal() { + read_masked() } else { - println!("{}", "Unable to check for updates".yellow()); + let mut value = String::new(); + io::stdin().read_line(&mut value)?; + Ok(value) + } +} + +/// Read a field off the async runtime threads +async fn read_input(prompt: ColoredString, secret: bool) -> io::Result { + tokio::task::spawn_blocking(move || read_field(&prompt, secret)) + .await + .expect("input task panicked") +} + +/// Print the cancellation notice and exit cleanly +fn exit_interrupted() -> ! { + println!( + "{}", + "\n\nAdGuardian setup interrupted by user, exiting...".yellow() + ); + std::process::exit(0); +} + +/// Prompt for a single field, re-prompting until the input is valid. +/// Masks passwords, applies the field's default on empty input, validates the +/// port is numeric, and exits cleanly if the user interrupts with Ctrl-C. +async fn prompt_for(key: &str) -> Result> { + let default = default_for(key); + let secret = key.contains("PASSWORD"); + loop { + let hint = default.map(|d| format!(" [{}]", d)).unwrap_or_default(); + let prompt = format!("› Enter a value for {}{}: ", key, hint) + .blue() + .bold(); + + let input = tokio::select! { + res = read_input(prompt, secret) => match res { + Ok(value) => value, + Err(e) if e.kind() == io::ErrorKind::Interrupted => exit_interrupted(), + Err(e) => return Err(e.into()), + }, + _ = tokio::signal::ctrl_c() => exit_interrupted(), + }; + + let value = match input.trim() { + "" => default.unwrap_or_default(), + trimmed => trimmed, + }; + + if key == "ADGUARD_PORT" && value.parse::().is_err() { + println!("{}", "Port must be a number, and a valid port".yellow()); + continue; + } + return Ok(value.to_string()); } } @@ -305,7 +439,7 @@ pub async fn welcome() -> Result<(), Box> { while let Some(arg) = args.next() { for &(flag, var) in &flags { if arg == flag { - if let Some(value) = args.peek() { + if let Some(value) = args.peek().filter(|v| !v.starts_with("--")) { env::set_var(var, value); args.next(); } @@ -325,12 +459,7 @@ pub async fn welcome() -> Result<(), Box> { "{}", format!("The {} environmental variable is not yet set", key.bold()).yellow() ); - print!("{}", format!("› Enter a value for {}: ", key).blue().bold()); - io::stdout().flush()?; - - let mut value = String::new(); - io::stdin().read_line(&mut value)?; - env::set_var(key, value.trim()); + env::set_var(key, prompt_for(key).await?); } } @@ -341,8 +470,22 @@ pub async fn welcome() -> Result<(), Box> { let username = get_env("ADGUARD_USERNAME")?; let password = get_env("ADGUARD_PASSWORD")?; - // Verify that we can connect, authenticate, and that version is supported (exit on failure) - verify_connection(&client, ip, port, protocol, username, password).await?; + // Verify we can connect, authenticate, and that the version is supported + let connected = with_retries(3, Duration::from_secs(5), "AdGuard connection", || { + verify_connection(&client, &ip, &port, &protocol, &username, &password) + }) + .await; + + if connected.is_err() { + print_error( + &format!( + "Could not reach AdGuard at {}:{} after 3 attempts", + ip, port + ), + "Please check that AdGuard Home is running and your settings are correct.", + None, + ); + } Ok(()) } diff --git a/src/widgets/table.rs b/src/widgets/table.rs index 0607b0b..cdab5d0 100644 --- a/src/widgets/table.rs +++ b/src/widgets/table.rs @@ -14,7 +14,7 @@ pub fn make_query_table(data: &[Query], width: u16) -> Table<'_> { let time = Cell::from(time_ago(query.time.as_str()).unwrap_or("unknown".to_string())) .style(Style::default().fg(Color::Gray)); - let question = Cell::from(make_request_cell(&query.question).unwrap()) + let question = Cell::from(make_request_cell(&query.question)) .style(Style::default().add_modifier(Modifier::BOLD)); let client = Cell::from(query.client.as_str()).style(Style::default().fg(Color::Blue)); @@ -95,8 +95,8 @@ fn time_ago(timestamp: &str) -> Result { } // Return cell showing info about the request made in a given query -fn make_request_cell(q: &Question) -> Result { - Ok(format!("[{}] {} - {}", q.class, q.question_type, q.name)) +fn make_request_cell(q: &Question) -> String { + format!("[{}] {} - {}", q.class, q.question_type, q.name) } // Return a cell showing the time taken for a query, and a color based on time