Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 100 additions & 119 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ h3-quinn = { version = "0.0.10" }
http = "1.4.0"
backon = "1.6.0"
openraft = { version = "0.9.24", features = ["storage-v2", "serde"] }
tokio-postgres-rustls = "0.13.0"
tokio-postgres-rustls = "0.14.0"
rustls = { version = "0.23.40" }
rustls-platform-verifier = "0.7.0"
toml = "1.1.2"
Expand All @@ -56,10 +56,10 @@ serde_json = "1.0.149"
rmp-serde = "1.3.1"
anyhow = "1.0.102"
bytes = "1.11.1"
rcgen = { version = "0.14.7", features = ["pem"] }
rcgen = { version = "0.14.8", features = ["pem"] }
foldhash = "0.2.0"
portable-atomic = "1.13.1"
scc = "3.7.0"
scc = "3.7.1"
sdd = "4.8.6"
schnorrkel = "0.11.5"
base64-simd = "0.8"
Expand All @@ -68,7 +68,7 @@ bs58 = "0.5.1"
prometheus = "0.14.0"
blake3 = "1.8.5"
moka = { version = "0.12.15", features = ["sync"] }
mimalloc = { version = "0.1.50", default-features = false }
mimalloc = { version = "0.1.51", default-features = false }
regex = "1.12.3"
image = "0.25.10"
chrono = { version = "0.4", features = ["serde", "clock"] }
Expand Down
2 changes: 2 additions & 0 deletions dev-env/config/config-single.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ get_timeout_sec = 2
max_idle_timeout_sec = 4
keep_alive_interval_sec = 1
rate_limit_whitelist = []
# Trust X-Forwarded-For only from the immediate proxy ranges that can reach the gateway.
trusted_proxy_cidrs = []
worker_whitelist = []
distributed_rate_limiter_max_capacity = 4096
allowed_origins = ["discord", "unity", "blender", "gen.404.xyz", "api"]
Expand Down
2 changes: 2 additions & 0 deletions dev-env/config/config1.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ get_timeout_sec = 2
max_idle_timeout_sec = 4
keep_alive_interval_sec = 1
rate_limit_whitelist = []
# Trust X-Forwarded-For only from the immediate proxy ranges that can reach the gateway.
trusted_proxy_cidrs = []
worker_whitelist = []
distributed_rate_limiter_max_capacity = 4096
allowed_origins = ["discord", "unity", "blender", "gen.404.xyz", "api"]
Expand Down
2 changes: 2 additions & 0 deletions dev-env/config/config2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ get_timeout_sec = 2
max_idle_timeout_sec = 4
keep_alive_interval_sec = 1
rate_limit_whitelist = []
# Trust X-Forwarded-For only from the immediate proxy ranges that can reach the gateway.
trusted_proxy_cidrs = []
worker_whitelist = []
distributed_rate_limiter_max_capacity = 4096
allowed_origins = ["discord", "unity", "blender", "gen.404.xyz", "api"]
Expand Down
2 changes: 2 additions & 0 deletions dev-env/config/config3.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ get_timeout_sec = 2
max_idle_timeout_sec = 4
keep_alive_interval_sec = 1
rate_limit_whitelist = []
# Trust X-Forwarded-For only from the immediate proxy ranges that can reach the gateway.
trusted_proxy_cidrs = []
worker_whitelist = []
distributed_rate_limiter_max_capacity = 4096
allowed_origins = ["discord", "unity", "blender", "gen.404.xyz", "api"]
Expand Down
160 changes: 159 additions & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{Result, anyhow};
use foldhash::HashSet;
use serde::{Deserialize, Deserializer, Serialize};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::str::FromStr;
use std::{fmt, path::Path, path::PathBuf};
use tracing::Level;
Expand Down Expand Up @@ -137,6 +138,10 @@ pub struct HTTPConfig {
pub basic_rate_limit: usize,
pub add_task_unauthorized_per_ip_daily_rate_limit: usize,
pub rate_limit_whitelist: HashSet<String>,
/// CIDR ranges or literal IPs whose X-Forwarded-For headers are trusted.
/// Keep empty unless the gateway is only reachable through those proxies.
#[serde(default)]
pub trusted_proxy_cidrs: Vec<String>,
#[serde(default = "default_distributed_rate_limiter_max_capacity")]
pub distributed_rate_limiter_max_capacity: usize,
pub worker_per_minute_rate_limit: usize,
Expand Down Expand Up @@ -452,10 +457,122 @@ pub fn validate_node_config(config: &NodeConfig) -> Result<()> {
}
tracing::warn!("TLS server verification is disabled; only use this in local development");
}
validate_trusted_proxy_cidrs(&config.http.trusted_proxy_cidrs)?;
validate_raft_dns_config(config)?;
Ok(())
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrustedProxyRange {
V4 { network: u32, prefix: u8 },
V6 { network: u128, prefix: u8 },
}

