Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions arxignis-main/services/threat-ai-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "threat-ai-rs"
version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1"
axum = { version = "0.8", features = ["json"] }
clap = { version = "4.5", features = ["derive"] }
md5 = "0.7"
r2d2 = "0.8"
rand = "0.9"
reqwest = { version = "0.12", features = ["json", "rustls-tls"] }
rusqlite = { version = "0.32", features = ["bundled", "load_extension"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
thiserror = "2"
tokio = { version = "1", features = ["macros", "rt-multi-thread", "time", "fs"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
tokenizers = "0.21"
safetensors = "0.4"
memmap2 = "0.9"
half = "2"
bytemuck = "1"
Binary file not shown.
92 changes: 92 additions & 0 deletions arxignis-main/services/threat-ai-rs/src/bootstrap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use crate::config::Config;
use crate::database::Database;
use crate::models::TrafficRecordInput;
use crate::scoring::Scorer;
use crate::vectorization::Vectorizer;
use anyhow::{anyhow, Context, Result};
use rand::seq::SliceRandom;
use serde_json::Value;
use std::fs;
use std::path::PathBuf;
use tracing::{info, warn};

pub async fn bootstrap_database(
config: Config,
verified_path: PathBuf,
dataset_path: Option<PathBuf>,
) -> Result<()> {
let db = Database::new(
&config.database_url,
config.sqlite_vector_path.as_ref().map(PathBuf::from),
config.database_pool_max_size,
)?;

let vectorizer = Vectorizer::new(config.vector.clone())?;
let scorer = Scorer::new(config.score.clone());

let verified_values = read_json_array(&verified_path)
.with_context(|| format!("reading verified data from {}", verified_path.display()))?;

let mut verified_records = Vec::new();
for value in verified_values {
let mut record = TrafficRecordInput::from_value(value)?;
if !record.has_fingerprints() {
continue;
}
vectorizer.vectorize(&mut record).await?;
record.threat_score = record
.raw
.get("threat_score")
.and_then(|v| v.as_i64())
.unwrap_or(50) as i32;
verified_records.push(record);
}

info!("Inserting {} verified records", verified_records.len());
db.insert_batch(&verified_records)?;

if let Some(dataset) = dataset_path {
let mut dataset_values = read_json_array(&dataset)
.with_context(|| format!("reading dataset from {}", dataset.display()))?;
let mut rng = rand::rng();
dataset_values.shuffle(&mut rng);

for value in dataset_values {
let mut record = TrafficRecordInput::from_value(value)?;
if record.verified || !record.has_fingerprints() {
continue;
}

vectorizer.vectorize(&mut record).await?;
if let Some(vec) = record.ja4combined_vec.as_ref() {
let neighbors = db.search_similar(vec, &config.search)?;
record.threat_score = scorer.calculate_threat_score(&record, &neighbors);
} else {
record.threat_score = 50;
}

if let Some(existing) = db.get_record_by_raw(&record.raw)? {
let mut updated = existing;
updated.observation_count = updated.observation_count.saturating_add(1);
updated.threat_score = record.threat_score;
db.update_record(&updated)?;
} else if let Err(err) = db.insert_record(&record) {
warn!("Insert failed, skipping record: {err:?}");
}
}
}

Ok(())
}

fn read_json_array(path: &PathBuf) -> Result<Vec<Value>> {
let contents = fs::read_to_string(path)?;
let value: Value = serde_json::from_str(&contents)?;
match value {
Value::Array(arr) => Ok(arr),
_ => Err(anyhow!(
"expected JSON array in {}",
path.display()
)),
}
}
78 changes: 78 additions & 0 deletions arxignis-main/services/threat-ai-rs/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use serde::Deserialize;
use std::{collections::HashMap, env, fs, path::Path};

const DEFAULT_BATCH_SIZE: usize = 100;
const DEFAULT_FLUSH_INTERVAL_SECS: u64 = 600;
const DEFAULT_POOL_MIN: usize = 1;
const DEFAULT_POOL_MAX: usize = 10;

#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub database_url: String,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_flush_interval")]
pub flush_interval: u64,
#[serde(default = "default_pool_min")]
pub database_pool_min_size: usize,
#[serde(default = "default_pool_max")]
pub database_pool_max_size: usize,
#[serde(default)]
pub sqlite_vector_path: Option<String>,
pub score: ScoreConfig,
pub search: SearchConfig,
pub vector: VectorConfig,
}

#[derive(Debug, Clone, Deserialize)]
pub struct ScoreConfig {
#[serde(rename = "verified_weight")]
pub verified_weight: f32,
#[serde(default)]
pub field_weights: HashMap<String, f32>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct SearchConfig {
pub limit: usize,
pub similarity_threshold: f32,
pub observation_threshold: usize,
}

#[derive(Debug, Clone, Deserialize)]
pub struct VectorConfig {
pub model: Option<String>,
#[serde(default)]
pub model_cache_dir: Option<String>,
#[serde(default)]
pub vector_weights: HashMap<String, f32>,
}

impl Config {
pub fn load(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let contents = fs::read_to_string(path.as_ref())?;
let mut parsed: Config = serde_yaml::from_str(&contents)?;

if let Ok(db_url) = env::var("DATABASE_URL") {
parsed.database_url = db_url;
}

Ok(parsed)
}
}

const fn default_batch_size() -> usize {
DEFAULT_BATCH_SIZE
}

const fn default_flush_interval() -> u64 {
DEFAULT_FLUSH_INTERVAL_SECS
}

const fn default_pool_min() -> usize {
DEFAULT_POOL_MIN
}

const fn default_pool_max() -> usize {
DEFAULT_POOL_MAX
}
Loading