diff --git a/arxignis-main/services/threat-ai-rs/Cargo.toml b/arxignis-main/services/threat-ai-rs/Cargo.toml new file mode 100644 index 0000000..4d52bca --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "threat-ai-rs" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1" +axum = { version = "0.8", features = ["json"] } +clap = { version = "4.5", features = ["derive"] } +md5 = "0.7" +r2d2 = "0.8" +rand = "0.9" +reqwest = { version = "0.12", features = ["json", "rustls-tls"] } +rusqlite = { version = "0.32", features = ["bundled", "load_extension"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +serde_yaml = "0.9" +thiserror = "2" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "time", "fs"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } +tokenizers = "0.21" +safetensors = "0.4" +memmap2 = "0.9" +half = "2" +bytemuck = "1" diff --git a/arxignis-main/services/threat-ai-rs/data/ja4_db.db b/arxignis-main/services/threat-ai-rs/data/ja4_db.db new file mode 100644 index 0000000..b88de74 Binary files /dev/null and b/arxignis-main/services/threat-ai-rs/data/ja4_db.db differ diff --git a/arxignis-main/services/threat-ai-rs/src/bootstrap.rs b/arxignis-main/services/threat-ai-rs/src/bootstrap.rs new file mode 100644 index 0000000..f381173 --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/bootstrap.rs @@ -0,0 +1,92 @@ +use crate::config::Config; +use crate::database::Database; +use crate::models::TrafficRecordInput; +use crate::scoring::Scorer; +use crate::vectorization::Vectorizer; +use anyhow::{anyhow, Context, Result}; +use rand::seq::SliceRandom; +use serde_json::Value; +use std::fs; +use std::path::PathBuf; +use tracing::{info, warn}; + +pub async fn bootstrap_database( + config: Config, + verified_path: PathBuf, + dataset_path: Option, +) -> Result<()> { + let db = Database::new( + &config.database_url, + config.sqlite_vector_path.as_ref().map(PathBuf::from), + config.database_pool_max_size, + )?; + + let vectorizer = Vectorizer::new(config.vector.clone())?; + let scorer = Scorer::new(config.score.clone()); + + let verified_values = read_json_array(&verified_path) + .with_context(|| format!("reading verified data from {}", verified_path.display()))?; + + let mut verified_records = Vec::new(); + for value in verified_values { + let mut record = TrafficRecordInput::from_value(value)?; + if !record.has_fingerprints() { + continue; + } + vectorizer.vectorize(&mut record).await?; + record.threat_score = record + .raw + .get("threat_score") + .and_then(|v| v.as_i64()) + .unwrap_or(50) as i32; + verified_records.push(record); + } + + info!("Inserting {} verified records", verified_records.len()); + db.insert_batch(&verified_records)?; + + if let Some(dataset) = dataset_path { + let mut dataset_values = read_json_array(&dataset) + .with_context(|| format!("reading dataset from {}", dataset.display()))?; + let mut rng = rand::rng(); + dataset_values.shuffle(&mut rng); + + for value in dataset_values { + let mut record = TrafficRecordInput::from_value(value)?; + if record.verified || !record.has_fingerprints() { + continue; + } + + vectorizer.vectorize(&mut record).await?; + if let Some(vec) = record.ja4combined_vec.as_ref() { + let neighbors = db.search_similar(vec, &config.search)?; + record.threat_score = scorer.calculate_threat_score(&record, &neighbors); + } else { + record.threat_score = 50; + } + + if let Some(existing) = db.get_record_by_raw(&record.raw)? { + let mut updated = existing; + updated.observation_count = updated.observation_count.saturating_add(1); + updated.threat_score = record.threat_score; + db.update_record(&updated)?; + } else if let Err(err) = db.insert_record(&record) { + warn!("Insert failed, skipping record: {err:?}"); + } + } + } + + Ok(()) +} + +fn read_json_array(path: &PathBuf) -> Result> { + let contents = fs::read_to_string(path)?; + let value: Value = serde_json::from_str(&contents)?; + match value { + Value::Array(arr) => Ok(arr), + _ => Err(anyhow!( + "expected JSON array in {}", + path.display() + )), + } +} diff --git a/arxignis-main/services/threat-ai-rs/src/config.rs b/arxignis-main/services/threat-ai-rs/src/config.rs new file mode 100644 index 0000000..1f94613 --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/config.rs @@ -0,0 +1,78 @@ +use serde::Deserialize; +use std::{collections::HashMap, env, fs, path::Path}; + +const DEFAULT_BATCH_SIZE: usize = 100; +const DEFAULT_FLUSH_INTERVAL_SECS: u64 = 600; +const DEFAULT_POOL_MIN: usize = 1; +const DEFAULT_POOL_MAX: usize = 10; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub database_url: String, + #[serde(default = "default_batch_size")] + pub batch_size: usize, + #[serde(default = "default_flush_interval")] + pub flush_interval: u64, + #[serde(default = "default_pool_min")] + pub database_pool_min_size: usize, + #[serde(default = "default_pool_max")] + pub database_pool_max_size: usize, + #[serde(default)] + pub sqlite_vector_path: Option, + pub score: ScoreConfig, + pub search: SearchConfig, + pub vector: VectorConfig, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ScoreConfig { + #[serde(rename = "verified_weight")] + pub verified_weight: f32, + #[serde(default)] + pub field_weights: HashMap, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SearchConfig { + pub limit: usize, + pub similarity_threshold: f32, + pub observation_threshold: usize, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct VectorConfig { + pub model: Option, + #[serde(default)] + pub model_cache_dir: Option, + #[serde(default)] + pub vector_weights: HashMap, +} + +impl Config { + pub fn load(path: impl AsRef) -> anyhow::Result { + let contents = fs::read_to_string(path.as_ref())?; + let mut parsed: Config = serde_yaml::from_str(&contents)?; + + if let Ok(db_url) = env::var("DATABASE_URL") { + parsed.database_url = db_url; + } + + Ok(parsed) + } +} + +const fn default_batch_size() -> usize { + DEFAULT_BATCH_SIZE +} + +const fn default_flush_interval() -> u64 { + DEFAULT_FLUSH_INTERVAL_SECS +} + +const fn default_pool_min() -> usize { + DEFAULT_POOL_MIN +} + +const fn default_pool_max() -> usize { + DEFAULT_POOL_MAX +} diff --git a/arxignis-main/services/threat-ai-rs/src/database.rs b/arxignis-main/services/threat-ai-rs/src/database.rs new file mode 100644 index 0000000..43a4e76 --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/database.rs @@ -0,0 +1,482 @@ +use crate::config::SearchConfig; +use crate::models::{SimilarTrafficRecord, TrafficRecord, TrafficRecordInput, VECTOR_DIM}; +use anyhow::{Context, Result}; +use r2d2::{ManageConnection, Pool}; +use rusqlite::{params, Connection, OptionalExtension, Row}; +use serde_json::Value; +use std::path::{Path, PathBuf}; + +const VECTOR_COLUMNS: &[&str] = &[ + "ja4_vec", + "ja4s_vec", + "ja4h_vec", + "ja4x_vec", + "ja4t_vec", + "ja4ts_vec", + "ja4tscan_vec", + "ja4set_vec", + "ua_vec", + "ja4combined_vec", +]; + +#[derive(Clone)] +struct SqliteManager { + path: String, + extension_path: Option, +} + +impl SqliteManager { + fn new(path: String, extension_path: Option) -> Self { + Self { path, extension_path } + } +} + +impl ManageConnection for SqliteManager { + type Connection = Connection; + type Error = rusqlite::Error; + + fn connect(&self) -> std::result::Result { + let conn = Connection::open(&self.path)?; + if let Some(ext) = &self.extension_path { + unsafe { + conn.load_extension_enable()?; + conn.load_extension(ext, None)?; + conn.load_extension_disable()?; + } + } + Ok(conn) + } + + fn is_valid(&self, conn: &mut Connection) -> std::result::Result<(), Self::Error> { + conn.query_row("SELECT 1", [], |_| Ok(()))?; + Ok(()) + } + + fn has_broken(&self, _: &mut Connection) -> bool { + false + } +} + +#[derive(Clone)] +pub struct Database { + pool: Pool, +} + +impl Database { + pub fn new( + db_path: impl AsRef, + extension_path: Option, + pool_size: usize, + ) -> Result { + let manager = SqliteManager::new( + db_path.as_ref().to_string_lossy().to_string(), + extension_path, + ); + + let pool = Pool::builder() + .max_size(pool_size as u32) + .build(manager) + .context("building SQLite pool")?; + + let db = Self { pool }; + db.initialize()?; + Ok(db) + } + + fn initialize(&self) -> Result<()> { + let mut conn = self.pool.get().context("getting connection for init")?; + self.create_tables(&mut conn)?; + self.init_vector_columns(&mut conn)?; + // Best-effort quantize to ensure the quantization table exists even on empty DBs. + if let Err(err) = quantize_combined(&mut conn) { + tracing::warn!("Initial quantize skipped: {err}"); + } + Ok(()) + } + + fn create_tables(&self, conn: &mut Connection) -> Result<()> { + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS traffic ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + threat_score INTEGER NOT NULL, + observation_count INTEGER NOT NULL, + first_seen TEXT DEFAULT (datetime('now')), + last_seen TEXT DEFAULT (datetime('now')), + raw TEXT NOT NULL, + application TEXT, + library TEXT, + device TEXT, + os TEXT, + user_agent_string TEXT, + certificate_authority TEXT, + verified BOOLEAN, + ja4 TEXT, + ja4s TEXT, + ja4h TEXT, + ja4x TEXT, + ja4t TEXT, + ja4ts TEXT, + ja4tscan TEXT, + ja4_fingerprint_string TEXT, + ja4_vec BLOB, + ja4s_vec BLOB, + ja4h_vec BLOB, + ja4x_vec BLOB, + ja4t_vec BLOB, + ja4ts_vec BLOB, + ja4tscan_vec BLOB, + ja4set_vec BLOB, + ua_vec BLOB, + ja4combined_vec BLOB + ); + "#, + )?; + Ok(()) + } + + fn init_vector_columns(&self, conn: &mut Connection) -> Result<()> { + for name in VECTOR_COLUMNS { + conn.query_row( + "SELECT vector_init('traffic', ?, 'type=FLOAT32,dimension=128,distance=cosine');", + params![name], + |_| Ok(()), + ) + .with_context(|| format!("initializing vector column {name}"))?; + } + + let has_vectors: Option = conn + .query_row( + "SELECT 1 FROM traffic WHERE ja4combined_vec IS NOT NULL LIMIT 1;", + [], + |row| row.get(0), + ) + .optional()?; + + if has_vectors.is_some() { + quantize_combined(conn).context("quantizing ja4combined_vec")?; + } + Ok(()) + } + + pub fn insert_record(&self, record: &TrafficRecordInput) -> Result { + let conn = self.pool.get().context("acquiring connection")?; + let mut stmt = conn.prepare_cached( + r#" + INSERT INTO traffic ( + threat_score, observation_count, raw, + application, library, device, os, + user_agent_string, certificate_authority, + verified, + ja4, ja4s, ja4h, ja4x, ja4t, ja4ts, ja4tscan, + ja4_fingerprint_string, ja4_vec, ja4s_vec, + ja4h_vec, ja4x_vec, ja4t_vec, ja4ts_vec, + ja4tscan_vec, ja4set_vec, ua_vec, ja4combined_vec + ) + VALUES ( + :threat_score, :observation_count, :raw, + :application, :library, :device, :os, + :user_agent_string, :certificate_authority, + :verified, + :ja4, :ja4s, :ja4h, :ja4x, :ja4t, :ja4ts, :ja4tscan, + :ja4_fingerprint_string, + CASE WHEN :ja4_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4_vec) END, + CASE WHEN :ja4s_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4s_vec) END, + CASE WHEN :ja4h_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4h_vec) END, + CASE WHEN :ja4x_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4x_vec) END, + CASE WHEN :ja4t_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4t_vec) END, + CASE WHEN :ja4ts_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4ts_vec) END, + CASE WHEN :ja4tscan_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4tscan_vec) END, + CASE WHEN :ja4set_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4set_vec) END, + CASE WHEN :ua_vec IS NULL THEN NULL ELSE vector_as_f32(:ua_vec) END, + CASE WHEN :ja4combined_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4combined_vec) END + ) + RETURNING id; + "#, + )?; + + let id: i64 = stmt + .query_row( + rusqlite::named_params! { + ":threat_score": record.threat_score, + ":observation_count": record.observation_count, + ":raw": serde_json::to_string(&record.raw).unwrap_or_else(|_| "{}".to_string()), + ":application": record.application, + ":library": record.library, + ":device": record.device, + ":os": record.os, + ":user_agent_string": record.user_agent_string, + ":certificate_authority": record.certificate_authority, + ":verified": record.verified, + ":ja4": record.ja4, + ":ja4s": record.ja4s, + ":ja4h": record.ja4h, + ":ja4x": record.ja4x, + ":ja4t": record.ja4t, + ":ja4ts": record.ja4ts, + ":ja4tscan": record.ja4tscan, + ":ja4_fingerprint_string": record.ja4_fingerprint_string, + ":ja4_vec": record.ja4_vec.as_ref().map(to_json), + ":ja4s_vec": record.ja4s_vec.as_ref().map(to_json), + ":ja4h_vec": record.ja4h_vec.as_ref().map(to_json), + ":ja4x_vec": record.ja4x_vec.as_ref().map(to_json), + ":ja4t_vec": record.ja4t_vec.as_ref().map(to_json), + ":ja4ts_vec": record.ja4ts_vec.as_ref().map(to_json), + ":ja4tscan_vec": record.ja4tscan_vec.as_ref().map(to_json), + ":ja4set_vec": record.ja4set_vec.as_ref().map(to_json), + ":ua_vec": record.ua_vec.as_ref().map(to_json), + ":ja4combined_vec": record.ja4combined_vec.as_ref().map(to_json), + }, + |row| row.get(0), + ) + .context("inserting record")?; + + conn.query_row( + "SELECT vector_quantize('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + ) + .context("quantizing after insert")?; + + let _ = conn.query_row( + "SELECT vector_quantize_preload('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + ); + Ok(id) + } + + pub fn insert_batch(&self, records: &[TrafficRecordInput]) -> Result> { + let mut ids = Vec::with_capacity(records.len()); + let mut conn = self.pool.get().context("acquiring connection")?; + let tx = conn.transaction()?; + { + let mut stmt = tx.prepare( + r#" + INSERT INTO traffic ( + threat_score, observation_count, raw, + application, library, device, os, + user_agent_string, certificate_authority, + verified, + ja4, ja4s, ja4h, ja4x, ja4t, ja4ts, ja4tscan, + ja4_fingerprint_string, ja4_vec, ja4s_vec, + ja4h_vec, ja4x_vec, ja4t_vec, ja4ts_vec, + ja4tscan_vec, ja4set_vec, ua_vec, ja4combined_vec + ) + VALUES ( + :threat_score, :observation_count, :raw, + :application, :library, :device, :os, + :user_agent_string, :certificate_authority, + :verified, + :ja4, :ja4s, :ja4h, :ja4x, :ja4t, :ja4ts, :ja4tscan, + :ja4_fingerprint_string, + CASE WHEN :ja4_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4_vec) END, + CASE WHEN :ja4s_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4s_vec) END, + CASE WHEN :ja4h_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4h_vec) END, + CASE WHEN :ja4x_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4x_vec) END, + CASE WHEN :ja4t_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4t_vec) END, + CASE WHEN :ja4ts_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4ts_vec) END, + CASE WHEN :ja4tscan_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4tscan_vec) END, + CASE WHEN :ja4set_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4set_vec) END, + CASE WHEN :ua_vec IS NULL THEN NULL ELSE vector_as_f32(:ua_vec) END, + CASE WHEN :ja4combined_vec IS NULL THEN NULL ELSE vector_as_f32(:ja4combined_vec) END + ) + RETURNING id; + "#, + )?; + + for record in records { + let id: i64 = stmt.query_row( + rusqlite::named_params! { + ":threat_score": record.threat_score, + ":observation_count": record.observation_count, + ":raw": serde_json::to_string(&record.raw).unwrap_or_else(|_| "{}".to_string()), + ":application": record.application, + ":library": record.library, + ":device": record.device, + ":os": record.os, + ":user_agent_string": record.user_agent_string, + ":certificate_authority": record.certificate_authority, + ":verified": record.verified, + ":ja4": record.ja4, + ":ja4s": record.ja4s, + ":ja4h": record.ja4h, + ":ja4x": record.ja4x, + ":ja4t": record.ja4t, + ":ja4ts": record.ja4ts, + ":ja4tscan": record.ja4tscan, + ":ja4_fingerprint_string": record.ja4_fingerprint_string, + ":ja4_vec": record.ja4_vec.as_ref().map(to_json), + ":ja4s_vec": record.ja4s_vec.as_ref().map(to_json), + ":ja4h_vec": record.ja4h_vec.as_ref().map(to_json), + ":ja4x_vec": record.ja4x_vec.as_ref().map(to_json), + ":ja4t_vec": record.ja4t_vec.as_ref().map(to_json), + ":ja4ts_vec": record.ja4ts_vec.as_ref().map(to_json), + ":ja4tscan_vec": record.ja4tscan_vec.as_ref().map(to_json), + ":ja4set_vec": record.ja4set_vec.as_ref().map(to_json), + ":ua_vec": record.ua_vec.as_ref().map(to_json), + ":ja4combined_vec": record.ja4combined_vec.as_ref().map(to_json), + }, + |row| row.get(0), + )?; + ids.push(id); + } + } + + tx.query_row( + "SELECT vector_quantize('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + )?; + let _ = tx.query_row( + "SELECT vector_quantize_preload('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + ); + tx.commit()?; + Ok(ids) + } + + pub fn get_record_by_raw(&self, raw: &Value) -> Result> { + let conn = self.pool.get().context("acquiring connection")?; + let raw_json = serde_json::to_string(raw).unwrap_or_else(|_| "{}".to_string()); + let mut stmt = conn.prepare( + "SELECT * FROM traffic WHERE json(raw) = json(?) LIMIT 1;", + )?; + + let result = stmt.query_row([raw_json], |row| to_traffic_record(row)).optional()?; + Ok(result) + } + + pub fn update_record(&self, record: &TrafficRecord) -> Result<()> { + let conn = self.pool.get().context("acquiring connection")?; + let _ = conn.execute( + r#" + UPDATE traffic + SET threat_score = ?1, + observation_count = ?2, + last_seen = datetime('now') + WHERE id = ?3; + "#, + params![record.threat_score, record.observation_count, record.record_id], + )?; + Ok(()) + } + + pub fn search_similar( + &self, + query_vec: &[f32], + search: &SearchConfig, + ) -> Result> { + if query_vec.is_empty() { + return Ok(Vec::new()); + } + + let mut conn = self.pool.get().context("acquiring connection")?; + + // Ensure quantization table exists before scanning; ignore failures silently here. + let _ = quantize_combined(&mut conn); + + let query_json = serde_json::to_string(query_vec)?; + + let mut stmt = conn.prepare( + r#" + WITH candidates AS ( + SELECT + t.*, + q.distance, + 1.0 - q.distance AS similarity + FROM vector_quantize_scan('traffic', 'ja4combined_vec', vector_as_f32(?1), 200) AS q + JOIN traffic AS t ON t.rowid = q.rowid + WHERE t.ja4combined_vec IS NOT NULL + ) + SELECT * + FROM candidates + WHERE similarity >= ?2 + AND (verified = 1 OR observation_count >= ?3) + ORDER BY distance ASC, verified DESC, observation_count DESC + LIMIT ?4; + "#, + )?; + + let rows = stmt + .query_map( + params![ + query_json, + search.similarity_threshold, + search.observation_threshold as i64, + search.limit as i64 + ], + |row| { + let similarity: f32 = row.get("similarity")?; + let record = to_traffic_record(row)?; + Ok(SimilarTrafficRecord { record, similarity }) + }, + )? + .collect::, _>>()?; + + Ok(rows) + } +} + +fn quantize_combined(conn: &mut Connection) -> Result<()> { + conn.query_row( + "SELECT vector_quantize('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + ) + .context("vector_quantize failed")?; + + if let Err(err) = conn.query_row( + "SELECT vector_quantize_preload('traffic', 'ja4combined_vec');", + [], + |_| Ok(()), + ) { + tracing::warn!("vector_quantize_preload skipped: {err}"); + } + Ok(()) +} + +fn to_json(vec: &Vec) -> String { + serde_json::to_string(vec).unwrap_or_else(|_| "[]".to_string()) +} + +fn to_traffic_record(row: &Row<'_>) -> rusqlite::Result { + let ua_vec_blob: Option> = row.get("ua_vec")?; + let combined_blob: Option> = row.get("ja4combined_vec")?; + let raw_text: Option = row.get("raw")?; + + Ok(TrafficRecord { + record_id: row.get("id")?, + threat_score: row.get("threat_score")?, + observation_count: row.get::<_, i64>("observation_count")? as u32, + raw: raw_text + .and_then(|txt| serde_json::from_str(&txt).ok()), + application: row.get("application")?, + library: row.get("library")?, + device: row.get("device")?, + os: row.get("os")?, + user_agent_string: row.get("user_agent_string")?, + certificate_authority: row.get("certificate_authority")?, + verified: row.get::<_, i64>("verified")? == 1, + ja4: row.get("ja4")?, + ja4s: row.get("ja4s")?, + ja4h: row.get("ja4h")?, + ja4x: row.get("ja4x")?, + ja4t: row.get("ja4t")?, + ja4ts: row.get("ja4ts")?, + ja4tscan: row.get("ja4tscan")?, + ja4_fingerprint_string: row.get("ja4_fingerprint_string")?, + ua_vec: ua_vec_blob.as_deref().map(blob_to_f32_vec), + ja4combined_vec: combined_blob.as_deref().map(blob_to_f32_vec), + }) +} + +fn blob_to_f32_vec(blob: &[u8]) -> Vec { + let mut out = Vec::with_capacity(VECTOR_DIM); + for chunk in blob.chunks_exact(4) { + let bytes: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]]; + out.push(f32::from_le_bytes(bytes)); + } + out +} diff --git a/arxignis-main/services/threat-ai-rs/src/main.rs b/arxignis-main/services/threat-ai-rs/src/main.rs new file mode 100644 index 0000000..677de4e --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/main.rs @@ -0,0 +1,127 @@ +mod bootstrap; +mod config; +mod database; +mod models; +mod scoring; +mod service; +mod ua_embedding; +mod vectorization; + +use crate::bootstrap::bootstrap_database; +use crate::config::Config; +use crate::service::ThreatService; +use anyhow::Context; +use axum::{ + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, + Json, Router, +}; +use clap::{Parser, Subcommand}; +use serde::Serialize; +use serde_json::Value; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::net::TcpListener; +use tracing::{error, info}; + +#[derive(Parser, Debug)] +#[command(author, version, about)] +struct Cli { + /// Path to YAML config + #[arg(short, long, default_value = "config.yaml", global = true)] + config: PathBuf, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand, Debug)] +enum Command { + /// Start the HTTP scoring service + Serve { + /// Address to bind, e.g. 0.0.0.0:8000 + #[arg(long, default_value = "0.0.0.0:8000")] + listen: String, + }, + /// Bootstrap the SQLite database from JSON dumps + Bootstrap { + /// Path to verified_db.json + #[arg(long)] + verified: PathBuf, + /// Optional JA4+ dataset JSON + #[arg(long)] + data: Option, + }, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + init_tracing(); + + let cli = Cli::parse(); + let config = Config::load(&cli.config).context("loading config")?; + + match cli.command.unwrap_or(Command::Serve { + listen: "0.0.0.0:8000".to_string(), + }) { + Command::Serve { listen } => serve(config, listen).await?, + Command::Bootstrap { verified, data } => { + bootstrap_database(config, verified, data).await? + } + } + + Ok(()) +} + +async fn serve(config: Config, listen: String) -> anyhow::Result<()> { + let service = Arc::new(ThreatService::new(config)?); + service.warmup().await?; + let state = AppState { service }; + + let app = Router::new() + .route("/", get(root)) + .route("/score", post(score)) + .with_state(state); + + let addr: SocketAddr = listen.parse().context("parsing listen address")?; + info!("Starting Threat AI Rust service on {addr}"); + let listener = TcpListener::bind(addr).await?; + axum::serve(listener, app.into_make_service()).await?; + Ok(()) +} + +#[derive(Clone)] +struct AppState { + service: Arc, +} + +async fn root() -> impl IntoResponse { + Json(serde_json::json!({ "message": "Threat AI Rust", "status": "running" })) +} + +#[derive(Serialize)] +struct ScoreResponse { + threat_score: i32, +} + +async fn score( + State(state): State, + Json(payload): Json, +) -> Result, impl IntoResponse> { + match state.service.process_and_score(payload).await { + Ok(score) => Ok(Json(ScoreResponse { threat_score: score })), + Err(err) => { + error!("failed to score record: {err:?}"); + Err((StatusCode::INTERNAL_SERVER_ERROR, "scoring failed")) + } + } +} + +fn init_tracing() { + let env = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info".into()); + tracing_subscriber::fmt().with_env_filter(env).init(); +} diff --git a/arxignis-main/services/threat-ai-rs/src/models.rs b/arxignis-main/services/threat-ai-rs/src/models.rs new file mode 100644 index 0000000..d15722a --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/models.rs @@ -0,0 +1,122 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +pub const VECTOR_DIM: usize = 128; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TrafficRecordInput { + #[serde(skip)] + pub raw: Value, + #[serde(default = "default_observation_count")] + pub observation_count: u32, + #[serde(default)] + pub threat_score: i32, + #[serde(default)] + pub application: Option, + #[serde(default)] + pub library: Option, + #[serde(default)] + pub device: Option, + #[serde(default)] + pub os: Option, + #[serde(default)] + pub user_agent_string: Option, + #[serde(default)] + pub certificate_authority: Option, + #[serde(default)] + pub verified: bool, + #[serde(rename = "ja4_fingerprint", default)] + pub ja4: Option, + #[serde(rename = "ja4s_fingerprint", default)] + pub ja4s: Option, + #[serde(rename = "ja4h_fingerprint", default)] + pub ja4h: Option, + #[serde(rename = "ja4x_fingerprint", default)] + pub ja4x: Option, + #[serde(rename = "ja4t_fingerprint", default)] + pub ja4t: Option, + #[serde(rename = "ja4ts_fingerprint", default)] + pub ja4ts: Option, + #[serde(rename = "ja4tscan_fingerprint", default)] + pub ja4tscan: Option, + #[serde(default)] + pub ja4_fingerprint_string: Option, + #[serde(skip)] + pub ja4_vec: Option>, + #[serde(skip)] + pub ja4s_vec: Option>, + #[serde(skip)] + pub ja4h_vec: Option>, + #[serde(skip)] + pub ja4x_vec: Option>, + #[serde(skip)] + pub ja4t_vec: Option>, + #[serde(skip)] + pub ja4ts_vec: Option>, + #[serde(skip)] + pub ja4tscan_vec: Option>, + #[serde(skip)] + pub ja4set_vec: Option>, + #[serde(skip)] + pub ua_vec: Option>, + #[serde(skip)] + pub ja4combined_vec: Option>, +} + +impl TrafficRecordInput { + pub fn from_value(value: Value) -> anyhow::Result { + let mut record: TrafficRecordInput = serde_json::from_value(value.clone())?; + record.raw = value; + if record.observation_count == 0 { + record.observation_count = 1; + } + Ok(record) + } + + pub fn has_fingerprints(&self) -> bool { + self.user_agent_string.is_some() + || self.ja4.is_some() + || self.ja4s.is_some() + || self.ja4h.is_some() + || self.ja4x.is_some() + || self.ja4t.is_some() + || self.ja4ts.is_some() + || self.ja4tscan.is_some() + || self.ja4_fingerprint_string.is_some() + } +} + +#[derive(Debug, Clone)] +pub struct TrafficRecord { + pub record_id: i64, + pub threat_score: i32, + pub observation_count: u32, + pub raw: Option, + pub application: Option, + pub library: Option, + pub device: Option, + pub os: Option, + pub user_agent_string: Option, + pub certificate_authority: Option, + pub verified: bool, + pub ja4: Option, + pub ja4s: Option, + pub ja4h: Option, + pub ja4x: Option, + pub ja4t: Option, + pub ja4ts: Option, + pub ja4tscan: Option, + pub ja4_fingerprint_string: Option, + pub ua_vec: Option>, + pub ja4combined_vec: Option>, +} + +#[derive(Debug, Clone)] +pub struct SimilarTrafficRecord { + pub record: TrafficRecord, + pub similarity: f32, +} + +const fn default_observation_count() -> u32 { + 1 +} diff --git a/arxignis-main/services/threat-ai-rs/src/scoring.rs b/arxignis-main/services/threat-ai-rs/src/scoring.rs new file mode 100644 index 0000000..d76912d --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/scoring.rs @@ -0,0 +1,85 @@ +use crate::config::ScoreConfig; +use crate::models::{SimilarTrafficRecord, TrafficRecordInput}; + +pub struct Scorer { + config: ScoreConfig, +} + +impl Scorer { + pub fn new(config: ScoreConfig) -> Self { + Self { config } + } + + pub fn calculate_threat_score( + &self, + record: &TrafficRecordInput, + similar_records: &[SimilarTrafficRecord], + ) -> i32 { + if similar_records.is_empty() { + return 50; + } + + let mut weighted_scores = 0.0f32; + let mut weight_sum = 0.0f32; + + for neighbor in similar_records { + let base_weight = if neighbor.record.verified { + self.config.verified_weight + } else { + ((1.0 + neighbor.record.observation_count as f32).ln() / 100f32.ln()) + .max(0.0) + }; + + let weight = base_weight * neighbor.similarity; + weighted_scores += neighbor.record.threat_score as f32 * weight; + weight_sum += weight; + } + + if weight_sum == 0.0 { + return 50; + } + + let mut base_score = weighted_scores / weight_sum; + + let ua_sims: Vec = similar_records + .iter() + .filter_map(|r| { + if r.record.observation_count > 5 { + match (&record.ua_vec, &r.record.ua_vec) { + (Some(a), Some(b)) => Some(cosine_similarity(a, b)), + _ => None, + } + } else { + None + } + }) + .collect(); + + if !ua_sims.is_empty() { + let avg = ua_sims.iter().copied().sum::() / ua_sims.len() as f32; + if avg < 0.5 { + let boost = 1.0 + (0.3 * (0.5 - avg) / 0.5); + base_score *= boost; + } + } + + base_score = base_score.min(99.0); + base_score as i32 + } +} + +fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.is_empty() || b.is_empty() || a.len() != b.len() { + return 0.0; + } + + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if norm_a == 0.0 || norm_b == 0.0 { + 0.0 + } else { + dot / (norm_a * norm_b) + } +} diff --git a/arxignis-main/services/threat-ai-rs/src/service.rs b/arxignis-main/services/threat-ai-rs/src/service.rs new file mode 100644 index 0000000..27664a8 --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/service.rs @@ -0,0 +1,72 @@ +use crate::config::Config; +use crate::database::Database; +use crate::models::TrafficRecordInput; +use crate::scoring::Scorer; +use crate::vectorization::Vectorizer; +use anyhow::{Context, Result}; +use serde_json::Value; +use std::path::PathBuf; +use std::sync::Arc; + +pub struct ThreatService { + pub config: Arc, + db: Database, + vectorizer: Vectorizer, + scorer: Scorer, +} + +impl ThreatService { + pub fn new(config: Config) -> Result { + let db = Database::new( + &config.database_url, + config.sqlite_vector_path.as_ref().map(PathBuf::from), + config.database_pool_max_size, + )?; + + let vectorizer = Vectorizer::new(config.vector.clone())?; + let scorer = Scorer::new(config.score.clone()); + + Ok(Self { + config: Arc::new(config), + db, + vectorizer, + scorer, + }) + } + + pub async fn warmup(&self) -> Result<()> { + self.vectorizer.warmup().await?; + Ok(()) + } + + pub async fn process_and_score(&self, payload: Value) -> Result { + let mut record = TrafficRecordInput::from_value(payload)?; + if !record.has_fingerprints() { + return Ok(50); + } + + self.vectorizer.vectorize(&mut record).await?; + + let combined_vec = record.ja4combined_vec.clone(); + let neighbors = if let Some(vec) = combined_vec { + let db = self.db.clone(); + let search_cfg = self.config.search.clone(); + tokio::task::spawn_blocking(move || db.search_similar(&vec, &search_cfg)) + .await + .context("search task join error")?? + } else { + Vec::new() + }; + + let score = self.scorer.calculate_threat_score(&record, &neighbors); + record.threat_score = score; + + let db = self.db.clone(); + let record_for_insert = record.clone(); + tokio::task::spawn_blocking(move || db.insert_record(&record_for_insert)) + .await + .context("insert task join error")??; + + Ok(score) + } +} diff --git a/arxignis-main/services/threat-ai-rs/src/ua_embedding.rs b/arxignis-main/services/threat-ai-rs/src/ua_embedding.rs new file mode 100644 index 0000000..13a1669 --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/ua_embedding.rs @@ -0,0 +1,248 @@ +use anyhow::{anyhow, Context, Result}; +use bytemuck::cast_slice; +use half::f16; +use memmap2::Mmap; +use reqwest::Client; +use safetensors::{Dtype, SafeTensors}; +use std::path::PathBuf; +use std::{env, fs, sync::Arc}; +use tokenizers::Tokenizer; +use tokio::sync::OnceCell; + +const HF_BASE: &str = "https://huggingface.co"; + +const TOKENIZER_REL: &str = "0_StaticEmbedding/tokenizer.json"; +const WEIGHTS_REL: &str = "0_StaticEmbedding/model.safetensors"; + +#[derive(Clone)] +pub struct UaEmbedder { + client: Client, + model_id: String, + cache_dir: PathBuf, + truncate_dim: usize, + inner: Arc>, +} + +struct Loaded { + tokenizer: Tokenizer, + embeddings: EmbeddingTable, +} + +struct EmbeddingTable { + vocab_size: usize, + dim: usize, + data: Vec, // row-major: vocab_size * dim (dim == truncate_dim) +} + +impl UaEmbedder { + pub fn new(client: Client, model_id: String, cache_dir: PathBuf, truncate_dim: usize) -> Self { + Self { + client, + model_id, + cache_dir, + truncate_dim, + inner: Arc::new(OnceCell::new()), + } + } + + pub async fn warmup(&self) -> Result<()> { + self.load().await.map(|_| ()) + } + + pub async fn embed(&self, text: &str) -> Result> { + let loaded = self.load().await?; + + let encoding = loaded + .tokenizer + .encode(text, true) + .map_err(|e| anyhow!("tokenization failed: {e}"))?; + + let ids = encoding.get_ids(); + let mask = encoding.get_attention_mask(); + if ids.is_empty() || mask.is_empty() { + return Err(anyhow!("tokenizer produced empty input")); + } + + let mut out = vec![0.0f32; self.truncate_dim]; + let mut count = 0.0f32; + + for (&id_u32, &m) in ids.iter().zip(mask.iter()) { + if m == 0 { + continue; + } + let id = id_u32 as usize; + if let Some(row) = loaded.embeddings.row(id) { + for (dst, src) in out.iter_mut().zip(row.iter()) { + *dst += *src; + } + count += 1.0; + } + } + + if count > 0.0 { + for v in &mut out { + *v /= count; + } + } + + Ok(out) + } + + async fn load(&self) -> Result<&Loaded> { + self.inner + .get_or_try_init(|| async { + fs::create_dir_all(&self.cache_dir)?; + + let tokenizer_path = self.download_file(TOKENIZER_REL).await?; + let weights_path = self.download_file(WEIGHTS_REL).await?; + + let tokenizer = Tokenizer::from_file(&tokenizer_path) + .map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + + let embeddings = + load_embedding_table(&weights_path, self.truncate_dim).with_context(|| { + format!("loading StaticEmbedding weights from {}", weights_path.display()) + })?; + + Ok(Loaded { + tokenizer, + embeddings, + }) + }) + .await + } + + async fn download_file(&self, rel: &str) -> Result { + let safe_model_id = self.model_id.replace('/', "__"); + let dest = self.cache_dir.join(safe_model_id).join(rel); + if dest.exists() { + return Ok(dest); + } + + if let Some(parent) = dest.parent() { + fs::create_dir_all(parent)?; + } + + let url = format!("{HF_BASE}/{}/resolve/main/{}", self.model_id, rel); + let mut req = self.client.get(url); + if let Some(token) = read_hf_token() { + req = req.bearer_auth(token); + } + + let resp = req + .send() + .await + .context("downloading HF model artifact")? + .error_for_status() + .context("HF download returned error")?; + + let bytes = resp.bytes().await?; + tokio::fs::write(&dest, &bytes).await?; + Ok(dest) + } +} + +impl EmbeddingTable { + fn row(&self, token_id: usize) -> Option<&[f32]> { + if token_id >= self.vocab_size { + return None; + } + let start = token_id * self.dim; + let end = start + self.dim; + Some(&self.data[start..end]) + } +} + +fn load_embedding_table(path: &PathBuf, truncate_dim: usize) -> Result { + let file = fs::File::open(path)?; + let mmap = unsafe { Mmap::map(&file)? }; + let safetensors = SafeTensors::deserialize(&mmap)?; + + let mut chosen_name = None; + let mut chosen_shape = None; + for name in safetensors.names() { + let t = safetensors.tensor(name)?; + let shape = t.shape(); + if shape.len() == 2 { + chosen_name = Some(name.to_string()); + chosen_shape = Some((shape[0], shape[1])); + break; + } + } + + let name = chosen_name.ok_or_else(|| anyhow!("no 2D tensor found in safetensors"))?; + let (vocab_size, full_dim) = + chosen_shape.ok_or_else(|| anyhow!("missing tensor shape for {name}"))?; + + if full_dim == 0 || vocab_size == 0 { + return Err(anyhow!("invalid embedding tensor shape {vocab_size}x{full_dim}")); + } + if truncate_dim > full_dim { + return Err(anyhow!( + "truncate_dim {} exceeds model dim {}", + truncate_dim, + full_dim + )); + } + + let tensor = safetensors.tensor(&name)?; + let bytes = tensor.data(); + + let dim = truncate_dim; + let mut data = vec![0.0f32; vocab_size * dim]; + + match tensor.dtype() { + Dtype::F32 => { + let floats: &[f32] = cast_slice(bytes); + if floats.len() != vocab_size * full_dim { + return Err(anyhow!( + "unexpected f32 tensor length {}, expected {}", + floats.len(), + vocab_size * full_dim + )); + } + for row in 0..vocab_size { + let src = &floats[row * full_dim..row * full_dim + dim]; + let dst = &mut data[row * dim..row * dim + dim]; + dst.copy_from_slice(src); + } + } + Dtype::F16 => { + if bytes.len() != vocab_size * full_dim * 2 { + return Err(anyhow!( + "unexpected f16 tensor byte length {}, expected {}", + bytes.len(), + vocab_size * full_dim * 2 + )); + } + for row in 0..vocab_size { + let dst = &mut data[row * dim..row * dim + dim]; + for i in 0..dim { + let offset = (row * full_dim + i) * 2; + let b0 = bytes[offset]; + let b1 = bytes[offset + 1]; + let v = f16::from_bits(u16::from_le_bytes([b0, b1])); + dst[i] = v.to_f32(); + } + } + } + other => return Err(anyhow!("unsupported safetensors dtype {other:?}")), + } + + Ok(EmbeddingTable { + vocab_size, + dim, + data, + }) +} + +fn read_hf_token() -> Option { + for key in ["HF_TOKEN", "HUGGINGFACE_TOKEN", "UA_EMBEDDING_TOKEN"] { + if let Ok(v) = env::var(key) { + if !v.trim().is_empty() { + return Some(v); + } + } + } + None +} diff --git a/arxignis-main/services/threat-ai-rs/src/vectorization.rs b/arxignis-main/services/threat-ai-rs/src/vectorization.rs new file mode 100644 index 0000000..776cdab --- /dev/null +++ b/arxignis-main/services/threat-ai-rs/src/vectorization.rs @@ -0,0 +1,234 @@ +use crate::config::VectorConfig; +use crate::models::{TrafficRecordInput, VECTOR_DIM}; +use anyhow::{Context, Result}; +use reqwest::Client; +use std::collections::HashSet; +use tracing::warn; + +use crate::ua_embedding::UaEmbedder; + +const JA4_FIELDS: &[(&str, fn(&str) -> Vec)] = &[ + ("ja4_vec", minhash_ja4_field), + ("ja4s_vec", minhash_ja4_field), + ("ja4h_vec", minhash_ja4_field), + ("ja4x_vec", minhash_ja4_field), + ("ja4t_vec", minhash_ja4_field), + ("ja4ts_vec", minhash_ja4_field), + ("ja4tscan_vec", minhash_ja4_field), +]; + +pub struct Vectorizer { + config: VectorConfig, + ua_embedder: Option, +} + +impl Vectorizer { + pub fn new(config: VectorConfig) -> Result { + let client = Client::builder().build().context("building HTTP client")?; + let ua_embedder = match (&config.model, &config.model_cache_dir) { + (Some(model_id), Some(cache_dir)) => Some(UaEmbedder::new( + client.clone(), + model_id.clone(), + cache_dir.into(), + VECTOR_DIM, + )), + _ => None, + }; + + Ok(Self { + config, + ua_embedder, + }) + } + + pub async fn warmup(&self) -> Result<()> { + if let Some(ua) = &self.ua_embedder { + ua.warmup().await?; + } + Ok(()) + } + + pub async fn vectorize(&self, record: &mut TrafficRecordInput) -> Result<()> { + for (idx, (vec_name, func)) in JA4_FIELDS.iter().enumerate() { + let source = match idx { + 0 => &record.ja4, + 1 => &record.ja4s, + 2 => &record.ja4h, + 3 => &record.ja4x, + 4 => &record.ja4t, + 5 => &record.ja4ts, + _ => &record.ja4tscan, + }; + + if let Some(value) = source.as_ref().filter(|s| !s.trim().is_empty()) { + let vec = func(value); + set_vec_field(record, vec_name, vec); + } + } + + if let Some(fp) = record + .ja4_fingerprint_string + .as_ref() + .filter(|s| !s.trim().is_empty()) + { + record.ja4set_vec = Some(minhash_ja4_fingerprint(fp)); + } + + if let (Some(ua_embedder), Some(ua)) = (&self.ua_embedder, &record.user_agent_string) { + match ua_embedder.embed(ua).await { + Ok(v) => record.ua_vec = Some(v), + Err(err) => warn!("UA embedding failed, continuing without UA vec: {err:?}"), + }; + } + + record.ja4combined_vec = combine_vectors(&self.config, record); + Ok(()) + } +} + +fn set_vec_field(record: &mut TrafficRecordInput, name: &str, vec: Vec) { + match name { + "ja4_vec" => record.ja4_vec = Some(vec), + "ja4s_vec" => record.ja4s_vec = Some(vec), + "ja4h_vec" => record.ja4h_vec = Some(vec), + "ja4x_vec" => record.ja4x_vec = Some(vec), + "ja4t_vec" => record.ja4t_vec = Some(vec), + "ja4ts_vec" => record.ja4ts_vec = Some(vec), + "ja4tscan_vec" => record.ja4tscan_vec = Some(vec), + _ => {} + } +} + +fn minhash_ja4_field(value: &str) -> Vec { + if value.trim().is_empty() || value.eq_ignore_ascii_case("null") { + return vec![0.0; VECTOR_DIM]; + } + + let tokens: HashSet = value + .split('_') + .filter_map(|p| { + let trimmed = p.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) + }) + .collect(); + + minhash_tokens(tokens) +} + +fn minhash_ja4_fingerprint(value: &str) -> Vec { + if value.trim().is_empty() { + return vec![0.0; VECTOR_DIM]; + } + + let mut tokens = HashSet::new(); + for part in value.split('_') { + for token in part.split(',') { + let trimmed = token.trim(); + if !trimmed.is_empty() { + tokens.insert(trimmed.to_string()); + } + } + } + + minhash_tokens(tokens) +} + +fn minhash_tokens(tokens: HashSet) -> Vec { + if tokens.is_empty() { + return vec![0.0; VECTOR_DIM]; + } + + let mut mins = Vec::with_capacity(VECTOR_DIM); + for i in 0..VECTOR_DIM { + let mut min_hash: u128 = u128::MAX; + for token in &tokens { + let input = format!("{token}_{i}"); + let digest = md5::compute(input.as_bytes()); + let val = u128::from_be_bytes(digest.0); + if val < min_hash { + min_hash = val; + } + } + mins.push(min_hash); + } + + normalize_hashes(&mins) +} + +fn normalize_hashes(values: &[u128]) -> Vec { + let (&min_val, &max_val) = match (values.iter().min(), values.iter().max()) { + (Some(min), Some(max)) => (min, max), + _ => return vec![0.0; VECTOR_DIM], + }; + + if max_val == min_val { + return vec![0.5; VECTOR_DIM]; + } + + let range = (max_val - min_val) as f64; + values + .iter() + .map(|v| { + let normalized = (*v - min_val) as f64 / range; + let val = normalized as f32; + if val.is_finite() { + val + } else { + 0.0 + } + }) + .collect() +} + +fn combine_vectors(config: &VectorConfig, record: &TrafficRecordInput) -> Option> { + if config.vector_weights.is_empty() { + return None; + } + + let mut vectors = Vec::new(); + let mut weights = Vec::new(); + + for (name, weight) in &config.vector_weights { + let vec_opt = match name.as_str() { + "ja4_vec" => record.ja4_vec.as_ref(), + "ja4s_vec" => record.ja4s_vec.as_ref(), + "ja4h_vec" => record.ja4h_vec.as_ref(), + "ja4x_vec" => record.ja4x_vec.as_ref(), + "ja4t_vec" => record.ja4t_vec.as_ref(), + "ja4ts_vec" => record.ja4ts_vec.as_ref(), + "ja4tscan_vec" => record.ja4tscan_vec.as_ref(), + "ja4set_vec" => record.ja4set_vec.as_ref(), + "ua_vec" => record.ua_vec.as_ref(), + _ => { + warn!("Unknown vector weight key {name}, skipping"); + None + } + }; + + if let Some(vec) = vec_opt { + if vec.len() == VECTOR_DIM { + vectors.push(vec.clone()); + weights.push(*weight); + } + } + } + + if vectors.is_empty() { + None + } else { + let weight_sum: f32 = weights.iter().sum(); + let normalized: Vec = if weight_sum == 0.0 { + vec![1.0 / vectors.len() as f32; vectors.len()] + } else { + weights.iter().map(|w| w / weight_sum).collect() + }; + + let mut combined = vec![0.0f32; VECTOR_DIM]; + for (vec, w) in vectors.iter().zip(normalized.iter()) { + for (i, val) in vec.iter().enumerate() { + combined[i] += val * w; + } + } + Some(combined) + } +}