impl TrustedProxyRange {
pub fn contains(&self, ip: IpAddr) -> bool {
match (*self, ip) {
(TrustedProxyRange::V4 { network, prefix }, IpAddr::V4(ip)) => {
u32::from(ip) & ipv4_prefix_mask(prefix) == network
}
(TrustedProxyRange::V6 { network, prefix }, IpAddr::V6(ip)) => {
u128::from(ip) & ipv6_prefix_mask(prefix) == network
}
_ => false,
}
}

fn from_ip(ip: IpAddr) -> Self {
match ip {
IpAddr::V4(ip) => Self::from_ipv4_cidr(ip, 32),
IpAddr::V6(ip) => Self::from_ipv6_cidr(ip, 128),
}
}

fn from_ipv4_cidr(ip: Ipv4Addr, prefix: u8) -> Self {
let network = u32::from(ip) & ipv4_prefix_mask(prefix);
Self::V4 { network, prefix }
}

fn from_ipv6_cidr(ip: Ipv6Addr, prefix: u8) -> Self {
let network = u128::from(ip) & ipv6_prefix_mask(prefix);
Self::V6 { network, prefix }
}
}

fn ipv4_prefix_mask(prefix: u8) -> u32 {
if prefix == 0 {
0
} else {
u32::MAX << (32 - u32::from(prefix))
}
}

fn ipv6_prefix_mask(prefix: u8) -> u128 {
if prefix == 0 {
0
} else {
u128::MAX << (128 - u32::from(prefix))
}
}

pub fn parse_trusted_proxy_cidr(entry: &str) -> Result<TrustedProxyRange> {
let trimmed = entry.trim();
if trimmed.is_empty() {
return Err(anyhow!(
"http.trusted_proxy_cidrs must not contain empty entries"
));
}

if let Ok(ip) = trimmed.parse::<IpAddr>() {
return Ok(TrustedProxyRange::from_ip(ip));
}

if let Some((addr, prefix)) = trimmed.split_once('/') {
if prefix.contains('/') {
return Err(anyhow!(
"http.trusted_proxy_cidrs entry '{}' must be an IP address or CIDR block",
entry
));
}
let ip = addr.trim().parse::<IpAddr>().map_err(|_| {
anyhow!(
"http.trusted_proxy_cidrs entry '{}' must be an IP address or CIDR block",
entry
)
})?;
let prefix = prefix.trim().parse::<u8>().map_err(|_| {
anyhow!(
"http.trusted_proxy_cidrs entry '{}' has an invalid CIDR prefix length",
entry
)
})?;
return match ip {
IpAddr::V4(ip) if prefix <= 32 => Ok(TrustedProxyRange::from_ipv4_cidr(ip, prefix)),
IpAddr::V6(ip) if prefix <= 128 => Ok(TrustedProxyRange::from_ipv6_cidr(ip, prefix)),
IpAddr::V4(_) => Err(anyhow!(
"http.trusted_proxy_cidrs entry '{}' has IPv4 prefix length greater than 32",
entry
)),
IpAddr::V6(_) => Err(anyhow!(
"http.trusted_proxy_cidrs entry '{}' has IPv6 prefix length greater than 128",
entry
)),
};
}

Err(anyhow!(
"http.trusted_proxy_cidrs entry '{}' must be an IP address or CIDR block",
entry
))
}

fn validate_trusted_proxy_cidrs(entries: &[String]) -> Result<()> {
for entry in entries {
parse_trusted_proxy_cidr(entry)?;
}
Ok(())
}

