From 59d4ddb7194bbbebabf7c689047e2dc67ed49858 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 20 Jan 2026 00:04:44 +0400 Subject: [PATCH] fix(server): add optional TLS and block insecure binds --- Cargo.toml | 1 + src/cli/commands.rs | 39 ++++++++++++++++++++++-- src/server/api.rs | 61 ++++++++++++++++++++++++++++++++++---- src/ui/console.rs | 7 +++-- tests/integration_tests.rs | 27 +++++++++++++++++ 5 files changed, 124 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d912c66..816c469 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ hf-hub = "0.4" clap = { version = "4.5", features = ["derive", "env"] } tokio = { version = "1", features = ["full"] } axum = "0.8" +axum-server = { version = "0.8", features = ["tls-rustls"] } tower-http = { version = "0.6", features = ["cors"] } rusqlite = { version = "0.32", features = ["bundled"] } ratatui = "0.29" diff --git a/src/cli/commands.rs b/src/cli/commands.rs index f128499..655fa90 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -40,6 +40,18 @@ pub struct Cli { /// Server port to use #[arg(short = 'p', long, global = true, env = "VGREP_PORT")] port: Option, + + /// TLS certificate path (PEM) + #[arg(long, global = true, env = "VGREP_TLS_CERT")] + tls_cert: Option, + + /// TLS private key path (PEM) + #[arg(long, global = true, env = "VGREP_TLS_KEY")] + tls_key: Option, + + /// Allow serving plain HTTP on non-loopback interfaces (INSECURE) + #[arg(long, global = true, env = "VGREP_ALLOW_INSECURE_HTTP")] + allow_insecure_http: bool, } #[derive(Subcommand)] @@ -314,7 +326,14 @@ impl Cli { Some(Commands::Serve { host, port }) => { let host = host.unwrap_or_else(|| config.server_host.clone()); let port = port.unwrap_or(config.server_port); - run_serve(&config, host, port) + run_serve( + &config, + host, + port, + self.tls_cert, + self.tls_key, + self.allow_insecure_http, + ) } Some(Commands::Status) => run_status(&config), Some(Commands::Models { action }) => run_models(action, &mut config), @@ -748,9 +767,23 @@ fn run_watch(config: &Config, path: PathBuf) -> Result<()> { watcher.watch() } -fn run_serve(config: &Config, host: String, port: u16) -> Result<()> { +fn run_serve( + config: &Config, + host: String, + port: u16, + tls_cert_path: Option, + tls_key_path: Option, + allow_insecure_http: bool, +) -> Result<()> { let rt = tokio::runtime::Runtime::new()?; - rt.block_on(server::run_server(config, &host, port)) + rt.block_on(server::run_server( + config, + &host, + port, + tls_cert_path, + tls_key_path, + allow_insecure_http, + )) } fn run_status(config: &Config) -> Result<()> { diff --git a/src/server/api.rs b/src/server/api.rs index bd5cc8a..09c1226 100644 --- a/src/server/api.rs +++ b/src/server/api.rs @@ -1,4 +1,4 @@ -use anyhow::Result; +use anyhow::{Context, Result}; use axum::{ extract::State, http::StatusCode, @@ -7,7 +7,7 @@ use axum::{ Json, Router, }; use serde::{Deserialize, Serialize}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tower_http::cors::{Any, CorsLayer}; @@ -84,7 +84,43 @@ pub struct StatusResponse { type SharedState = Arc; -pub async fn run_server(config: &Config, host: &str, port: u16) -> Result<()> { +fn is_loopback_host(host: &str) -> bool { + if host.eq_ignore_ascii_case("localhost") { + return true; + } + + host.parse::().is_ok_and(|ip| ip.is_loopback()) +} + +pub async fn run_server( + config: &Config, + host: &str, + port: u16, + tls_cert_path: Option, + tls_key_path: Option, + allow_insecure_http: bool, +) -> Result<()> { + let tls_paths = match (tls_cert_path, tls_key_path) { + (Some(cert), Some(key)) => Some((cert, key)), + (None, None) => None, + (Some(_), None) => { + anyhow::bail!( + "TLS certificate provided without TLS key. Provide --tls-key or set VGREP_TLS_KEY." + ) + } + (None, Some(_)) => { + anyhow::bail!( + "TLS key provided without TLS certificate. Provide --tls-cert or set VGREP_TLS_CERT." + ) + } + }; + + if !is_loopback_host(host) && tls_paths.is_none() && !allow_insecure_http { + anyhow::bail!( + "Refusing to bind to non-loopback address '{host}' without TLS. Configure TLS with --tls-cert/--tls-key (VGREP_TLS_CERT/VGREP_TLS_KEY) or set VGREP_ALLOW_INSECURE_HTTP=true to override." + ); + } + let config = config.clone(); if !config.has_embedding_model() { @@ -120,10 +156,23 @@ pub async fn run_server(config: &Config, host: &str, port: u16) -> Result<()> { let addr: SocketAddr = format!("{}:{}", host, port).parse()?; - crate::ui::print_server_banner(host, port); + crate::ui::print_server_banner(host, port, tls_paths.is_some()); - let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, app).await?; + match tls_paths { + Some((cert, key)) => { + let tls_config = axum_server::tls_rustls::RustlsConfig::from_pem_file(cert, key) + .await + .context("Failed to load TLS certificate/key")?; + + axum_server::bind_rustls(addr, tls_config) + .serve(app.into_make_service()) + .await?; + } + None => { + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + } + } Ok(()) } diff --git a/src/ui/console.rs b/src/ui/console.rs index b910218..7d99f3b 100644 --- a/src/ui/console.rs +++ b/src/ui/console.rs @@ -34,12 +34,15 @@ pub fn print_banner() { println!(); } -pub fn print_server_banner(host: &str, port: u16) { +pub fn print_server_banner(host: &str, port: u16, use_tls: bool) { print_banner(); + let scheme = if use_tls { "https" } else { "http" }; println!( " {}Server listening on {}", SERVER, - style(format!("http://{}:{}", host, port)).green().bold() + style(format!("{}://{}:{}", scheme, host, port)) + .green() + .bold() ); println!(); println!(" {}Endpoints:", GEAR); diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index efb71d9..fe15015 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -59,3 +59,30 @@ fn test_config_show() { .success() .stdout(predicate::str::contains("Chunk size")); } + +#[test] +fn test_serve_refuses_insecure_non_loopback_without_tls_or_override() { + let home = tempfile::tempdir().unwrap(); + + vgrep() + .env("HOME", home.path()) + .env("VGREP_HOST", "0.0.0.0") + .arg("serve") + .assert() + .failure() + .stderr(predicate::str::contains("Refusing to bind to non-loopback")); +} + +#[test] +fn test_serve_allows_insecure_override() { + let home = tempfile::tempdir().unwrap(); + + vgrep() + .env("HOME", home.path()) + .env("VGREP_HOST", "0.0.0.0") + .env("VGREP_ALLOW_INSECURE_HTTP", "true") + .arg("serve") + .assert() + .failure() + .stderr(predicate::str::contains("Embedding model not found")); +}