Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 147 additions & 12 deletions crates/core/src/observability/atof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,36 @@ pub enum AtofEndpointTransport {
Ndjson,
}

/// Field name transformation policy used before sending events to an endpoint.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AtofEndpointFieldNamePolicy {
/// Preserve canonical ATOF field names exactly.
#[default]
Preserve,
/// Replace dots in JSON object keys with underscores, recursively.
ReplaceDots,
}

impl AtofEndpointFieldNamePolicy {
/// Parse a string policy used by configuration and bindings.
pub fn parse(value: &str) -> Option<Self> {
match value {
"preserve" => Some(Self::Preserve),
"replace_dots" => Some(Self::ReplaceDots),
_ => None,
}
}

/// Return the stable string representation used by configuration and bindings.
pub fn as_str(self) -> &'static str {
match self {
Self::Preserve => "preserve",
Self::ReplaceDots => "replace_dots",
}
}
}

impl AtofEndpointTransport {
/// Parse a string transport used by configuration and bindings.
pub fn parse(value: &str) -> Option<Self> {
Expand Down Expand Up @@ -150,6 +180,9 @@ pub struct AtofEndpointConfig {
/// Per-endpoint timeout in milliseconds.
#[serde(default = "default_endpoint_timeout_millis")]
pub timeout_millis: u64,
/// Field name transformation policy applied before sending events.
#[serde(default)]
pub field_name_policy: AtofEndpointFieldNamePolicy,
}

impl AtofEndpointConfig {
Expand All @@ -160,6 +193,7 @@ impl AtofEndpointConfig {
transport,
headers: HashMap::new(),
timeout_millis: default_endpoint_timeout_millis(),
field_name_policy: AtofEndpointFieldNamePolicy::Preserve,
}
}

Expand All @@ -174,6 +208,15 @@ impl AtofEndpointConfig {
self.timeout_millis = timeout_millis;
self
}

/// Override the endpoint field name policy.
pub fn with_field_name_policy(
mut self,
field_name_policy: AtofEndpointFieldNamePolicy,
) -> Self {
self.field_name_policy = field_name_policy;
self
}
}

/// Configuration for [`AtofExporter`].
Expand Down Expand Up @@ -564,6 +607,7 @@ fn run_endpoint_worker(
config: AtofEndpointConfig,
rx: tokio::sync::mpsc::UnboundedReceiver<EndpointMessage>,
) {
install_rustls_crypto_provider();
let runtime = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
Expand All @@ -583,6 +627,11 @@ fn run_endpoint_worker(
});
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn install_rustls_crypto_provider() {
Comment thread
bbednarski9 marked this conversation as resolved.
let _ = rustls::crypto::ring::default_provider().install_default();
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
async fn run_http_post_endpoint(
index: usize,
Expand Down Expand Up @@ -611,7 +660,7 @@ async fn run_http_post_endpoint(
while let Some(message) = rx.recv().await {
match message {
EndpointMessage::Event(raw_json) => {
let body = format!("{raw_json}\n");
let body = format!("{}\n", endpoint_event_json(&config, raw_json));
let result = client
.post(&config.url)
.header(reqwest::header::CONTENT_TYPE, "application/x-ndjson")
Expand All @@ -620,10 +669,7 @@ async fn run_http_post_endpoint(
.await;
match result {
Ok(response) if response.status().is_success() => {}
Ok(response) => eprintln!(
"nemo_relay: ATOF endpoint[{index}] HTTP status {}",
response.status()
),
Ok(response) => log_http_error(index, "HTTP", response).await,
Err(error) => {
eprintln!("nemo_relay: ATOF endpoint[{index}] send failed: {error}")
}
Expand Down Expand Up @@ -657,7 +703,7 @@ async fn run_websocket_endpoint(
while let Some(message) = rx.recv().await {
match message {
EndpointMessage::Event(raw_json) => {
pending.push_back(raw_json);
pending.push_back(endpoint_event_json(&config, raw_json));
let _ = drain_websocket_pending(index, &config, &mut socket, &mut pending).await;
}
EndpointMessage::Flush(done) => {
Expand Down Expand Up @@ -788,9 +834,10 @@ async fn run_ndjson_endpoint(
};

let (body_tx, body) = ndjson_body_channel();
let url = config.url.clone();
let request = tokio::spawn(async move {
client
.post(config.url)
.post(url)
.header(reqwest::header::CONTENT_TYPE, "application/x-ndjson")
.body(body)
.send()
Expand All @@ -800,7 +847,9 @@ async fn run_ndjson_endpoint(

while let Some(message) = rx.recv().await {
match message {
EndpointMessage::Event(raw_json) => send_ndjson_event(index, &body_tx, raw_json),
EndpointMessage::Event(raw_json) => {
send_ndjson_event(index, &body_tx, endpoint_event_json(&config, raw_json))
}
EndpointMessage::Flush(done) => send_ndjson_flush(index, &body_tx, done),
EndpointMessage::Close(done) => {
drop(body_tx);
Expand Down Expand Up @@ -879,10 +928,7 @@ async fn finish_ndjson_upload(
) {
match tokio::time::timeout(close_timeout, request).await {
Ok(Ok(Ok(response))) if response.status().is_success() => {}
Ok(Ok(Ok(response))) => eprintln!(
"nemo_relay: ATOF endpoint[{index}] NDJSON HTTP status {}",
response.status()
),
Ok(Ok(Ok(response))) => log_http_error(index, "NDJSON HTTP", response).await,
Ok(Ok(Err(error))) => {
eprintln!("nemo_relay: ATOF endpoint[{index}] NDJSON upload failed: {error}")
}
Expand Down Expand Up @@ -910,6 +956,95 @@ async fn drain_closed(mut rx: tokio::sync::mpsc::UnboundedReceiver<EndpointMessa
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn endpoint_event_json(config: &AtofEndpointConfig, raw_json: String) -> String {
match config.field_name_policy {
AtofEndpointFieldNamePolicy::Preserve => raw_json,
AtofEndpointFieldNamePolicy::ReplaceDots => replace_dotted_field_names(&raw_json),
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn replace_dotted_field_names(raw_json: &str) -> String {
let Ok(mut value) = serde_json::from_str::<Json>(raw_json) else {
return raw_json.to_string();
};
replace_dotted_value_keys(&mut value);
serde_json::to_string(&value).unwrap_or_else(|_| raw_json.to_string())
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn replace_dotted_value_keys(value: &mut Json) {
match value {
Json::Object(object) => replace_dotted_object_keys(object),
Json::Array(items) => {
for item in items {
replace_dotted_value_keys(item);
}
}
_ => {}
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn replace_dotted_object_keys(object: &mut serde_json::Map<String, Json>) {
let mut old = std::mem::take(object)
.into_iter()
.map(|(key, mut value)| {
replace_dotted_value_keys(&mut value);
(key, value)
})
.collect::<Vec<_>>();
old.sort_by_key(|(key, _)| !key.contains('.'));

for (key, value) in old {
let sanitized_key = key.replace('.', "_");
let final_key = collision_free_key(object, sanitized_key);
object.insert(final_key, value);
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn collision_free_key(object: &serde_json::Map<String, Json>, key: String) -> String {
if !object.contains_key(&key) {
return key;
}
for suffix in 2.. {
let candidate = format!("{key}_{suffix}");
if !object.contains_key(&candidate) {
return candidate;
}
}
unreachable!("unbounded suffix search must find a key")
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
async fn log_http_error(index: usize, label: &str, response: reqwest::Response) {
let status = response.status();
match response.text().await {
Ok(body) if !body.trim().is_empty() => eprintln!(
"nemo_relay: ATOF endpoint[{index}] {label} status {status}: {}",
truncate_log_body(&body)
),
Ok(_) => eprintln!("nemo_relay: ATOF endpoint[{index}] {label} status {status}"),
Err(error) => eprintln!(
"nemo_relay: ATOF endpoint[{index}] {label} status {status}; failed to read response body: {error}"
),
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
fn truncate_log_body(body: &str) -> String {
const LIMIT: usize = 1024;
let trimmed = body.trim();
if trimmed.chars().count() <= LIMIT {
return trimmed.to_string();
}
let mut truncated = trimmed.chars().take(LIMIT).collect::<String>();
truncated.push_str("... <truncated>");
truncated
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
Expand Down
33 changes: 30 additions & 3 deletions crates/core/src/observability/plugin_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ use crate::api::subscriber::{
use crate::error::FlowError;
use crate::observability::atif::{AtifAgentInfo, AtifExporter};
use crate::observability::atof::{
AtofEndpointConfig as CoreAtofEndpointConfig, AtofEndpointTransport, AtofExporter,
AtofExporterConfig as CoreAtofExporterConfig, AtofExporterMode,
AtofEndpointConfig as CoreAtofEndpointConfig, AtofEndpointFieldNamePolicy,
AtofEndpointTransport, AtofExporter, AtofExporterConfig as CoreAtofExporterConfig,
AtofExporterMode,
};
#[cfg(feature = "openinference")]
use crate::observability::openinference::{
Expand Down Expand Up @@ -199,6 +200,9 @@ pub struct AtofEndpointSectionConfig {
/// Per-endpoint timeout in milliseconds.
#[serde(default = "default_timeout_millis")]
pub timeout_millis: u64,
/// Field name policy applied before sending events: `preserve` or `replace_dots`.
#[serde(default = "default_atof_endpoint_field_name_policy")]
pub field_name_policy: String,
}

/// Per-trajectory ATIF exporter config.
Expand Down Expand Up @@ -660,8 +664,15 @@ fn build_atof_endpoint_config(
"ATOF endpoints[{index}].transport must be 'http_post', 'websocket', or 'ndjson'"
))
})?;
let field_name_policy = AtofEndpointFieldNamePolicy::parse(&endpoint.field_name_policy)
.ok_or_else(|| {
PluginError::InvalidConfig(format!(
"ATOF endpoints[{index}].field_name_policy must be 'preserve' or 'replace_dots'"
))
})?;
let mut config = CoreAtofEndpointConfig::new(endpoint.url, transport)
.with_timeout_millis(endpoint.timeout_millis);
.with_timeout_millis(endpoint.timeout_millis)
.with_field_name_policy(field_name_policy);
for (key, value) in endpoint.headers {
config = config.with_header(key, value);
}
Expand Down Expand Up @@ -1848,6 +1859,18 @@ fn validate_atof_endpoint_values(
format!("ATOF endpoints[{index}].timeout_millis must be greater than 0"),
);
}
if AtofEndpointFieldNamePolicy::parse(&endpoint.field_name_policy).is_none() {
push_policy_diag(
diagnostics,
policy.unsupported_value,
"observability.unsupported_value",
Some("atof".to_string()),
Some(format!("endpoints[{index}].field_name_policy")),
format!(
"ATOF endpoints[{index}].field_name_policy must be 'preserve' or 'replace_dots'"
),
);
}
}

#[cfg(all(feature = "atof-streaming", not(target_arch = "wasm32")))]
Expand Down Expand Up @@ -2176,6 +2199,10 @@ fn default_atof_endpoint_transport() -> String {
"http_post".to_string()
}

fn default_atof_endpoint_field_name_policy() -> String {
"preserve".to_string()
}

fn default_agent_name() -> String {
"NeMo Relay".to_string()
}
Expand Down
Loading
Loading