pub fn validate_multi_node_raft_dns_config(config: &NodeConfig) -> Result<()> {
if config.raft.dns_name.trim().is_empty() {
return Err(anyhow!(
Expand Down Expand Up @@ -555,7 +672,10 @@ pub fn resolve_config_path(path: Option<&String>) -> Result<PathBuf> {

#[cfg(test)]
mod tests {
use super::{NodeConfig, validate_multi_node_raft_dns_config, validate_node_config};
use super::{
NodeConfig, parse_trusted_proxy_cidr, validate_multi_node_raft_dns_config,
validate_node_config,
};

fn read_config_single() -> String {
let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
Expand All @@ -582,6 +702,44 @@ mod tests {
assert_eq!(config.http.generic_key_concurrent_limit, 3);
}

#[test]
fn http_config_defaults_trusted_proxy_cidrs_when_missing() {
let config_text = read_config_single().replace("trusted_proxy_cidrs = []\n", "");
let config: NodeConfig =
toml::from_str(&config_text).expect("parse config without trusted proxies");
assert!(config.http.trusted_proxy_cidrs.is_empty());
validate_node_config(&config).expect("empty trusted proxy cidrs are valid");
}

#[test]
fn http_config_rejects_invalid_trusted_proxy_cidrs() {
let config_text = read_config_single().replace(
"trusted_proxy_cidrs = []",
"trusted_proxy_cidrs = [\"not-a-cidr\"]",
);
let config: NodeConfig = toml::from_str(&config_text).expect("parse config");
let err = validate_node_config(&config).expect_err("reject invalid proxy cidr");
assert!(err.to_string().contains("trusted_proxy_cidrs"));
}

#[test]
fn trusted_proxy_cidr_parser_matches_ipv4_and_ipv6_ranges() {
let ipv4 = parse_trusted_proxy_cidr("35.191.0.0/16").expect("ipv4 cidr");
assert!(ipv4.contains("35.191.22.10".parse().unwrap()));
assert!(!ipv4.contains("35.192.0.1".parse().unwrap()));

let ipv6 = parse_trusted_proxy_cidr("2600:2d00:1:b029::/64").expect("ipv6 cidr");
assert!(ipv6.contains("2600:2d00:1:b029::1".parse().unwrap()));
assert!(!ipv6.contains("2600:2d00:1:b02a::1".parse().unwrap()));
}

#[test]
fn trusted_proxy_cidr_parser_accepts_literal_ip_as_single_host() {
let proxy = parse_trusted_proxy_cidr("127.0.0.1").expect("literal ip");
assert!(proxy.contains("127.0.0.1".parse().unwrap()));
assert!(!proxy.contains("127.0.0.2".parse().unwrap()));
}

#[test]
fn network_config_defaults_cluster_peer_egress_ips_when_missing() {
let config_text = read_config_single();
Expand Down
16 changes: 14 additions & 2 deletions src/config_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use tokio::sync::Notify;
use tracing::{info, warn};

use crate::config::{
HTTPConfig, ImageConfig, ModelParamsConfig, NodeConfig, PromptConfig, read_config_from_path,
validate_node_config,
HTTPConfig, ImageConfig, ModelParamsConfig, NodeConfig, PromptConfig, TrustedProxyRange,
parse_trusted_proxy_cidr, read_config_from_path, validate_node_config,
};
use crate::http3::rate_limits::RateLimiters;
use crate::http3::upload_limiter::ImageUploadLimiter;
Expand All @@ -29,6 +29,7 @@ pub struct RuntimeConfigSnapshot {
pub prompt_regex: Regex,
pub rate_limit_whitelist: RateLimitWhitelist,
pub cluster_ips: HashSet<IpAddr>,
pub trusted_proxy_cidrs: Vec<TrustedProxyRange>,
pub rate_limiters: RateLimiters,
pub rate_limit_service: RateLimitService,
pub image_upload_limiter: ImageUploadLimiter,
Expand Down Expand Up @@ -72,6 +73,10 @@ impl RuntimeConfigView {
&self.snapshot.cluster_ips
}

pub fn trusted_proxy_cidrs(&self) -> &[TrustedProxyRange] {
&self.snapshot.trusted_proxy_cidrs
}

pub fn rate_limits(&self) -> &RateLimitService {
&self.snapshot.rate_limit_service
}
Expand Down Expand Up @@ -347,6 +352,12 @@ async fn build_runtime_snapshot(config: NodeConfig) -> Result<RuntimeConfigSnaps
let whitelist_ips = resolve_rate_limit_whitelist(&config.http.rate_limit_whitelist).await;

let cluster_ips = resolve_egress_ips(&config.network.cluster_peer_egress_ips).await;
let trusted_proxy_cidrs = config
.http
.trusted_proxy_cidrs
.iter()
.map(|entry| parse_trusted_proxy_cidr(entry))
.collect::<Result<Vec<_>>>()?;

let rate_limit_service = RateLimitService::new(&config.http);
let rate_limiters = RateLimiters::new(&config.http);
Expand All @@ -359,6 +370,7 @@ async fn build_runtime_snapshot(config: NodeConfig) -> Result<RuntimeConfigSnaps
ips: Arc::new(whitelist_ips),
},
cluster_ips,
trusted_proxy_cidrs,
rate_limiters,
rate_limit_service,
image_upload_limiter,
Expand Down
Loading
Loading