diff --git a/.github/workflows/node.yaml b/.github/workflows/node.yaml index 4d409652..eb6a6a64 100644 --- a/.github/workflows/node.yaml +++ b/.github/workflows/node.yaml @@ -111,7 +111,7 @@ jobs: yarn build --target aarch64-unknown-linux-musl /aarch64-linux-musl-cross/bin/aarch64-linux-musl-strip *.node - name: stable - ${{ matrix.settings.target }} - node@18 + name: stable - ${{ matrix.settings.target }} - node@22 runs-on: ${{ matrix.settings.host }} defaults: run: @@ -124,7 +124,7 @@ jobs: uses: actions/setup-node@v3 if: ${{ !matrix.settings.docker }} with: - node-version: 18 + node-version: 22 check-latest: true cache: yarn cache-dependency-path: 'bindings/nodejs/yarn.lock' @@ -191,7 +191,7 @@ jobs: target: 'aarch64-apple-darwin' - host: windows-latest target: 'x86_64-pc-windows-msvc' - node: [ '18', '20', '22' ] + node: [ '20', '22' ] runs-on: ${{ matrix.settings.host }} steps: @@ -235,7 +235,7 @@ jobs: strategy: fail-fast: false matrix: - node: [ '18', '20', '22' ] + node: [ '20', '22' ] runs-on: ubuntu-latest steps: @@ -286,7 +286,7 @@ jobs: strategy: fail-fast: false matrix: - node: [ '18', '20', '22' ] + node: [ '20', '22' ] runs-on: ubuntu-latest steps: @@ -337,7 +337,7 @@ jobs: strategy: fail-fast: false matrix: - node: [ '18', '20', '22' ] + node: [ '20', '22' ] runs-on: ubuntu-latest steps: @@ -457,7 +457,7 @@ jobs: - name: Setup node uses: actions/setup-node@v3 with: - node-version: 18 + node-version: 22 check-latest: true cache: yarn cache-dependency-path: 'bindings/nodejs/yarn.lock' diff --git a/bindings/c/src/custom_node.rs b/bindings/c/src/custom_node.rs index 26140484..e98e2758 100644 --- a/bindings/c/src/custom_node.rs +++ b/bindings/c/src/custom_node.rs @@ -49,11 +49,11 @@ impl ZenCustomNodeResult { if let Some(c_error) = maybe_error { let maybe_str = c_error.to_str().unwrap_or("unknown error"); - return Err(anyhow!("{maybe_str}")); + return Err(anyhow!("{maybe_str}").into()); } if self.content.is_null() { - return Err(anyhow!("response not provided")); + return Err(anyhow!("response not provided").into()); } let content_cstr = unsafe { CString::from_raw(self.content) }; diff --git a/bindings/c/src/languages/go.rs b/bindings/c/src/languages/go.rs index 8d48e4bd..1a0541d3 100644 --- a/bindings/c/src/languages/go.rs +++ b/bindings/c/src/languages/go.rs @@ -52,11 +52,11 @@ impl GoCustomNode { impl CustomNodeAdapter for GoCustomNode { async fn handle(&self, request: CustomNodeRequest) -> NodeResult { let Some(handler) = self.handler else { - return Err(anyhow!("go handler not found")); + return Err(anyhow!("go handler not found").into()); }; let Ok(request_value) = serde_json::to_string(&request) else { - return Err(anyhow!("failed to serialize request json")); + return Err(anyhow!("failed to serialize request json").into()); }; let c_request = unsafe { CString::from_vec_unchecked(request_value.into_bytes()) }; diff --git a/bindings/c/src/languages/native.rs b/bindings/c/src/languages/native.rs index e2a7368e..93de29ec 100644 --- a/bindings/c/src/languages/native.rs +++ b/bindings/c/src/languages/native.rs @@ -51,7 +51,7 @@ impl NativeCustomNode { impl CustomNodeAdapter for NativeCustomNode { async fn handle(&self, request: CustomNodeRequest) -> NodeResult { let Ok(request_value) = serde_json::to_string(&request) else { - return Err(anyhow!("failed to serialize request json")); + return Err(anyhow!("failed to serialize request json").into()); }; let c_request = unsafe { CString::from_vec_unchecked(request_value.into_bytes()) }; diff --git a/bindings/c/src/result.rs b/bindings/c/src/result.rs index 4ea82d8c..47644f80 100644 --- a/bindings/c/src/result.rs +++ b/bindings/c/src/result.rs @@ -71,9 +71,9 @@ impl From<&IsolateError> for ZenResult { } } -impl From<&Box> for ZenResult { - fn from(loader_error: &Box) -> Self { - match loader_error.as_ref() { +impl From<&LoaderError> for ZenResult { + fn from(loader_error: &LoaderError) -> Self { + match loader_error { LoaderError::NotFound(key) => { ZenResult::error(ZenError::LoaderKeyNotFound { key: key.clone() }) } diff --git a/bindings/nodejs/index.d.ts b/bindings/nodejs/index.d.ts index 0109f91e..111a624e 100644 --- a/bindings/nodejs/index.d.ts +++ b/bindings/nodejs/index.d.ts @@ -5,22 +5,23 @@ export interface ZenConfig { nodesInContext?: boolean + functionTimeoutMillis?: number } -export function overrideConfig(config: ZenConfig): void +export declare function overrideConfig(config: ZenConfig): void export interface ZenEvaluateOptions { maxDepth?: number - trace?: boolean + trace?: boolean | 'string' | 'reference' | 'referenceString' } export interface ZenEngineOptions { loader?: (key: string) => Promise customHandler?: (request: ZenEngineHandlerRequest) => Promise } -export function evaluateExpressionSync(expression: string, context?: any | undefined | null): any -export function evaluateUnaryExpressionSync(expression: string, context: any): boolean -export function renderTemplateSync(template: string, context: any): any -export function evaluateExpression(expression: string, context?: any | undefined | null): Promise -export function evaluateUnaryExpression(expression: string, context: any): Promise -export function renderTemplate(template: string, context: any): Promise +export declare function evaluateExpressionSync(expression: string, context?: any | undefined | null): any +export declare function evaluateUnaryExpressionSync(expression: string, context: any): boolean +export declare function renderTemplateSync(template: string, context: any): any +export declare function evaluateExpression(expression: string, context?: any | undefined | null): Promise +export declare function evaluateUnaryExpression(expression: string, context: any): Promise +export declare function renderTemplate(template: string, context: any): Promise export interface ZenEngineTrace { id: string name: string @@ -45,17 +46,17 @@ export interface DecisionNode { kind: string config: any } -export class ZenDecisionContent { +export declare class ZenDecisionContent { constructor(content: Buffer | object) toBuffer(): Buffer } -export class ZenDecision { +export declare class ZenDecision { constructor() evaluate(context: any, opts?: ZenEvaluateOptions | undefined | null): Promise safeEvaluate(context: any, opts?: ZenEvaluateOptions | undefined | null): Promise> validate(): void } -export class ZenEngine { +export declare class ZenEngine { constructor(options?: ZenEngineOptions | undefined | null) evaluate(key: string, context: any, opts?: ZenEvaluateOptions | undefined | null): Promise createDecision(content: ZenDecisionContent | Buffer | object): ZenDecision @@ -68,23 +69,10 @@ export class ZenEngine { */ dispose(): void } -export class ZenEngineHandlerRequest { +export declare class ZenEngineHandlerRequest { input: any node: DecisionNode constructor() getField(path: string): unknown getFieldRaw(path: string): unknown } - -// Custom definitions -type SafeResultSuccess = { - success: true; - data: T; -} - -type SafeResultError = { - success: false; - error: any; -} - -export type SafeResult = SafeResultSuccess | SafeResultError; \ No newline at end of file diff --git a/bindings/nodejs/package.json b/bindings/nodejs/package.json index 32ecffc3..d96112f7 100644 --- a/bindings/nodejs/package.json +++ b/bindings/nodejs/package.json @@ -55,6 +55,7 @@ "@types/express": "^5.0.1", "@types/node": "^22.14.1", "babel-jest": "^29.7.0", + "cross-env": "^10.0.0", "express": "^5.1.0", "jest": "^29.7.0", "lerna": "6", @@ -78,7 +79,7 @@ "build": "napi build --dts temp.d.ts --platform --release", "build:debug": "napi build --platform --js index.js --dts index.d.ts", "watch": "cargo watch --ignore '{index.js,index.d.ts}' -- npm run build:debug", - "test": "jest", + "test": "cross-env __ZEN_MOCK_UTC_TIME=2025-08-19T16:55:02.078Z jest", "artifacts": "napi artifacts -d ../../artifacts", "prepublishOnly": "napi prepublish", "version": "napi version" diff --git a/bindings/nodejs/src/custom_node.rs b/bindings/nodejs/src/custom_node.rs index dedec446..b3d11060 100644 --- a/bindings/nodejs/src/custom_node.rs +++ b/bindings/nodejs/src/custom_node.rs @@ -2,10 +2,10 @@ use napi::anyhow::anyhow; use napi::bindgen_prelude::Promise; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction}; +use crate::types::{ZenEngineHandlerRequest, ZenEngineHandlerResponse}; use zen_engine::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest}; use zen_engine::handler::node::{NodeResponse, NodeResult}; - -use crate::types::{ZenEngineHandlerRequest, ZenEngineHandlerResponse}; +use zen_engine::Variable; #[derive(Default)] pub(crate) struct CustomNode { @@ -23,7 +23,7 @@ impl CustomNode { impl CustomNodeAdapter for CustomNode { async fn handle(&self, request: CustomNodeRequest) -> NodeResult { let Some(function) = &self.function else { - return Err(anyhow!("Custom function is undefined")); + return Err(anyhow!("Custom function is undefined").into()); }; let node_data = crate::types::DecisionNode::from(request.node); @@ -41,7 +41,7 @@ impl CustomNodeAdapter for CustomNode { Ok(NodeResponse { output: result.output.into(), - trace_data: result.trace_data, + trace_data: result.trace_data.map(Variable::from), }) } } diff --git a/bindings/nodejs/src/decision.rs b/bindings/nodejs/src/decision.rs index 0e7a0a18..98cc1924 100644 --- a/bindings/nodejs/src/decision.rs +++ b/bindings/nodejs/src/decision.rs @@ -3,12 +3,11 @@ use crate::engine::ZenEvaluateOptions; use crate::loader::DecisionLoader; use crate::mt::spawn_worker; use crate::safe_result::SafeResult; -use crate::types::ZenEngineResponse; use napi::anyhow::anyhow; use napi_derive::napi; use serde_json::Value; use std::sync::Arc; -use zen_engine::{Decision, EvaluationOptions}; +use zen_engine::{Decision, EvaluationSerializedOptions}; #[napi] pub struct ZenDecision(pub(crate) Arc>); @@ -26,34 +25,31 @@ impl ZenDecision { Err(anyhow!("Private constructor").into()) } - #[napi] + #[napi(ts_return_type = "Promise")] pub async fn evaluate( &self, context: Value, opts: Option, - ) -> napi::Result { + ) -> napi::Result { let decision = self.0.clone(); let result = spawn_worker(move || { let options = opts.unwrap_or_default(); async move { decision - .evaluate_with_opts( + .evaluate_serialized( context.into(), - EvaluationOptions { + EvaluationSerializedOptions { max_depth: options.max_depth, - trace: options.trace, + trace: options.trace.unwrap_or_default().0, }, ) .await - .map(ZenEngineResponse::from) } }) .await .map_err(|_| anyhow!("Hook timed out"))? - .map_err(|e| { - anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string())) - })?; + .map_err(|e| anyhow!(e))?; Ok(result) } @@ -63,7 +59,7 @@ impl ZenDecision { &self, context: Value, opts: Option, - ) -> SafeResult { + ) -> SafeResult { self.evaluate(context, opts).await.into() } diff --git a/bindings/nodejs/src/engine.rs b/bindings/nodejs/src/engine.rs index 2b2cd702..88961d9c 100644 --- a/bindings/nodejs/src/engine.rs +++ b/bindings/nodejs/src/engine.rs @@ -1,22 +1,23 @@ +use std::str::FromStr; use std::sync::Arc; use napi::anyhow::{anyhow, Context}; -use napi::bindgen_prelude::{Buffer, Either3}; +use napi::bindgen_prelude::{Buffer, Either3, FromNapiValue, ToNapiValue}; +use napi::sys::{napi_env, napi_value}; use napi::threadsafe_function::{ErrorStrategy, ThreadSafeCallContext, ThreadsafeFunction}; -use napi::{Env, JsFunction, JsObject}; +use napi::{Env, JsFunction, JsObject, JsUnknown, NapiValue, ValueType}; use napi_derive::napi; use serde_json::Value; -use zen_engine::model::DecisionContent; -use zen_engine::{DecisionEngine, EvaluationOptions}; - use crate::content::ZenDecisionContent; use crate::custom_node::CustomNode; use crate::decision::ZenDecision; use crate::loader::DecisionLoader; use crate::mt::spawn_worker; use crate::safe_result::SafeResult; -use crate::types::{ZenEngineHandlerRequest, ZenEngineResponse}; +use crate::types::ZenEngineHandlerRequest; +use zen_engine::model::DecisionContent; +use zen_engine::{DecisionEngine, EvaluationSerializedOptions, EvaluationTraceKind}; #[napi] pub struct ZenEngine { @@ -26,17 +27,63 @@ pub struct ZenEngine { custom_handler_ref: Option>, } +#[derive(Debug, Default)] +pub struct JsEvaluationTraceKind(pub EvaluationTraceKind); + +impl FromNapiValue for JsEvaluationTraceKind { + unsafe fn from_napi_value(env: napi_env, napi_val: napi_value) -> napi::Result { + let js_value = JsUnknown::from_raw(env, napi_val)?; + + match js_value.get_type()? { + ValueType::Undefined | ValueType::Null => Ok(JsEvaluationTraceKind::default()), + ValueType::Boolean => { + let enabled = js_value.coerce_to_bool()?.get_value()?; + let kind = match enabled { + true => EvaluationTraceKind::Default, + false => EvaluationTraceKind::None, + }; + + Ok(JsEvaluationTraceKind(kind)) + } + ValueType::String => { + let kind_utf8 = js_value.coerce_to_string()?.into_utf8()?; + let kind_str = kind_utf8.as_str()?; + let kind = + EvaluationTraceKind::from_str(kind_str).context("invalid evaluation mode")?; + + Ok(JsEvaluationTraceKind(kind)) + } + _ => Err(anyhow!("Invalid trace setting").into()), + } + } +} + +impl ToNapiValue for JsEvaluationTraceKind { + unsafe fn to_napi_value(env: napi_env, val: Self) -> napi::Result { + match val.0 { + EvaluationTraceKind::None => ToNapiValue::to_napi_value(env, false), + EvaluationTraceKind::Default => ToNapiValue::to_napi_value(env, true), + _ => { + let mode_str: &'static str = val.0.into(); + ToNapiValue::to_napi_value(env, mode_str) + } + } + } +} + +#[derive(Debug)] #[napi(object)] pub struct ZenEvaluateOptions { pub max_depth: Option, - pub trace: Option, + #[napi(ts_type = "boolean | 'string' | 'reference' | 'referenceString'")] + pub trace: Option, } impl Default for ZenEvaluateOptions { fn default() -> Self { Self { max_depth: Some(5), - trace: Some(false), + trace: Some(JsEvaluationTraceKind::default()), } } } @@ -103,36 +150,33 @@ impl ZenEngine { }) } - #[napi] + #[napi(ts_return_type = "Promise")] pub async fn evaluate( &self, key: String, context: Value, opts: Option, - ) -> napi::Result { + ) -> napi::Result { let graph = self.graph.clone(); let result = spawn_worker(|| { let options = opts.unwrap_or_default(); async move { graph - .evaluate_with_opts( + .evaluate_serialized( key, context.into(), - EvaluationOptions { + EvaluationSerializedOptions { max_depth: options.max_depth, - trace: options.trace, + trace: options.trace.unwrap_or_default().0, }, ) .await - .map(ZenEngineResponse::from) } }) .await .map_err(|_| anyhow!("Hook timed out"))? - .map_err(|e| { - anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string())) - })?; + .map_err(|e| anyhow!(e))?; Ok(result) } @@ -173,7 +217,7 @@ impl ZenEngine { key: String, context: Value, opts: Option, - ) -> SafeResult { + ) -> SafeResult { self.evaluate(key, context, opts).await.into() } diff --git a/bindings/nodejs/src/types.rs b/bindings/nodejs/src/types.rs index c55826ef..14a423ce 100644 --- a/bindings/nodejs/src/types.rs +++ b/bindings/nodejs/src/types.rs @@ -9,6 +9,7 @@ use zen_engine::handler::custom_node_adapter::CustomDecisionNode; use zen_engine::{DecisionGraphResponse, DecisionGraphTrace}; use zen_expression::Variable; +#[allow(dead_code)] #[napi(object)] pub struct ZenEngineTrace { pub id: String, @@ -28,12 +29,13 @@ impl From for ZenEngineTrace { input: value.input.to_value(), output: value.output.to_value(), performance: value.performance, - trace_data: value.trace_data, + trace_data: value.trace_data.map(Value::from), order: value.order, } } } +#[allow(dead_code)] #[napi(object)] pub struct ZenEngineResponse { pub performance: String, diff --git a/bindings/nodejs/test/decision.spec.ts b/bindings/nodejs/test/decision.spec.ts index 7b746449..1dccdec8 100644 --- a/bindings/nodejs/test/decision.spec.ts +++ b/bindings/nodejs/test/decision.spec.ts @@ -4,59 +4,28 @@ import { evaluateUnaryExpression, renderTemplate, evaluateExpressionSync, - evaluateUnaryExpressionSync, renderTemplateSync, ZenDecisionContent -} from "../index"; + evaluateUnaryExpressionSync, renderTemplateSync, ZenDecisionContent, +} from '../index'; import fs from 'fs/promises'; import path from 'path'; -import {describe, expect, it, jest} from "@jest/globals"; -import assert from "assert"; +import { describe, expect, it, jest } from '@jest/globals'; +import assert from 'assert'; const testDataRoot = path.join(__dirname, '../../../', 'test-data'); -const loader = async (key: string) => fs.readFile(path.join(testDataRoot, key)) +const loader = async (key: string) => fs.readFile(path.join(testDataRoot, key)); jest.useRealTimers(); -interface PropertyMatcher { - [key: string]: any; -} - -const defaultMatchers: PropertyMatcher = { - timestamp: expect.any(Number), - estimatedArrival: expect.any(Number), - approvalDate: expect.any(Number), -}; - -function addJestMatchers(obj: any, matchers: PropertyMatcher = defaultMatchers): any { - if (obj === null || typeof obj !== 'object') { - return obj; - } - - if (Array.isArray(obj)) { - return obj.map((item: any) => addJestMatchers(item, matchers)); - } - - const result: Record = {}; - for (const [key, value] of Object.entries(obj)) { - if (matchers[key]) { - result[key] = matchers[key]; - } else { - result[key] = addJestMatchers(value, matchers); - } - } - - return result; -} - describe('ZenEngine', () => { it('Evaluates decisions using loader', async () => { const engine = new ZenEngine({ - loader + loader, }); - const r1 = await engine.evaluate('function.json', {input: 5}); - const r2 = await engine.evaluate('table.json', {input: 2}); - const r3 = await engine.evaluate('table.json', {input: 12}); + const r1 = await engine.evaluate('function.json', { input: 5 }); + const r2 = await engine.evaluate('table.json', { input: 2 }); + const r3 = await engine.evaluate('table.json', { input: 12 }); expect(r1.result.output).toEqual(10); expect(r2.result.output).toEqual(0); @@ -73,9 +42,9 @@ describe('ZenEngine', () => { const functionDecision = await engine.getDecision('function.json'); const tableDecision = await engine.getDecision('table.json'); - const r1 = await functionDecision.evaluate({input: 10}); - const r2 = await tableDecision.evaluate({input: 5}); - const r3 = await tableDecision.evaluate({input: 12}); + const r1 = await functionDecision.evaluate({ input: 10 }); + const r2 = await tableDecision.evaluate({ input: 5 }); + const r3 = await tableDecision.evaluate({ input: 12 }); expect(r1.result.output).toEqual(20); expect(r2.result.output).toEqual(0); @@ -89,9 +58,9 @@ describe('ZenEngine', () => { const functionContent = await fs.readFile(path.join(testDataRoot, 'function.json')); const functionDecision = engine.createDecision(functionContent); - const r = await functionDecision.evaluate({input: 15}); + const r = await functionDecision.evaluate({ input: 15 }); expect(r.result.output).toEqual(30); - }, 10000) + }, 10000); it('Evaluate custom nodes with a handler', async () => { const engine = new ZenEngine({ @@ -101,20 +70,20 @@ describe('ZenEngine', () => { const prop1Raw = request.getFieldRaw('prop1'); expect(prop1).toEqual(15); - expect(prop1Raw).toEqual('{{ a + 10 }}') + expect(prop1Raw).toEqual('{{ a + 10 }}'); expect(request.node).toMatchObject({ id: '138b3b11-ff46-450f-9704-3f3c712067b2', name: 'customNode1', kind: 'sum', config: { - prop1: '{{ a + 10 }}' - } + prop1: '{{ a + 10 }}', + }, }); - return {output: {data: prop1 + 10}} - } + return { output: { data: prop1 + 10 } }; + }, }); - const r = await engine.evaluate('custom.json', {a: 5}); + const r = await engine.evaluate('custom.json', { a: 5 }); expect(r.result.data).toEqual(25); engine.dispose(); @@ -140,7 +109,7 @@ describe('ZenEngine', () => { const graphsRoot = path.join(testDataRoot, 'graphs'); const loader = async (key: string) => fs.readFile(path.join(graphsRoot, key)); - const engine = new ZenEngine({loader}); + const engine = new ZenEngine({ loader }); const entries = await fs.readdir(graphsRoot); for (const entry of entries) { @@ -159,28 +128,27 @@ describe('ZenEngine', () => { assert.ok(engineResponse.success, 'Engine response must be ok'); assert.ok(decisionResponse.success, 'Decision response must be ok'); - const expectedObject = addJestMatchers(testCase.output); - expect(engineResponse.data.result).toMatchObject(expectedObject); - expect(decisionResponse.data.result).toMatchObject(expectedObject); + expect(engineResponse.data.result).toMatchObject(testCase.output); + expect(decisionResponse.data.result).toMatchObject(testCase.output); } } engine.dispose(); - }) -}) + }); +}); describe('Expressions', () => { it('Evaluates standard expressions', async () => { const expressions = [ - {expression: '1 + 1', result: 2}, - {expression: 'a > b', context: {a: 5, b: 3}, result: true}, - {expression: 'sum(a)', context: {a: [1, 2, 3, 4]}, result: 10}, - {expression: 'contains("some", "none")', result: false}, - {expression: 'matches("test@email.com", "\\w+@\\w+\\.com")', result: true}, + { expression: '1 + 1', result: 2 }, + { expression: 'a > b', context: { a: 5, b: 3 }, result: true }, + { expression: 'sum(a)', context: { a: [1, 2, 3, 4] }, result: 10 }, + { expression: 'contains("some", "none")', result: false }, + { expression: 'matches("test@email.com", "\\w+@\\w+\\.com")', result: true }, ]; - for (const {expression, result, context} of expressions) { + for (const { expression, result, context } of expressions) { expect(await evaluateExpression(expression, context)).toEqual(result); expect(evaluateExpressionSync(expression, context)).toEqual(result); } @@ -188,13 +156,13 @@ describe('Expressions', () => { it('Evaluates unary expressions', async () => { const expressions = [ - {expression: '>= 5', context: {$: 5}, result: true}, - {expression: '< 5', context: {$: 5}, result: false}, - {expression: '"FR", "ES"', context: {$: 'GB'}, result: false}, - {expression: 'contains($, "some")', context: {$: 'some-string'}, result: true}, + { expression: '>= 5', context: { $: 5 }, result: true }, + { expression: '< 5', context: { $: 5 }, result: false }, + { expression: '"FR", "ES"', context: { $: 'GB' }, result: false }, + { expression: 'contains($, "some")', context: { $: 'some-string' }, result: true }, ]; - for (const {expression, result, context} of expressions) { + for (const { expression, result, context } of expressions) { expect(await evaluateUnaryExpression(expression, context)).toEqual(result); expect(evaluateUnaryExpressionSync(expression, context)).toEqual(result); } @@ -202,13 +170,13 @@ describe('Expressions', () => { it('Renders templates', async () => { const templateCases = [ - {template: '{{ a + 10 }}', context: {a: 10}, result: 20}, - {template: '{{ a + 10 }}', context: {a: 15}, result: 25}, - {template: '{{ a + 10 }}', context: {a: 20}, result: 30}, - {template: '{{ a + 10 }}', context: {a: 25}, result: 35}, + { template: '{{ a + 10 }}', context: { a: 10 }, result: 20 }, + { template: '{{ a + 10 }}', context: { a: 15 }, result: 25 }, + { template: '{{ a + 10 }}', context: { a: 20 }, result: 30 }, + { template: '{{ a + 10 }}', context: { a: 25 }, result: 35 }, ]; - for (const {template, context, result} of templateCases) { + for (const { template, context, result } of templateCases) { expect(await renderTemplate(template, context)).toEqual(result); expect(renderTemplateSync(template, context)).toEqual(result); } diff --git a/bindings/nodejs/yarn.lock b/bindings/nodejs/yarn.lock index 9897d069..8296aa14 100644 --- a/bindings/nodejs/yarn.lock +++ b/bindings/nodejs/yarn.lock @@ -285,6 +285,11 @@ dependencies: "@jridgewell/trace-mapping" "0.3.9" +"@epic-web/invariant@^1.0.0": + version "1.0.0" + resolved "https://registry.yarnpkg.com/@epic-web/invariant/-/invariant-1.0.0.tgz#1073e5dee6dd540410784990eb73e4acd25c9813" + integrity sha512-lrTPqgvfFQtR/eY/qkIzp98OGdNJu0m5ji3q/nJI8v3SXkRKEnWiOxMmbvcSoAIzv/cGiuvRy57k4suKQSAdwA== + "@gar/promisify@^1.1.3": version "1.1.3" resolved "https://registry.yarnpkg.com/@gar/promisify/-/promisify-1.1.3.tgz#555193ab2e3bb3b6adc3d551c9c030d9e860daf6" @@ -2167,6 +2172,14 @@ create-require@^1.1.0: resolved "https://registry.yarnpkg.com/create-require/-/create-require-1.1.1.tgz#c1d7e8f1e5f6cfc9ff65f9cd352d37348756c333" integrity sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ== +cross-env@^10.0.0: + version "10.0.0" + resolved "https://registry.yarnpkg.com/cross-env/-/cross-env-10.0.0.tgz#ba25823cfa1ed6af293dcded8796fa16cd162456" + integrity sha512-aU8qlEK/nHYtVuN4p7UQgAwVljzMg8hB4YK5ThRqD2l/ziSnryncPNn7bMLt5cFYsKVKBh8HqLqyCoTupEUu7Q== + dependencies: + "@epic-web/invariant" "^1.0.0" + cross-spawn "^7.0.6" + cross-spawn@^7.0.3, cross-spawn@^7.0.6: version "7.0.6" resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.6.tgz#8a58fe78f00dcd70c370451759dfbfaf03e8ee9f" diff --git a/bindings/python/src/custom_node.rs b/bindings/python/src/custom_node.rs index 75093c7c..cfe529b2 100644 --- a/bindings/python/src/custom_node.rs +++ b/bindings/python/src/custom_node.rs @@ -1,4 +1,4 @@ -use anyhow::anyhow; +use anyhow::{anyhow, Context}; use either::Either; use pyo3::types::PyDict; use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, PyObject, PyResult, Python}; @@ -26,15 +26,17 @@ impl PyCustomNode { } fn extract_custom_node_response(py: Python<'_>, result: PyObject) -> NodeResult { - let dict = result.extract::>(py)?; - let response: NodeResponse = depythonize(&dict)?; + let dict = result + .extract::>(py) + .context("Failed to extract response")?; + let response: NodeResponse = depythonize(&dict).context("Failed to depythonize response")?; Ok(response) } impl CustomNodeAdapter for PyCustomNode { async fn handle(&self, request: CustomNodeRequest) -> NodeResult { let Some(callable) = &self.callback else { - return Err(anyhow!("Custom node handler not provided")); + return Err(anyhow!("Custom node handler not provided").into()); }; let maybe_result: PyResult<_> = Python::with_gil(|py| { @@ -57,10 +59,10 @@ impl CustomNodeAdapter for PyCustomNode { Ok(Either::Right(result_future)) }); - match maybe_result? { + match maybe_result.context("Failed to run custom node handler")? { Either::Left(result) => result, Either::Right(future) => { - let result = future.await?; + let result = future.await.context("Failed to run custom node handler")?; Python::with_gil(|py| extract_custom_node_response(py, result)) } } diff --git a/bindings/python/src/decision.rs b/bindings/python/src/decision.rs index 92cecfee..0a9e9ed8 100644 --- a/bindings/python/src/decision.rs +++ b/bindings/python/src/decision.rs @@ -77,12 +77,12 @@ impl PyZenDecision { ) .await .map(serde_json::to_value) + .map_err(|e| { + anyhow!(serde_json::to_string(&e).unwrap_or_else(|_| e.to_string())) + }) }) .await - .context("Failed to join threads")? - .map_err(|e| { - anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string())) - })? + .context("Failed to join threads")?? .context("Failed to serialize result")?; Python::with_gil(|py| PyValue(value).into_py_any(py)) diff --git a/bindings/python/src/engine.rs b/bindings/python/src/engine.rs index 455229ec..838f31a6 100644 --- a/bindings/python/src/engine.rs +++ b/bindings/python/src/engine.rs @@ -152,12 +152,14 @@ impl PyZenEngine { ) .await .map(serde_json::to_value) + .map_err(|e| { + anyhow!( + serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string()) + ) + }) }) .await - .context("Failed to join threads")? - .map_err(|e| { - anyhow!(serde_json::to_string(e.as_ref()).unwrap_or_else(|_| e.to_string())) - })? + .context("Failed to join threads")?? .context("Failed to serialize result")?; Python::with_gil(|py| PyValue(value).into_py_any(py)) diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 9537ae14..ec945278 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -102,7 +102,7 @@ impl From for PyNodeResponse { fn from(value: NodeResponse) -> Self { Self { output: value.output.to_value(), - trace_data: value.trace_data, + trace_data: value.trace_data.map(|v| v.to_value()), } } } @@ -111,7 +111,7 @@ impl From for NodeResponse { fn from(value: PyNodeResponse) -> Self { Self { output: value.output.into(), - trace_data: value.trace_data, + trace_data: value.trace_data.map(|v| v.into()), } } } diff --git a/bindings/python/test_async.py b/bindings/python/test_async.py index eeddd29a..0efad8e7 100644 --- a/bindings/python/test_async.py +++ b/bindings/python/test_async.py @@ -4,9 +4,11 @@ import os.path import time import unittest +import os import zen +os.environ['__ZEN_MOCK_UTC_TIME'] = '2025-08-19T16:55:02.078Z' async def loader(key): with open("../../test-data/" + key, "r") as f: diff --git a/bindings/python/test_sync.py b/bindings/python/test_sync.py index b6f444c0..8472d6df 100644 --- a/bindings/python/test_sync.py +++ b/bindings/python/test_sync.py @@ -2,9 +2,11 @@ import os.path import unittest import glob +import os import zen +os.environ['__ZEN_MOCK_UTC_TIME'] = '2025-08-19T16:55:02.078Z' def loader(key): with open("../../test-data/" + key, "r") as f: @@ -111,8 +113,8 @@ def test_evaluate_graphs(self): decision = engine.get_decision(key) decision_response = decision.evaluate(test_case["input"]) - self.assertEqual(engine_response["result"], test_case["output"]) - self.assertEqual(decision_response["result"], test_case["output"]) + self.assertEqual(engine_response["result"], test_case["output"], key) + self.assertEqual(decision_response["result"], test_case["output"], key) if __name__ == '__main__': unittest.main() diff --git a/bindings/uniffi/src/custom_node.rs b/bindings/uniffi/src/custom_node.rs index 6f402e7c..8a287cbc 100644 --- a/bindings/uniffi/src/custom_node.rs +++ b/bindings/uniffi/src/custom_node.rs @@ -1,6 +1,5 @@ use crate::error::ZenError; use crate::types::{DecisionNode, ZenEngineHandlerRequest, ZenEngineHandlerResponse}; -use serde_json::Value; use uniffi::deps::anyhow::anyhow; use zen_engine::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest}; use zen_engine::handler::node::{NodeResponse, NodeResult}; @@ -49,7 +48,8 @@ impl CustomNodeAdapter for ZenCustomNodeCallbackWrapper { .try_into() .map_err(|err: ZenError| anyhow!(err))?; - let trace_data: Option = result.trace_data.and_then(|trace| trace.try_into().ok()); + let trace_data: Option = + result.trace_data.and_then(|trace| trace.try_into().ok()); Ok(NodeResponse { output, trace_data }) } diff --git a/bindings/uniffi/src/loader.rs b/bindings/uniffi/src/loader.rs index e21b2e07..2beb4def 100644 --- a/bindings/uniffi/src/loader.rs +++ b/bindings/uniffi/src/loader.rs @@ -29,15 +29,15 @@ impl DecisionLoader for ZenDecisionLoaderCallbackWrapper { let maybe_json_buffer = match self.0.load(key.into()).await { Ok(r) => r, Err(error) => { - return Err(Box::new(LoaderError::Internal { + return Err(LoaderError::Internal { key: key.to_string(), source: anyhow!(error), - })); + }); } }; let Some(json_buffer) = maybe_json_buffer else { - return Err(Box::new(LoaderError::NotFound(key.to_string()))); + return Err(LoaderError::NotFound(key.to_string())); }; let decision_content: DecisionContent = diff --git a/core/engine/Cargo.toml b/core/engine/Cargo.toml index ef4d93ac..7c9cf892 100644 --- a/core/engine/Cargo.toml +++ b/core/engine/Cargo.toml @@ -17,6 +17,7 @@ thiserror = { workspace = true } petgraph = { workspace = true } serde_json = { workspace = true, features = ["arbitrary_precision"] } serde = { workspace = true, features = ["derive", "rc"] } +strum = { workspace = true, features = ["derive"] } once_cell = { workspace = true } json_dotpath = { workspace = true } rust_decimal = { workspace = true, features = ["maths-nopanic"] } @@ -33,7 +34,6 @@ chrono = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } criterion = { workspace = true, features = ["async_tokio"] } insta = { version = "1.43", features = ["yaml", "redactions"] } -zen-expression = { path = "../expression", version = "0.49.1", features = ["time-override"] } [[bench]] harness = false diff --git a/core/engine/src/decision.rs b/core/engine/src/decision.rs index 6a87a0a6..063a012a 100644 --- a/core/engine/src/decision.rs +++ b/core/engine/src/decision.rs @@ -1,10 +1,11 @@ -use crate::engine::EvaluationOptions; +use crate::engine::{EvaluationOptions, EvaluationSerializedOptions, EvaluationTraceKind}; use crate::handler::custom_node_adapter::{CustomNodeAdapter, NoopCustomNode}; use crate::handler::graph::{DecisionGraph, DecisionGraphConfig, DecisionGraphResponse}; use crate::loader::{CachedLoader, DecisionLoader, NoopLoader}; use crate::model::DecisionContent; use crate::util::validator_cache::ValidatorCache; use crate::{DecisionGraphValidationError, EvaluationError}; +use serde_json::Value; use std::sync::Arc; use zen_expression::variable::Variable; @@ -104,6 +105,31 @@ where Ok(response) } + pub async fn evaluate_serialized( + &self, + context: Variable, + options: EvaluationSerializedOptions, + ) -> Result { + let response = self + .evaluate_with_opts( + context, + EvaluationOptions { + trace: Some(options.trace != EvaluationTraceKind::None), + max_depth: options.max_depth, + }, + ) + .await; + + match response { + Ok(ok) => Ok(ok + .serialize_with_mode(serde_json::value::Serializer, options.trace) + .unwrap_or_default()), + Err(err) => Err(err + .serialize_with_mode(serde_json::value::Serializer, options.trace) + .unwrap_or_default()), + } + } + pub fn validate(&self) -> Result<(), DecisionGraphValidationError> { let decision_graph = DecisionGraph::try_new(DecisionGraphConfig { content: self.content.clone(), diff --git a/core/engine/src/engine.rs b/core/engine/src/engine.rs index ab9e2455..bd659921 100644 --- a/core/engine/src/engine.rs +++ b/core/engine/src/engine.rs @@ -1,12 +1,13 @@ -use std::future::Future; -use std::sync::Arc; - use crate::decision::Decision; use crate::handler::custom_node_adapter::{CustomNodeAdapter, NoopCustomNode}; use crate::handler::graph::DecisionGraphResponse; use crate::loader::{ClosureLoader, DecisionLoader, LoaderResponse, LoaderResult, NoopLoader}; use crate::model::DecisionContent; use crate::EvaluationError; +use serde_json::Value; +use std::future::Future; +use std::sync::Arc; +use strum::{EnumString, IntoStaticStr}; use zen_expression::variable::Variable; /// Structure used for generating and evaluating JDM decisions @@ -26,6 +27,41 @@ pub struct EvaluationOptions { pub max_depth: Option, } +#[derive(Debug, Default)] +pub struct EvaluationSerializedOptions { + pub trace: EvaluationTraceKind, + pub max_depth: Option, +} + +#[derive(Debug, Default, PartialEq, Eq, EnumString, IntoStaticStr)] +#[strum(serialize_all = "camelCase")] +pub enum EvaluationTraceKind { + #[default] + None, + Default, + String, + Reference, + ReferenceString, +} + +impl EvaluationTraceKind { + pub fn serialize_trace(&self, trace: &Variable) -> Value { + match self { + EvaluationTraceKind::None => Value::Null, + EvaluationTraceKind::Default => serde_json::to_value(&trace).unwrap_or_default(), + EvaluationTraceKind::String => { + Value::String(serde_json::to_string(&trace).unwrap_or_default()) + } + EvaluationTraceKind::Reference => { + serde_json::to_value(&trace.serialize_ref()).unwrap_or_default() + } + EvaluationTraceKind::ReferenceString => { + Value::String(serde_json::to_string(&trace.serialize_ref()).unwrap_or_default()) + } + } + } +} + impl Default for DecisionEngine { fn default() -> Self { Self { @@ -99,6 +135,25 @@ impl DecisionEngine decision.evaluate_with_opts(context, options).await } + pub async fn evaluate_serialized( + &self, + key: K, + context: Variable, + options: EvaluationSerializedOptions, + ) -> Result + where + K: AsRef, + { + let content = self + .loader + .load(key.as_ref()) + .await + .map_err(|err| Value::String(err.to_string()))?; + + let decision = self.create_decision(content); + decision.evaluate_serialized(context, options).await + } + /// Creates a decision from DecisionContent, exists for easier binding creation pub fn create_decision(&self, content: Arc) -> Decision { Decision::from(content) diff --git a/core/engine/src/error.rs b/core/engine/src/error.rs index 4cd79486..b49c4cc4 100644 --- a/core/engine/src/error.rs +++ b/core/engine/src/error.rs @@ -1,5 +1,6 @@ +use crate::engine::EvaluationTraceKind; use crate::handler::graph::DecisionGraphValidationError; -use crate::handler::node::NodeError; +pub use crate::handler::node::NodeError; use crate::loader::LoaderError; use jsonschema::{ErrorIterator, ValidationError}; use serde::ser::SerializeMap; @@ -11,43 +12,65 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum EvaluationError { #[error("Loader error")] - LoaderError(Box), + LoaderError(LoaderError), #[error("Node error")] - NodeError(Box), + NodeError(NodeError), #[error("Depth limit exceeded")] DepthLimitExceeded, #[error("Invalid graph")] - InvalidGraph(Box), + InvalidGraph(DecisionGraphValidationError), #[error("Validation failed")] - Validation(Box), + Validation(Value), } -impl Serialize for EvaluationError { - fn serialize(&self, serializer: S) -> Result +impl EvaluationError { + pub fn serialize_with_mode( + &self, + serializer: S, + mode: EvaluationTraceKind, + ) -> Result where S: Serializer, { let mut map = serializer.serialize_map(None)?; + match self { EvaluationError::DepthLimitExceeded => { map.serialize_entry("type", "DepthLimitExceeded")?; } EvaluationError::NodeError(err) => { map.serialize_entry("type", "NodeError")?; - map.serialize_entry("nodeId", &err.node_id)?; - map.serialize_entry("source", &err.source.to_string())?; - if let Some(trace) = &err.trace { - map.serialize_entry("trace", &trace)?; + match err { + NodeError::Internal => map.serialize_entry("source", "Internal")?, + NodeError::Other(o) => map.serialize_entry("source", &o.to_string())?, + NodeError::Display(d) => map.serialize_entry("source", d.as_str())?, + NodeError::Node { + node_id, + source, + trace, + } => { + map.serialize_entry("nodeId", node_id.as_str())?; + map.serialize_entry("source", &source.to_string())?; + if let Some(trace) = &trace { + map.serialize_entry("trace", &mode.serialize_trace(trace))?; + } + } + NodeError::PartialTrace { trace, message } => { + map.serialize_entry("source", message.as_str())?; + if let Some(trace) = &trace { + map.serialize_entry("trace", &mode.serialize_trace(trace))?; + } + } } } EvaluationError::LoaderError(err) => { map.serialize_entry("type", "LoaderError")?; - match err.as_ref() { + match err { LoaderError::Internal { key, source } => { map.serialize_entry("key", key)?; map.serialize_entry("source", &source.to_string())?; @@ -71,27 +94,24 @@ impl Serialize for EvaluationError { } } -impl From for Box { - fn from(error: LoaderError) -> Self { - Box::new(EvaluationError::LoaderError(error.into())) +impl Serialize for EvaluationError { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.serialize_with_mode(serializer, Default::default()) } } -impl From> for Box { - fn from(error: Box) -> Self { - Box::new(EvaluationError::LoaderError(error)) +impl From for Box { + fn from(error: LoaderError) -> Self { + Box::new(EvaluationError::LoaderError(error.into())) } } impl From for Box { - fn from(error: NodeError) -> Self { - Box::new(EvaluationError::NodeError(error.into())) - } -} - -impl From> for Box { - fn from(error: Box) -> Self { - Box::new(EvaluationError::NodeError(error)) + fn from(value: NodeError) -> Self { + Box::new(EvaluationError::NodeError(value)) } } @@ -127,9 +147,7 @@ impl<'a> From> for Box { serde_json::to_value(errors).unwrap_or_default(), ); - Box::new(EvaluationError::Validation(Box::new(Value::Object( - json_map, - )))) + Box::new(EvaluationError::Validation(Value::Object(json_map))) } } diff --git a/core/engine/src/handler/custom_node_adapter.rs b/core/engine/src/handler/custom_node_adapter.rs index 17b29c78..9a0711f2 100644 --- a/core/engine/src/handler/custom_node_adapter.rs +++ b/core/engine/src/handler/custom_node_adapter.rs @@ -1,6 +1,5 @@ -use crate::handler::node::{NodeRequest, NodeResult}; +use crate::handler::node::{NodeError, NodeRequest, NodeResult}; use crate::model::{DecisionNode, DecisionNodeKind}; -use anyhow::anyhow; use json_dotpath::DotPaths; use serde::Serialize; use serde_json::Value; @@ -18,7 +17,9 @@ pub struct NoopCustomNode; impl CustomNodeAdapter for NoopCustomNode { async fn handle(&self, _: CustomNodeRequest) -> NodeResult { - Err(anyhow!("Custom node handler not provided")) + Err(NodeError::Display( + "Custom node handler not provided".to_string(), + )) } } diff --git a/core/engine/src/handler/decision.rs b/core/engine/src/handler/decision.rs index 89415a19..97735d70 100644 --- a/core/engine/src/handler/decision.rs +++ b/core/engine/src/handler/decision.rs @@ -1,7 +1,7 @@ use crate::handler::custom_node_adapter::CustomNodeAdapter; use crate::handler::function::function::Function; -use crate::handler::graph::{DecisionGraph, DecisionGraphConfig}; -use crate::handler::node::{NodeRequest, NodeResponse, NodeResult}; +use crate::handler::graph::{error_trace, DecisionGraph, DecisionGraphConfig}; +use crate::handler::node::{NodeError, NodeRequest, NodeResponse, NodeResult}; use crate::loader::DecisionLoader; use crate::model::DecisionNodeKind; use crate::util::validator_cache::ValidatorCache; @@ -54,7 +54,11 @@ impl DecisionHandle _ => Err(anyhow!("Unexpected node type")), }?; - let sub_decision = self.loader.load(&content.key).await?; + let sub_decision = self + .loader + .load(&content.key) + .await + .map_err(|err| NodeError::Display(err.to_string()))?; let sub_tree = DecisionGraph::try_new(DecisionGraphConfig { content: sub_decision, max_depth: self.max_depth, @@ -63,7 +67,8 @@ impl DecisionHandle iteration: request.iteration + 1, trace: self.trace, validator_cache: Some(self.validator_cache.clone()), - })? + }) + .map_err(|err| NodeError::Display(err.to_string()))? .with_function(self.js_function.clone()); let sub_tree_mutex = Arc::new(Mutex::new(sub_tree)); @@ -77,14 +82,10 @@ impl DecisionHandle let mut sub_tree_ref = sub_tree_mutex.lock().await; sub_tree_ref.reset_graph(); - sub_tree_ref - .evaluate(input) - .await - .map(|r| NodeResponse { - output: r.result, - trace_data: serde_json::to_value(r.trace).ok(), - }) - .map_err(|e| e.source) + sub_tree_ref.evaluate(input).await.map(|r| NodeResponse { + output: r.result, + trace_data: error_trace(&r.trace), + }) } }) .await diff --git a/core/engine/src/handler/expression/mod.rs b/core/engine/src/handler/expression/mod.rs index 29af5f6d..d6caf068 100644 --- a/core/engine/src/handler/expression/mod.rs +++ b/core/engine/src/handler/expression/mod.rs @@ -1,21 +1,23 @@ -use crate::handler::node::{NodeRequest, NodeResponse, NodeResult, PartialTraceError}; +use crate::handler::node::{NodeRequest, NodeResponse, NodeResult}; use crate::model::{DecisionNodeKind, ExpressionNodeContent}; use ahash::{HashMap, HashMapExt}; +use std::rc::Rc; use std::sync::Arc; -use anyhow::{anyhow, Context}; +use crate::handler::node::NodeError; +use anyhow::anyhow; use serde::Serialize; use tokio::sync::Mutex; -use zen_expression::variable::Variable; +use zen_expression::variable::{ToVariable, Variable}; use zen_expression::Isolate; pub struct ExpressionHandler { trace: bool, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, ToVariable)] struct ExpressionTrace { - result: serde_json::Value, + result: Variable, } impl ExpressionHandler { @@ -60,7 +62,9 @@ impl<'a> ExpressionHandlerInner<'a> { async fn handle(&mut self, input: Variable, content: &'a ExpressionNodeContent) -> NodeResult { let result = Variable::empty_object(); - let mut trace_map = self.trace.then(|| HashMap::<&str, ExpressionTrace>::new()); + let mut trace_map = self + .trace + .then(|| HashMap::, ExpressionTrace>::new()); self.isolate.set_environment(input.depth_clone(1)); for expression in &content.expressions { @@ -68,21 +72,17 @@ impl<'a> ExpressionHandlerInner<'a> { continue; } - let value = self - .isolate - .run_standard(&expression.value) - .with_context(|| PartialTraceError { - trace: trace_map - .as_ref() - .map(|s| serde_json::to_value(s).ok()) - .flatten(), + let value = self.isolate.run_standard(&expression.value).map_err(|_| { + NodeError::PartialTrace { + trace: trace_map.as_ref().map(|v| v.to_variable()), message: format!(r#"Failed to evaluate expression: "{}""#, &expression.value), - })?; + } + })?; if let Some(tmap) = &mut trace_map { tmap.insert( - &expression.key, + Rc::from(expression.key.as_str()), ExpressionTrace { - result: value.to_value(), + result: value.clone(), }, ); } @@ -101,7 +101,7 @@ impl<'a> ExpressionHandlerInner<'a> { Ok(NodeResponse { output: result, - trace_data: trace_map.map(|tm| serde_json::to_value(tm).ok()).flatten(), + trace_data: trace_map.as_ref().map(|v| v.to_variable()), }) } } diff --git a/core/engine/src/handler/function/function.rs b/core/engine/src/handler/function/function.rs index 3ff83ce0..78b8805e 100644 --- a/core/engine/src/handler/function/function.rs +++ b/core/engine/src/handler/function/function.rs @@ -9,7 +9,7 @@ use crate::handler::function::serde::JsValue; use rquickjs::promise::MaybePromise; use rquickjs::{async_with, AsyncContext, AsyncRuntime, CatchResultExt, Ctx, Module}; use serde::{Deserialize, Serialize}; -use zen_expression::variable::Variable; +use zen_expression::variable::{ToVariable, Variable}; pub struct FunctionConfig { pub(crate) listeners: Option>>, @@ -129,7 +129,7 @@ impl Function { } } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, ToVariable)] pub struct HandlerResponse { pub logs: Vec, pub data: Variable, diff --git a/core/engine/src/handler/function/mod.rs b/core/engine/src/handler/function/mod.rs index 2dfe47f3..98a737c1 100644 --- a/core/engine/src/handler/function/mod.rs +++ b/core/engine/src/handler/function/mod.rs @@ -2,18 +2,18 @@ use std::rc::Rc; use std::sync::atomic::Ordering; use std::time::Duration; -use ::serde::{Deserialize, Serialize}; -use anyhow::anyhow; -use rquickjs::{async_with, CatchResultExt, Object}; -use serde_json::json; - use crate::handler::function::error::FunctionResult; use crate::handler::function::function::{Function, HandlerResponse}; use crate::handler::function::module::console::Log; use crate::handler::function::serde::JsValue; -use crate::handler::node::{NodeRequest, NodeResponse, NodeResult, PartialTraceError}; +use crate::handler::node::{NodeError, NodeRequest, NodeResponse, NodeResult}; use crate::model::{DecisionNodeKind, FunctionNodeContent}; use crate::ZEN_CONFIG; +use ::serde::{Deserialize, Serialize}; +use anyhow::anyhow; +use rquickjs::{async_with, CatchResultExt, Object}; +use serde_json::json; +use zen_expression::variable::ToVariable; pub(crate) mod error; pub(crate) mod function; @@ -35,6 +35,12 @@ pub struct FunctionHandler { max_duration: Duration, } +#[derive(ToVariable)] +#[serde(rename_all = "camelCase")] +struct FunctionTrace { + pub log: Vec, +} + impl FunctionHandler { pub fn new(function: Rc, trace: bool, iteration: u8, max_depth: u8) -> Self { let max_duration_millis = ZEN_CONFIG.function_timeout_millis.load(Ordering::Relaxed); @@ -93,7 +99,12 @@ impl FunctionHandler { Ok(NodeResponse { output: response.data, - trace_data: self.trace.then(|| json!({ "log": response.logs })), + trace_data: self.trace.then(|| { + FunctionTrace { + log: response.logs.clone(), + } + .to_variable() + }), }) } Err(e) => { @@ -103,10 +114,10 @@ impl FunctionHandler { ms_since_run: start.elapsed().as_millis() as usize, }); - Err(anyhow!(PartialTraceError { + Err(NodeError::PartialTrace { message: e.to_string(), - trace: Some(json!({ "log": log })), - })) + trace: Some(FunctionTrace { log }.to_variable()), + }) } } } diff --git a/core/engine/src/handler/function/module/console.rs b/core/engine/src/handler/function/module/console.rs index 011314ab..59e8ca6b 100644 --- a/core/engine/src/handler/function/module/console.rs +++ b/core/engine/src/handler/function/module/console.rs @@ -8,6 +8,7 @@ use crate::handler::function::listener::{RuntimeEvent, RuntimeListener}; use rquickjs::prelude::Rest; use rquickjs::{Ctx, Object, Value}; use serde::{Deserialize, Serialize}; +use zen_expression::variable::ToVariable; pub(crate) struct ConsoleListener; @@ -28,7 +29,7 @@ impl RuntimeListener for ConsoleListener { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, ToVariable, Clone)] #[serde(rename_all = "camelCase")] pub struct Log { pub lines: Vec, diff --git a/core/engine/src/handler/function_v1/mod.rs b/core/engine/src/handler/function_v1/mod.rs index 1dde7059..37062b46 100644 --- a/core/engine/src/handler/function_v1/mod.rs +++ b/core/engine/src/handler/function_v1/mod.rs @@ -5,7 +5,8 @@ use crate::handler::node::{NodeRequest, NodeResponse, NodeResult}; use crate::model::{DecisionNodeKind, FunctionNodeContent}; use anyhow::anyhow; use rquickjs::Runtime; -use serde_json::json; +use serde_json::Value; +use zen_expression::variable::ToVariable; pub(crate) mod runtime; mod script; @@ -42,8 +43,15 @@ impl FunctionHandler { let response = result_response?; Ok(NodeResponse { - output: response.output, - trace_data: self.trace.then(|| json!({ "log": response.log })), + output: response.output.clone(), + trace_data: self + .trace + .then(|| FunctionTrace { log: response.log }.to_variable()), }) } } + +#[derive(ToVariable)] +struct FunctionTrace { + log: Vec, +} diff --git a/core/engine/src/handler/function_v1/script.rs b/core/engine/src/handler/function_v1/script.rs index 817748fd..c6f4a00a 100644 --- a/core/engine/src/handler/function_v1/script.rs +++ b/core/engine/src/handler/function_v1/script.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use std::rc::Rc; use zen_expression::variable::Variable; -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct EvaluateResponse { pub output: Variable, diff --git a/core/engine/src/handler/graph.rs b/core/engine/src/handler/graph.rs index f74389dc..21e9808c 100644 --- a/core/engine/src/handler/graph.rs +++ b/core/engine/src/handler/graph.rs @@ -1,3 +1,4 @@ +use crate::engine::EvaluationTraceKind; use crate::handler::custom_node_adapter::{CustomNodeAdapter, CustomNodeRequest}; use crate::handler::decision::DecisionHandler; use crate::handler::expression::ExpressionHandler; @@ -7,7 +8,7 @@ use crate::handler::function::module::zen::ZenListener; use crate::handler::function::FunctionHandler; use crate::handler::function_v1; use crate::handler::function_v1::runtime::create_runtime; -use crate::handler::node::{NodeRequest, PartialTraceError}; +use crate::handler::node::NodeRequest; use crate::handler::table::zen::DecisionTableHandler; use crate::handler::traversal::{GraphWalker, StableDiDecisionGraph}; use crate::loader::DecisionLoader; @@ -25,7 +26,7 @@ use std::rc::Rc; use std::sync::Arc; use std::time::Instant; use thiserror::Error; -use zen_expression::variable::Variable; +use zen_expression::variable::{ToVariable, Variable}; pub struct DecisionGraph { initial_graph: StableDiDecisionGraph, @@ -148,16 +149,16 @@ impl DecisionGraph< ) -> Result { let root_start = Instant::now(); - self.validate().map_err(|e| NodeError { + self.validate().map_err(|e| NodeError::Node { node_id: "".to_string(), - source: anyhow!(e), + source: anyhow!(e).into(), trace: None, })?; if self.iteration >= self.max_depth { - return Err(NodeError { + return Err(NodeError::Node { node_id: "".to_string(), - source: anyhow!(EvaluationError::DepthLimitExceeded), + source: Box::new(NodeError::Display("Depth limit exceeded".to_string())), trace: None, }); } @@ -217,21 +218,24 @@ impl DecisionGraph< .validator_cache .get_or_insert(validator_key, &json_schema) .await - .map_err(|e| NodeError { - source: e.into(), + .map_err(|e| NodeError::Node { + source: NodeError::from(e.to_string()).into(), node_id: node.id.clone(), trace: error_trace(&node_traces), })?; let context_json = context.to_value(); - validator.validate(&context_json).map_err(|e| NodeError { - source: anyhow!(serde_json::to_value( - Into::>::into(e) - ) - .unwrap_or_default()), - node_id: node.id.clone(), - trace: error_trace(&node_traces), - })?; + validator + .validate(&context_json) + .map_err(|e| NodeError::Node { + source: anyhow!(serde_json::to_value( + Box::::from(e) + ) + .unwrap_or_default()) + .into(), + node_id: node.id.clone(), + trace: error_trace(&node_traces), + })?; } walker.set_node_data(nid, context.clone()); @@ -256,8 +260,8 @@ impl DecisionGraph< .validator_cache .get_or_insert(validator_key, &json_schema) .await - .map_err(|e| NodeError { - source: e.into(), + .map_err(|e| NodeError::Node { + source: NodeError::from(e.to_string()).into(), node_id: node.id.clone(), trace: error_trace(&node_traces), })?; @@ -265,11 +269,8 @@ impl DecisionGraph< let incoming_data_json = incoming_data.to_value(); validator .validate(&incoming_data_json) - .map_err(|e| NodeError { - source: anyhow!(serde_json::to_value( - Into::>::into(e) - ) - .unwrap_or_default()), + .map_err(|e| NodeError::Node { + source: NodeError::from(e.to_string()).into(), node_id: node.id.clone(), trace: error_trace(&node_traces), })?; @@ -287,11 +288,14 @@ impl DecisionGraph< walker.set_node_data(nid, input_data); } DecisionNodeKind::FunctionNode { content } => { - let function = self.get_or_insert_function().await.map_err(|e| NodeError { - source: e.into(), - node_id: node.id.clone(), - trace: error_trace(&node_traces), - })?; + let function = + self.get_or_insert_function() + .await + .map_err(|e| NodeError::Node { + source: e.into(), + node_id: node.id.clone(), + trace: error_trace(&node_traces), + })?; let node_request = NodeRequest { node: node.clone(), @@ -308,22 +312,22 @@ impl DecisionGraph< .handle(node_request.clone()) .await .map_err(|e| { - if let Some(detailed_err) = e.downcast_ref::() { + if let NodeError::PartialTrace { trace, .. } = &e { trace!({ input: node_request.input.clone(), output: Variable::Null, - trace_data: detailed_err.trace.clone(), + trace_data: trace.clone(), }); } - NodeError { + NodeError::Node { source: e.into(), node_id: node.id.clone(), trace: error_trace(&node_traces), } })?, FunctionNodeContent::Version1(_) => { - let runtime = create_runtime().map_err(|e| NodeError { + let runtime = create_runtime().map_err(|e| NodeError::Node { source: e.into(), node_id: node.id.clone(), trace: error_trace(&node_traces), @@ -332,7 +336,7 @@ impl DecisionGraph< function_v1::FunctionHandler::new(self.trace, runtime) .handle(node_request.clone()) .await - .map_err(|e| NodeError { + .map_err(|e| NodeError::Node { source: e.into(), node_id: node.id.clone(), trace: error_trace(&node_traces), @@ -367,7 +371,7 @@ impl DecisionGraph< ) .handle(node_request.clone()) .await - .map_err(|e| NodeError { + .map_err(|e| NodeError::Node { source: e.into(), node_id: node.id.to_string(), trace: error_trace(&node_traces), @@ -393,7 +397,7 @@ impl DecisionGraph< let res = DecisionTableHandler::new(self.trace) .handle(node_request.clone()) .await - .map_err(|e| NodeError { + .map_err(|e| NodeError::Node { node_id: node.id.clone(), source: e.into(), trace: error_trace(&node_traces), @@ -420,15 +424,15 @@ impl DecisionGraph< .handle(node_request.clone()) .await .map_err(|e| { - if let Some(detailed_err) = e.downcast_ref::() { + if let NodeError::PartialTrace { trace, .. } = &e { trace!({ input: node_request.input.clone(), output: Variable::Null, - trace_data: detailed_err.trace.clone(), + trace_data: trace.clone(), }); } - NodeError { + NodeError::Node { node_id: node.id.clone(), source: e.into(), trace: error_trace(&node_traces), @@ -456,7 +460,7 @@ impl DecisionGraph< .adapter .handle(CustomNodeRequest::try_from(node_request.clone()).unwrap()) .await - .map_err(|e| NodeError { + .map_err(|e| NodeError::Node { node_id: node.id.clone(), source: e.into(), trace: error_trace(&node_traces), @@ -536,7 +540,27 @@ pub struct DecisionGraphResponse { pub trace: Option>, } -#[derive(Debug, Clone, Serialize, Deserialize)] +impl DecisionGraphResponse { + pub fn serialize_with_mode( + &self, + serializer: S, + mode: EvaluationTraceKind, + ) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(None)?; + map.serialize_entry("performance", &self.performance)?; + map.serialize_entry("result", &self.result)?; + if let Some(trace) = &self.trace { + map.serialize_entry("trace", &mode.serialize_trace(&trace.to_variable()))?; + } + + map.end() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToVariable)] #[serde(rename_all = "camelCase")] pub struct DecisionGraphTrace { pub input: Variable, @@ -544,15 +568,19 @@ pub struct DecisionGraphTrace { pub name: String, pub id: String, pub performance: Option, - pub trace_data: Option, + pub trace_data: Option, pub order: u32, } -pub(crate) fn error_trace(trace: &Option>) -> Option { - trace - .as_ref() - .map(|s| serde_json::to_value(s).ok()) - .flatten() +pub(crate) fn error_trace(trace: &Option>) -> Option { + trace.as_ref().map(|s| { + s.values().for_each(|v| { + v.input.dot_remove("$nodes"); + v.output.dot_remove("$nodes"); + }); + + s.to_variable() + }) } fn create_validator_cache_key(content: &Value) -> u64 { diff --git a/core/engine/src/handler/node.rs b/core/engine/src/handler/node.rs index ebc5438c..c28ef4f1 100644 --- a/core/engine/src/handler/node.rs +++ b/core/engine/src/handler/node.rs @@ -1,16 +1,14 @@ use crate::model::DecisionNode; use serde::{Deserialize, Serialize}; -use serde_json::Value; use std::fmt::{Display, Formatter}; use std::sync::Arc; -use thiserror::Error; use zen_expression::variable::Variable; #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct NodeResponse { pub output: Variable, - pub trace_data: Option, + pub trace_data: Option, } #[derive(Debug, Serialize, Clone)] @@ -20,30 +18,71 @@ pub struct NodeRequest { pub node: Arc, } -#[derive(Error, Debug)] -pub struct NodeError { - pub node_id: String, - pub trace: Option, - #[source] - pub source: anyhow::Error, +pub type NodeResult = Result; + +#[derive(Debug)] +pub enum NodeError { + Internal, + Other(Box), + Display(String), // For non-Error types that implement Display + Node { + node_id: String, + trace: Option, + source: Box, + }, + PartialTrace { + trace: Option, + message: String, + }, +} + +impl NodeError { + /// Convert any error type to NodError + pub fn from_error(error: E) -> Self { + Self::Other(Box::new(error)) + } + + /// Add context to this error + pub fn context(self, context: C) -> Self { + Self::Display(format!("{}: {}", context, self)) + } + + /// Add context to this error using a closure + pub fn with_context C>(self, f: F) -> Self { + Self::Display(format!("{}: {}", f(), self)) + } } impl Display for NodeError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + match self { + NodeError::Internal => write!(f, "Internal error occurred"), + NodeError::Other(err) => write!(f, "{}", err), + NodeError::Display(msg) => write!(f, "{}", msg), + NodeError::Node { source, .. } => { + write!(f, "{}", source) + } + NodeError::PartialTrace { message, .. } => { + write!(f, "{}", message) + } + } } } -#[derive(Debug)] -pub(crate) struct PartialTraceError { - pub trace: Option, - pub message: String, +impl From for Box { + fn from(value: anyhow::Error) -> Self { + Box::new(NodeError::Other(value.into())) + } } -impl Display for PartialTraceError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.message) +impl From for NodeError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value.into()) } } -pub type NodeResult = anyhow::Result; +impl From for NodeError { + fn from(value: String) -> Self { + Self::Display(value) + } +} diff --git a/core/engine/src/handler/table/mod.rs b/core/engine/src/handler/table/mod.rs index 58f00118..71967534 100644 --- a/core/engine/src/handler/table/mod.rs +++ b/core/engine/src/handler/table/mod.rs @@ -1,13 +1,13 @@ pub mod zen; -use zen_expression::variable::Variable; +use zen_expression::variable::{ToVariable, Variable}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, ToVariable)] pub(crate) enum RowOutputKind { Variable(Variable), } -#[derive(Debug, Default)] +#[derive(Debug, Default, ToVariable)] pub(crate) struct RowOutput { output: OutputMap, } diff --git a/core/engine/src/handler/table/zen.rs b/core/engine/src/handler/table/zen.rs index 079682ff..b582e027 100644 --- a/core/engine/src/handler/table/zen.rs +++ b/core/engine/src/handler/table/zen.rs @@ -7,10 +7,10 @@ use crate::handler::table::{RowOutput, RowOutputKind}; use crate::model::{DecisionNodeKind, DecisionTableContent, DecisionTableHitPolicy}; use serde::Serialize; use tokio::sync::Mutex; -use zen_expression::variable::Variable; +use zen_expression::variable::{ToVariable, Variable}; use zen_expression::Isolate; -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, ToVariable)] struct RowResult { rule: Option>, reference_map: Option>, @@ -89,10 +89,7 @@ impl<'a> DecisionTableHandlerInner<'a> { if let Some(result) = self.evaluate_row(&content, i) { return Ok(NodeResponse { output: result.output.to_json().await, - trace_data: self - .trace - .then(|| serde_json::to_value(&result).ok()) - .flatten(), + trace_data: self.trace.then(|| result.to_variable()), }); } } @@ -118,10 +115,7 @@ impl<'a> DecisionTableHandlerInner<'a> { Ok(NodeResponse { output: Variable::from_array(outputs), - trace_data: self - .trace - .then(|| serde_json::to_value(&results).ok()) - .flatten(), + trace_data: self.trace.then(|| results.to_variable()), }) } diff --git a/core/engine/src/lib.rs b/core/engine/src/lib.rs index 80804090..35323827 100644 --- a/core/engine/src/lib.rs +++ b/core/engine/src/lib.rs @@ -125,7 +125,7 @@ mod config; mod decision; mod engine; -mod error; +pub mod error; pub mod handler; pub mod loader; #[path = "model/mod.rs"] @@ -134,7 +134,9 @@ mod util; pub use config::ZEN_CONFIG; pub use decision::Decision; -pub use engine::{DecisionEngine, EvaluationOptions}; +pub use engine::{ + DecisionEngine, EvaluationOptions, EvaluationSerializedOptions, EvaluationTraceKind, +}; pub use error::EvaluationError; pub use handler::graph::DecisionGraphResponse; pub use handler::graph::DecisionGraphTrace; diff --git a/core/engine/src/loader/mod.rs b/core/engine/src/loader/mod.rs index 4cb8a27b..cdf6f663 100644 --- a/core/engine/src/loader/mod.rs +++ b/core/engine/src/loader/mod.rs @@ -18,7 +18,7 @@ mod filesystem; mod memory; mod noop; -pub type LoaderResult = Result>; +pub type LoaderResult = Result; pub type LoaderResponse = LoaderResult>; /// Trait used for implementing a loader for decisions diff --git a/core/engine/src/util/transform_attribute.rs b/core/engine/src/util/transform_attribute.rs index b68bed1c..888476be 100644 --- a/core/engine/src/util/transform_attribute.rs +++ b/core/engine/src/util/transform_attribute.rs @@ -1,7 +1,6 @@ use crate::handler::node::{NodeResponse, NodeResult}; use crate::model::{TransformAttributes, TransformExecutionMode}; use anyhow::Context; -use serde_json::Value; use std::future::Future; use zen_expression::{Isolate, Variable}; @@ -16,7 +15,9 @@ impl TransformAttributes { Some(input_field) => { let mut isolate = Isolate::new(); isolate.set_environment(node_input.clone()); - let calculated_input = isolate.run_standard(input_field.as_str())?; + let calculated_input = isolate + .run_standard(input_field.as_str()) + .context("Failed to run standard")?; let nodes = node_input.dot("$nodes").unwrap_or(Variable::Null); match &calculated_input { @@ -42,7 +43,7 @@ impl TransformAttributes { } }; - let mut trace_data: Option = None; + let mut trace_data: Option = None; let mut output = match self.execution_mode { TransformExecutionMode::Single => { let response = evaluate(input).await?; @@ -73,7 +74,7 @@ impl TransformAttributes { output_array.push(response.output); } - trace_data.replace(Value::Array(trace_datum)); + trace_data.replace(Variable::from_array(trace_datum)); Variable::from_array(output_array) } }; diff --git a/core/engine/tests/decision.rs b/core/engine/tests/decision.rs index cc5609b0..a487fb6b 100644 --- a/core/engine/tests/decision.rs +++ b/core/engine/tests/decision.rs @@ -3,7 +3,7 @@ use serde_json::json; use std::ops::Deref; use std::sync::Arc; use tokio::runtime::Builder; -use zen_engine::{Decision, DecisionGraphValidationError, EvaluationError}; +use zen_engine::{Decision, DecisionGraphValidationError, EvaluationError, NodeError}; mod support; @@ -28,9 +28,11 @@ async fn decision_from_content_recursive() { let context = json!({}); let result = decision.evaluate(context.clone().into()).await; match result.unwrap_err().deref() { - EvaluationError::NodeError(e) => { - assert_eq!(e.node_id, "0b8dcf6b-fc04-47cb-bf82-bda764e6c09b"); - assert!(e.source.to_string().contains("Loader failed")); + EvaluationError::NodeError(NodeError::Node { + node_id, source, .. + }) => { + assert_eq!(node_id, "0b8dcf6b-fc04-47cb-bf82-bda764e6c09b"); + assert!(source.to_string().contains("Loader failed")); } _ => assert!(false, "Depth limit not exceeded"), } @@ -38,8 +40,8 @@ async fn decision_from_content_recursive() { let with_loader = decision.with_loader(Arc::new(create_fs_loader())); let new_result = with_loader.evaluate(context.clone().into()).await; match new_result.unwrap_err().deref() { - EvaluationError::NodeError(e) => { - assert_eq!(e.source.to_string(), "Depth limit exceeded") + EvaluationError::NodeError(NodeError::Node { source, .. }) => { + assert_eq!(source.to_string(), "Depth limit exceeded") } _ => assert!(false, "Depth limit not exceeded"), } diff --git a/core/engine/tests/engine.rs b/core/engine/tests/engine.rs index 61ad0502..31592e7e 100644 --- a/core/engine/tests/engine.rs +++ b/core/engine/tests/engine.rs @@ -10,9 +10,8 @@ use std::sync::Arc; use tokio::runtime::Builder; use zen_engine::loader::{LoaderError, MemoryLoader}; use zen_engine::model::{DecisionContent, DecisionNode, DecisionNodeKind, FunctionNodeContent}; -use zen_engine::Variable; use zen_engine::{DecisionEngine, EvaluationError, EvaluationOptions}; -use zen_expression::vm::UTC_OVERRIDE; +use zen_engine::{NodeError, Variable}; mod support; @@ -114,9 +113,11 @@ async fn engine_errors() { .evaluate("infinite-function.json", json!({}).into()) .await; match infinite_fn.unwrap_err().deref() { - EvaluationError::NodeError(e) => { - assert_eq!(e.node_id, "e0fd96d0-44dc-4f0e-b825-06e56b442d78"); - assert!(e.source.to_string().contains("interrupted")); + EvaluationError::NodeError(NodeError::Node { + node_id, source, .. + }) => { + assert_eq!(node_id, "e0fd96d0-44dc-4f0e-b825-06e56b442d78"); + assert!(source.to_string().contains("interrupted")); } _ => assert!(false, "Wrong error type"), } @@ -125,8 +126,8 @@ async fn engine_errors() { .evaluate("recursive-table1.json", json!({}).into()) .await; match recursive.unwrap_err().deref() { - EvaluationError::NodeError(e) => { - assert_eq!(e.source.to_string(), "Depth limit exceeded") + EvaluationError::NodeError(NodeError::Node { source, .. }) => { + assert_eq!(source.to_string(), "Depth limit exceeded") } _ => assert!(false, "Depth limit not exceeded"), } @@ -269,7 +270,7 @@ async fn engine_graph_tests() { } fn mock_datetime() { - *UTC_OVERRIDE.write().unwrap() = Some("2025-08-19T16:55:02.078Z".parse().unwrap()); + std::env::set_var("__ZEN_MOCK_UTC_TIME", "2025-08-19T16:55:02.078Z"); } #[tokio::test] diff --git a/core/expression/Cargo.toml b/core/expression/Cargo.toml index 75d55a05..edbbbcde 100644 --- a/core/expression/Cargo.toml +++ b/core/expression/Cargo.toml @@ -30,6 +30,9 @@ nohash-hasher = "0.2.0" strsim = "0.11" iana-time-zone = "0.1" +zen-macros = { path = "../macros" } +zen-types = { path = "../types" } + [dev-dependencies] criterion = { workspace = true } csv = "1" @@ -40,7 +43,6 @@ default = ["regex-deprecated", "stack-protection"] regex-lite = ["dep:regex-lite"] regex-deprecated = ["dep:regex"] stack-protection = ["dep:recursive"] -time-override = [] [[bench]] harness = false diff --git a/core/expression/src/functions/arguments.rs b/core/expression/src/functions/arguments.rs index 17b0b335..b3ff781e 100644 --- a/core/expression/src/functions/arguments.rs +++ b/core/expression/src/functions/arguments.rs @@ -1,13 +1,16 @@ -use crate::variable::{DynamicVariable, RcCell}; +use crate::variable::DynamicVariable; use crate::Variable; use ahash::HashMap; use anyhow::Context; use rust_decimal::Decimal; +use std::cell::RefCell; use std::ops::Deref; use std::rc::Rc; pub struct Arguments<'a>(pub &'a [Variable]); +type RcCell = Rc>; + impl<'a> Deref for Arguments<'a> { type Target = [Variable]; diff --git a/core/expression/src/functions/deprecated.rs b/core/expression/src/functions/deprecated.rs index f344c60a..12eef7a4 100644 --- a/core/expression/src/functions/deprecated.rs +++ b/core/expression/src/functions/deprecated.rs @@ -117,11 +117,29 @@ impl From<&DeprecatedFunction> for Rc { mod imp { use super::*; use crate::vm::helpers::DateUnit; + use crate::vm::VMError; + use zen_types::variable::Variable; fn __internal_convert_datetime(timestamp: &V) -> anyhow::Result { - timestamp - .try_into() - .context("Failed to convert value to date time") + match timestamp { + Variable::String(a) => date_time(a), + #[allow(deprecated)] + Variable::Number(a) => NaiveDateTime::from_timestamp_opt( + a.to_i64().ok_or_else(|| VMError::OpcodeErr { + opcode: "DateManipulation".into(), + message: "Failed to extract date".into(), + })?, + 0, + ) + .ok_or_else(|| VMError::ParseDateTimeErr { + timestamp: a.to_string(), + }), + _ => Err(VMError::OpcodeErr { + opcode: "DateManipulation".into(), + message: "Unsupported type".into(), + }), + } + .context("Failed to convert value to date time") } pub fn parse_date(args: Arguments) -> anyhow::Result { diff --git a/core/expression/src/functions/internal.rs b/core/expression/src/functions/internal.rs index 0e467ad2..7b0c1719 100644 --- a/core/expression/src/functions/internal.rs +++ b/core/expression/src/functions/internal.rs @@ -304,6 +304,7 @@ impl From<&InternalFunction> for Rc { pub(crate) mod imp { use crate::functions::arguments::Arguments; + use crate::vm::date::DynamicVariableExt; use crate::vm::VmDate; use crate::{Variable as V, Variable}; use anyhow::{anyhow, Context}; diff --git a/core/expression/src/intellisense/types/provider.rs b/core/expression/src/intellisense/types/provider.rs index bcea1f4e..454eedfb 100644 --- a/core/expression/src/intellisense/types/provider.rs +++ b/core/expression/src/intellisense/types/provider.rs @@ -5,8 +5,8 @@ use crate::intellisense::types::type_info::TypeInfo; use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator}; use crate::parser::Node; use crate::variable::VariableType; +use ahash::{HashMap, HashMapExt}; use std::cell::RefCell; -use std::collections::HashMap; use std::iter::once; use std::ops::Deref; use std::rc::Rc; @@ -151,7 +151,6 @@ impl TypesProvider { .root_data .dot_insert_detached(key_type.as_ref(), value_type.kind.shallow_clone()) { - println!("NewVar: {new_var:?}"); scope.root_data = new_var; }; diff --git a/core/expression/src/parser/ast.rs b/core/expression/src/parser/ast.rs index 42681a8e..a61f6401 100644 --- a/core/expression/src/parser/ast.rs +++ b/core/expression/src/parser/ast.rs @@ -154,7 +154,7 @@ impl<'a> Node<'a> { }; } - pub fn first_error(&self) -> Option { + pub fn first_error(&'a self) -> Option> { let error_cell = Cell::new(None); self.walk(|n| { if let Node::Error { error, .. } = n { diff --git a/core/expression/src/variable/mod.rs b/core/expression/src/variable/mod.rs index e2410ef4..ef41a48d 100644 --- a/core/expression/src/variable/mod.rs +++ b/core/expression/src/variable/mod.rs @@ -1,481 +1,5 @@ -use ahash::HashMap; -use rust_decimal::prelude::Zero; -use rust_decimal::Decimal; -use serde_json::Value; -use std::any::Any; -use std::cell::RefCell; -use std::collections::hash_map::Entry; -use std::fmt::{Debug, Display, Formatter}; -use std::ops::Deref; -use std::rc::Rc; +pub use zen_types::rcvalue::*; +pub use zen_types::variable::*; +pub use zen_types::variable_type::*; -mod conv; -mod de; -mod ser; -mod types; - -pub use de::VariableDeserializer; -pub use types::VariableType; - -pub(crate) type RcCell = Rc>; - -pub enum Variable { - Null, - Bool(bool), - Number(Decimal), - String(Rc), - Array(RcCell>), - Object(RcCell, Variable>>), - Dynamic(Rc), -} - -pub trait DynamicVariable: Display { - fn type_name(&self) -> &'static str; - - fn as_any(&self) -> &dyn Any; - - fn to_value(&self) -> Value; -} - -impl Variable { - pub fn from_array(arr: Vec) -> Self { - Self::Array(Rc::new(RefCell::new(arr))) - } - - pub fn from_object(obj: HashMap, Self>) -> Self { - Self::Object(Rc::new(RefCell::new(obj))) - } - - pub fn empty_object() -> Self { - Variable::Object(Default::default()) - } - - pub fn empty_array() -> Self { - Variable::Array(Default::default()) - } - - pub fn as_str(&self) -> Option<&str> { - match self { - Variable::String(s) => Some(s.as_ref()), - _ => None, - } - } - - pub fn as_rc_str(&self) -> Option> { - match self { - Variable::String(s) => Some(s.clone()), - _ => None, - } - } - - pub fn as_array(&self) -> Option>> { - match self { - Variable::Array(arr) => Some(arr.clone()), - _ => None, - } - } - - pub fn is_array(&self) -> bool { - match self { - Variable::Array(_) => true, - _ => false, - } - } - - pub fn as_object(&self) -> Option, Variable>>> { - match self { - Variable::Object(obj) => Some(obj.clone()), - _ => None, - } - } - - pub fn is_object(&self) -> bool { - match self { - Variable::Object(_) => true, - _ => false, - } - } - - pub fn as_bool(&self) -> Option { - match self { - Variable::Bool(b) => Some(*b), - _ => None, - } - } - - pub fn as_number(&self) -> Option { - match self { - Variable::Number(n) => Some(*n), - _ => None, - } - } - - pub fn type_name(&self) -> &'static str { - match self { - Variable::Null => "null", - Variable::Bool(_) => "bool", - Variable::Number(_) => "number", - Variable::String(_) => "string", - Variable::Array(_) => "array", - Variable::Object(_) => "object", - Variable::Dynamic(d) => d.type_name(), - } - } - - pub fn dynamic(&self) -> Option<&T> { - match self { - Variable::Dynamic(d) => d.as_any().downcast_ref::(), - _ => None, - } - } - - pub fn to_value(&self) -> Value { - Value::from(self.shallow_clone()) - } - - pub fn dot(&self, key: &str) -> Option { - key.split('.') - .try_fold(self.shallow_clone(), |var, part| match var { - Variable::Object(obj) => { - let reference = obj.borrow(); - reference.get(part).map(|v| v.shallow_clone()) - } - _ => None, - }) - } - - fn dot_head(&self, key: &str) -> Option { - let mut parts = Vec::from_iter(key.split('.')); - parts.pop(); - - parts - .iter() - .try_fold(self.shallow_clone(), |var, part| match var { - Variable::Object(obj) => { - let mut obj_ref = obj.borrow_mut(); - Some(match obj_ref.entry(Rc::from(*part)) { - Entry::Occupied(occ) => occ.get().shallow_clone(), - Entry::Vacant(vac) => vac.insert(Self::empty_object()).shallow_clone(), - }) - } - _ => None, - }) - } - - fn dot_head_detach(&self, key: &str) -> (Variable, Option) { - let mut parts = Vec::from_iter(key.split('.')); - parts.pop(); - - let cloned_self = self.depth_clone(1); - let head = parts - .iter() - .try_fold(cloned_self.shallow_clone(), |var, part| match var { - Variable::Object(obj) => { - let mut obj_ref = obj.borrow_mut(); - Some(match obj_ref.entry(Rc::from(*part)) { - Entry::Occupied(mut occ) => { - let var = occ.get(); - let new_obj = match var { - Variable::Object(_) => var.depth_clone(1), - _ => Variable::empty_object(), - }; - - occ.insert(new_obj.shallow_clone()); - new_obj - } - Entry::Vacant(vac) => vac.insert(Self::empty_object()).shallow_clone(), - }) - } - _ => None, - }); - - (cloned_self, head) - } - - pub fn dot_remove(&self, key: &str) -> Option { - let last_part = key.split('.').last()?; - let head = self.dot_head(key)?; - let Variable::Object(object_ref) = head else { - return None; - }; - - let mut object = object_ref.borrow_mut(); - object.remove(last_part) - } - - pub fn dot_insert(&self, key: &str, variable: Variable) -> Option { - let last_part = key.split('.').last()?; - let head = self.dot_head(key)?; - let Variable::Object(object_ref) = head else { - return None; - }; - - let mut object = object_ref.borrow_mut(); - object.insert(Rc::from(last_part), variable) - } - - pub fn dot_insert_detached(&self, key: &str, variable: Variable) -> Option { - let last_part = key.split('.').last()?; - let (new_var, head_opt) = self.dot_head_detach(key); - let head = head_opt?; - let Variable::Object(object_ref) = head else { - return None; - }; - - let mut object = object_ref.borrow_mut(); - object.insert(Rc::from(last_part), variable); - Some(new_var) - } - - pub fn merge(&mut self, patch: &Variable) -> Variable { - let _ = merge_variables(self, patch, true, MergeStrategy::InPlace); - - self.shallow_clone() - } - - pub fn merge_clone(&mut self, patch: &Variable) -> Variable { - let mut new_self = self.shallow_clone(); - - let _ = merge_variables(&mut new_self, patch, true, MergeStrategy::CloneOnWrite); - new_self - } - - pub fn shallow_clone(&self) -> Self { - match self { - Variable::Null => Variable::Null, - Variable::Bool(b) => Variable::Bool(*b), - Variable::Number(n) => Variable::Number(*n), - Variable::String(s) => Variable::String(s.clone()), - Variable::Array(a) => Variable::Array(a.clone()), - Variable::Object(o) => Variable::Object(o.clone()), - Variable::Dynamic(d) => Variable::Dynamic(d.clone()), - } - } - - pub fn deep_clone(&self) -> Self { - match self { - Variable::Array(a) => { - let arr = a.borrow(); - Variable::from_array(arr.iter().map(|v| v.deep_clone()).collect()) - } - Variable::Object(o) => { - let obj = o.borrow(); - Variable::from_object( - obj.iter() - .map(|(k, v)| (k.clone(), v.deep_clone())) - .collect(), - ) - } - _ => self.shallow_clone(), - } - } - - pub fn depth_clone(&self, depth: usize) -> Self { - match depth.is_zero() { - true => self.shallow_clone(), - false => match self { - Variable::Array(a) => { - let arr = a.borrow(); - Variable::from_array(arr.iter().map(|v| v.depth_clone(depth - 1)).collect()) - } - Variable::Object(o) => { - let obj = o.borrow(); - Variable::from_object( - obj.iter() - .map(|(k, v)| (k.clone(), v.depth_clone(depth - 1))) - .collect(), - ) - } - _ => self.shallow_clone(), - }, - } - } -} - -impl Clone for Variable { - fn clone(&self) -> Self { - self.shallow_clone() - } -} - -#[derive(Copy, Clone)] -enum MergeStrategy { - InPlace, - CloneOnWrite, -} - -fn merge_variables( - doc: &mut Variable, - patch: &Variable, - top_level: bool, - strategy: MergeStrategy, -) -> bool { - if patch.is_array() && top_level { - *doc = patch.shallow_clone(); - return true; - } - - if !patch.is_object() && top_level { - return false; - } - - if doc.is_object() && patch.is_object() { - let doc_ref = doc.as_object().unwrap(); - let patch_ref = patch.as_object().unwrap(); - if Rc::ptr_eq(&doc_ref, &patch_ref) { - return false; - } - - let patch = patch_ref.borrow(); - match strategy { - MergeStrategy::InPlace => { - let mut map = doc_ref.borrow_mut(); - for (key, value) in patch.deref() { - if value == &Variable::Null { - map.remove(key); - } else { - let entry = map.entry(key.clone()).or_insert(Variable::Null); - merge_variables(entry, value, false, strategy); - } - } - - return true; - } - MergeStrategy::CloneOnWrite => { - let mut changed = false; - let mut new_map = None; - - for (key, value) in patch.deref() { - // Get or create the new map if we haven't yet - let map = if let Some(ref mut m) = new_map { - m - } else { - let m = doc_ref.borrow().clone(); - new_map = Some(m); - new_map.as_mut().unwrap() - }; - - if value == &Variable::Null { - // Remove null values - if map.remove(key).is_some() { - changed = true; - } - } else { - // Handle nested merging - let entry = map.entry(key.clone()).or_insert(Variable::Null); - if merge_variables(entry, value, false, strategy) { - changed = true; - } - } - } - - // Only update doc if changes were made - if changed { - if let Some(new_map) = new_map { - *doc = Variable::Object(Rc::new(RefCell::new(new_map))); - } - return true; - } - - return false; - } - } - } else { - let new_value = patch.shallow_clone(); - if *doc != new_value { - *doc = new_value; - return true; - } - - return false; - } -} - -impl Display for Variable { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Variable::Null => write!(f, "null"), - Variable::Bool(b) => match *b { - true => write!(f, "true"), - false => write!(f, "false"), - }, - Variable::Number(n) => write!(f, "{n}"), - Variable::String(s) => write!(f, "\"{s}\""), - Variable::Array(arr) => { - let arr = arr.borrow(); - let s = arr - .iter() - .map(|v| v.to_string()) - .collect::>() - .join(","); - write!(f, "[{s}]") - } - Variable::Object(obj) => { - let obj = obj.borrow(); - let s = obj - .iter() - .map(|(k, v)| format!("\"{k}\":{v}")) - .collect::>() - .join(","); - - write!(f, "{{{s}}}") - } - Variable::Dynamic(d) => write!(f, "{d}"), - } - } -} - -impl Debug for Variable { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) - } -} - -impl PartialEq for Variable { - fn eq(&self, other: &Self) -> bool { - match (&self, &other) { - (Variable::Null, Variable::Null) => true, - (Variable::Bool(b1), Variable::Bool(b2)) => b1 == b2, - (Variable::Number(n1), Variable::Number(n2)) => n1 == n2, - (Variable::String(s1), Variable::String(s2)) => s1 == s2, - (Variable::Array(a1), Variable::Array(a2)) => a1 == a2, - (Variable::Object(obj1), Variable::Object(obj2)) => obj1 == obj2, - (Variable::Dynamic(d1), Variable::Dynamic(d2)) => Rc::ptr_eq(d1, d2), - _ => false, - } - } -} - -impl Eq for Variable {} - -#[cfg(test)] -mod tests { - use crate::Variable; - use rust_decimal_macros::dec; - use serde_json::json; - - #[test] - fn insert_detached() { - let some_data: Variable = json!({ "customer": { "firstName": "John" }}).into(); - - let a_a = some_data - .dot_insert_detached("a.a", Variable::Number(dec!(1))) - .unwrap(); - let a_b = a_a - .dot_insert_detached("a.b", Variable::Number(dec!(2))) - .unwrap(); - let a_c = a_b - .dot_insert_detached("a.c", Variable::Number(dec!(3))) - .unwrap(); - - assert_eq!(a_a.dot("a"), Some(Variable::from(json!({ "a": 1 })))); - assert_eq!( - a_b.dot("a"), - Some(Variable::from(json!({ "a": 1, "b": 2 }))) - ); - assert_eq!( - a_c.dot("a"), - Some(Variable::from(json!({ "a": 1, "b": 2, "c": 3 }))) - ); - } -} +pub use zen_macros::ToVariable; diff --git a/core/expression/src/vm/date/mod.rs b/core/expression/src/vm/date/mod.rs index c150e304..225c2709 100644 --- a/core/expression/src/vm/date/mod.rs +++ b/core/expression/src/vm/date/mod.rs @@ -7,7 +7,7 @@ use chrono_tz::Tz; use serde_json::Value; use std::any::Any; use std::fmt::{Display, Formatter}; -use std::sync::{LazyLock, RwLock}; +use std::sync::OnceLock; // Duration is a modified copy of `humantime` mod duration; @@ -176,7 +176,7 @@ impl VmDate { } mod helper { - use crate::vm::date::{utc_now, Duration, DurationUnit}; + use crate::vm::date::{utc_now, Duration, DurationUnit, DynamicVariableExt}; use crate::Variable; use chrono::{ DateTime, Datelike, Days, LocalResult, Month, Months, NaiveDate, NaiveDateTime, Offset, @@ -476,25 +476,26 @@ mod helper { } } -impl dyn DynamicVariable { - pub(crate) fn as_date(&self) -> Option<&VmDate> { +pub(crate) trait DynamicVariableExt { + fn as_date(&self) -> Option<&VmDate>; +} + +impl DynamicVariableExt for dyn DynamicVariable { + fn as_date(&self) -> Option<&VmDate> { self.as_any().downcast_ref::() } } -#[cfg(feature = "time-override")] -pub static UTC_OVERRIDE: LazyLock>>> = - LazyLock::new(|| RwLock::new(None)); - pub(crate) fn utc_now() -> DateTime { - #[cfg(feature = "time-override")] - { - if let Ok(override_time) = UTC_OVERRIDE.read() { - if let Some(time) = *override_time { - return time; - } - } - } + static CURRENT_DATE_VALUE: OnceLock>> = OnceLock::new(); - Utc::now() + CURRENT_DATE_VALUE + .get_or_init(|| match std::env::var("__ZEN_MOCK_UTC_TIME") { + Ok(v) => { + let now = v.parse::>().unwrap(); + Some(now) + } + Err(_) => None, + }) + .unwrap_or_else(|| Utc::now()) } diff --git a/core/expression/src/vm/helpers.rs b/core/expression/src/vm/helpers.rs index a58778ae..2ce36a3d 100644 --- a/core/expression/src/vm/helpers.rs +++ b/core/expression/src/vm/helpers.rs @@ -1,8 +1,6 @@ use crate::vm::date::utc_now; use crate::vm::error::{VMError, VMResult}; -use chrono::{ - DateTime, Datelike, Days, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc, Weekday, -}; +use chrono::{DateTime, Datelike, Days, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Weekday}; use once_cell::sync::Lazy; #[allow(clippy::unwrap_used)] diff --git a/core/expression/src/vm/mod.rs b/core/expression/src/vm/mod.rs index 19c6c4b9..ac3fa557 100644 --- a/core/expression/src/vm/mod.rs +++ b/core/expression/src/vm/mod.rs @@ -11,6 +11,3 @@ mod interval; mod vm; pub(crate) use date::VmDate; - -#[cfg(feature = "time-override")] -pub use date::UTC_OVERRIDE; diff --git a/core/expression/src/vm/vm.rs b/core/expression/src/vm/vm.rs index 79a50f16..01e2b1e3 100644 --- a/core/expression/src/vm/vm.rs +++ b/core/expression/src/vm/vm.rs @@ -4,6 +4,7 @@ use crate::functions::registry::FunctionRegistry; use crate::functions::{internal, MethodRegistry}; use crate::variable::Variable; use crate::variable::Variable::*; +use crate::vm::date::DynamicVariableExt; use crate::vm::error::VMError::*; use crate::vm::error::VMResult; use crate::vm::interval::{VmInterval, VmIntervalData}; diff --git a/core/macros/Cargo.toml b/core/macros/Cargo.toml new file mode 100644 index 00000000..99257dd8 --- /dev/null +++ b/core/macros/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "zen-macros" +version = "0.49.1" +edition = "2024" +publish = false + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +quote = "1" +syn = { version = "2", features = ["full"] } +serde_derive_internals = "0.29" diff --git a/core/macros/src/lib.rs b/core/macros/src/lib.rs new file mode 100644 index 00000000..109b12cf --- /dev/null +++ b/core/macros/src/lib.rs @@ -0,0 +1,8 @@ +use proc_macro::TokenStream; + +mod to_variable; + +#[proc_macro_derive(ToVariable, attributes(serde))] +pub fn derive_to_variable(input: TokenStream) -> TokenStream { + to_variable::to_variable_impl(input) +} diff --git a/core/macros/src/to_variable.rs b/core/macros/src/to_variable.rs new file mode 100644 index 00000000..aad1f08f --- /dev/null +++ b/core/macros/src/to_variable.rs @@ -0,0 +1,212 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse_macro_input; + +pub fn to_variable_impl(input: TokenStream) -> TokenStream { + let mut input = parse_macro_input!(input as syn::DeriveInput); + + serde_derive_internals::replace_receiver(&mut input); + + let ctxt = serde_derive_internals::Ctxt::new(); + let container = match serde_derive_internals::ast::Container::from_ast( + &ctxt, + &input, + serde_derive_internals::Derive::Serialize, + ) { + Some(container) => container, + None => return ctxt.check().unwrap_err().into_compile_error().into(), + }; + + if let Err(err) = ctxt.check() { + return err.into_compile_error().into(); + } + + let ident = &container.ident; + let (impl_generics, ty_generics, where_clause) = container.generics.split_for_impl(); + + let body = match &container.data { + serde_derive_internals::ast::Data::Struct(_, fields) => generate_struct_body(fields), + serde_derive_internals::ast::Data::Enum(variants) => { + generate_enum_body(variants, &container) + } + }; + + let impl_block = quote! { + #[automatically_derived] + impl #impl_generics _ToVariable for #ident #ty_generics #where_clause { + fn to_variable(&self) -> _Variable { + #body + } + } + }; + + quote! { + #[doc(hidden)] + #[allow(non_upper_case_globals, unused_attributes, unused_qualifications, clippy::absolute_paths)] + const _: () = { + extern crate zen_expression as _zen_expression; + + use _zen_expression::variable::{Variable as _Variable, VariableMap as _VariableMap, VariableMapExt, ToVariable as _ToVariable}; + use ::std::rc::Rc as _Rc; + + #impl_block + }; + }.into() +} + +fn generate_struct_body(fields: &[serde_derive_internals::ast::Field]) -> proc_macro2::TokenStream { + let active_fields: Vec<_> = fields + .iter() + .filter(|field| !field.attrs.skip_serializing()) + .collect(); + + let field_count = active_fields.len(); + + let field_mappings = active_fields.iter().map(|field| { + let field_ident = match &field.member { + syn::Member::Named(ident) => ident, + syn::Member::Unnamed(_) => panic!("ToVariable only supports named fields"), + }; + + let serialized_name = field.attrs.name().serialize_name(); + + quote! { + map.insert( + _Rc::from(#serialized_name), + self.#field_ident.to_variable() + ); + } + }); + + quote! { + let mut map = _VariableMap::with_capacity(#field_count); + #(#field_mappings)* + _Variable::from_object(map) + } +} + +fn generate_enum_body( + variants: &[serde_derive_internals::ast::Variant], + container: &serde_derive_internals::ast::Container, +) -> proc_macro2::TokenStream { + let enum_ident = &container.ident; + + let active_variants: Vec<_> = variants + .iter() + .filter(|variant| !variant.attrs.skip_serializing()) + .collect(); + + let variant_arms = active_variants + .iter() + .map(|variant| generate_variant_arm(enum_ident, variant, container)); + + quote! { + match self { + #(#variant_arms)* + } + } +} + +fn generate_variant_arm( + enum_ident: &syn::Ident, + variant: &serde_derive_internals::ast::Variant, + container: &serde_derive_internals::ast::Container, +) -> proc_macro2::TokenStream { + let variant_ident = &variant.ident; + + let variant_name = variant.attrs.name().serialize_name(); + let rename_rule = container.attrs.rename_all_rules().serialize; + let type_key = rename_rule.apply_to_field("type"); + let value_key = rename_rule.apply_to_field("value"); + + match variant.style { + serde_derive_internals::ast::Style::Unit => { + quote! { + #enum_ident::#variant_ident => { + _Variable::String(_Rc::from(#variant_name)) + } + } + } + + serde_derive_internals::ast::Style::Newtype => { + quote! { + #enum_ident::#variant_ident(value) => { + let mut map = _VariableMap::with_capacity(2); + map.insert(_Rc::from(#type_key), _Variable::String(_Rc::from(#variant_name))); + map.insert(_Rc::from(#value_key), value.to_variable()); + _Variable::from_object(map) + } + } + } + + serde_derive_internals::ast::Style::Tuple => { + let field_count = variant.fields.len(); + let field_patterns: Vec<_> = (0..field_count) + .map(|i| quote::format_ident!("field_{}", i)) + .collect(); + + if field_count == 1 { + quote! { + #enum_ident::#variant_ident(#(#field_patterns),*) => { + let mut map = _VariableMap::with_capacity(2); + map.insert(_Rc::from(#type_key), _Variable::String(_Rc::from(#variant_name))); + map.insert(_Rc::from(#value_key), (#(#field_patterns)*).to_variable()); + _Variable::from_object(map) + } + } + } else { + let field_mappings = field_patterns.iter().enumerate().map(|(i, pattern)| { + let field_key = rename_rule.apply_to_field(&format!("field_{}", i)); + quote! { + map.insert(_Rc::from(#field_key), (#pattern).to_variable()); + } + }); + + quote! { + #enum_ident::#variant_ident(#(#field_patterns),*) => { + let mut map = _VariableMap::with_capacity(#field_count + 1); + map.insert(_Rc::from(#type_key), _Variable::String(_Rc::from(#variant_name))); + #(#field_mappings)* + _Variable::from_object(map) + } + } + } + } + + serde_derive_internals::ast::Style::Struct => { + let active_fields: Vec<_> = variant + .fields + .iter() + .filter(|field| !field.attrs.skip_serializing()) + .collect(); + + let field_mappings = active_fields.iter().map(|field| { + let field_ident = match &field.member { + syn::Member::Named(ident) => ident, + syn::Member::Unnamed(_) => panic!("Unexpected unnamed field in struct variant"), + }; + + let field_name = field.attrs.name().serialize_name(); + + quote! { + map.insert(_Rc::from(#field_name), #field_ident.to_variable()); + } + }); + + let field_patterns = active_fields.iter().map(|field| match &field.member { + syn::Member::Named(ident) => quote! { #ident }, + syn::Member::Unnamed(_) => panic!("Unexpected unnamed field in struct variant"), + }); + + let field_count = active_fields.len() + 1; + quote! { + #enum_ident::#variant_ident { #(#field_patterns),* } => { + let mut map = _VariableMap::with_capacity(#field_count); + map.insert(_Rc::from(#type_key), _Variable::String(_Rc::from(#variant_name))); + #(#field_mappings)* + _Variable::from_object(map) + } + } + } + } +} diff --git a/core/types/Cargo.toml b/core/types/Cargo.toml new file mode 100644 index 00000000..dcd3047b --- /dev/null +++ b/core/types/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "zen-types" +version = "0.49.1" +edition = "2024" +publish = false + +[dependencies] +ahash = { workspace = true } +serde = { workspace = true, features = ["rc", "derive"] } +serde_json = { workspace = true, features = ["arbitrary_precision"] } +rust_decimal = { workspace = true, features = ["maths-nopanic"] } +rust_decimal_macros = { workspace = true } +thiserror = { workspace = true } +nohash-hasher = "0.2.0" diff --git a/core/types/src/constant.rs b/core/types/src/constant.rs new file mode 100644 index 00000000..d4ba663d --- /dev/null +++ b/core/types/src/constant.rs @@ -0,0 +1 @@ +pub(crate) const NUMBER_TOKEN: &str = "$serde_json::private::Number"; diff --git a/core/types/src/lib.rs b/core/types/src/lib.rs new file mode 100644 index 00000000..1772aca9 --- /dev/null +++ b/core/types/src/lib.rs @@ -0,0 +1,4 @@ +mod constant; +pub mod rcvalue; +pub mod variable; +pub mod variable_type; diff --git a/core/types/src/rcvalue/conv.rs b/core/types/src/rcvalue/conv.rs new file mode 100644 index 00000000..1cf1a758 --- /dev/null +++ b/core/types/src/rcvalue/conv.rs @@ -0,0 +1,75 @@ +use crate::rcvalue::RcValue; +use crate::variable::{ToVariable, Variable}; +use rust_decimal::Decimal; +use serde_json::Value; +use std::rc::Rc; + +impl ToVariable for RcValue { + fn to_variable(&self) -> Variable { + match self { + RcValue::Null => Variable::Null, + RcValue::Bool(b) => Variable::Bool(*b), + RcValue::Number(n) => Variable::Number(*n), + RcValue::String(s) => Variable::String(Rc::from(s.as_ref())), + RcValue::Array(arr) => { + Variable::from_array(arr.iter().map(|v| v.to_variable()).collect()) + } + RcValue::Object(obj) => Variable::from_object( + obj.iter() + .map(|(k, v)| (Rc::from(k.as_ref()), v.to_variable())) + .collect(), + ), + } + } +} + +impl From<&Variable> for RcValue { + fn from(value: &Variable) -> Self { + match value { + Variable::Null => RcValue::Null, + Variable::Bool(b) => RcValue::Bool(*b), + Variable::Number(n) => RcValue::Number(*n), + Variable::String(s) => RcValue::String(s.clone()), + Variable::Array(arr) => { + let arr = arr.borrow(); + RcValue::Array(arr.iter().map(RcValue::from).collect()) + } + Variable::Object(obj) => { + let obj = obj.borrow(); + RcValue::Object( + obj.iter() + .map(|(k, v)| (k.clone(), RcValue::from(v))) + .collect(), + ) + } + Variable::Dynamic(d) => RcValue::from(&d.to_value()), + } + } +} + +impl From for RcValue { + fn from(value: Variable) -> Self { + Self::from(&value) + } +} + +impl From<&Value> for RcValue { + fn from(value: &Value) -> Self { + match value { + Value::Null => RcValue::Null, + Value::Bool(b) => RcValue::Bool(*b), + Value::Number(n) => RcValue::Number( + Decimal::from_str_exact(n.as_str()) + .or_else(|_| Decimal::from_scientific(n.as_str())) + .expect("Allowed number"), + ), + Value::String(s) => RcValue::String(Rc::from(s.as_str())), + Value::Array(arr) => RcValue::Array(arr.iter().map(RcValue::from).collect()), + Value::Object(obj) => RcValue::Object( + obj.iter() + .map(|(k, v)| (Rc::from(k.as_str()), RcValue::from(v))) + .collect(), + ), + } + } +} diff --git a/core/types/src/rcvalue/de.rs b/core/types/src/rcvalue/de.rs new file mode 100644 index 00000000..087ef1b7 --- /dev/null +++ b/core/types/src/rcvalue/de.rs @@ -0,0 +1,133 @@ +use crate::constant::NUMBER_TOKEN; +use crate::rcvalue::RcValue; +use ahash::{HashMap, HashMapExt}; +use rust_decimal::Decimal; +use rust_decimal::prelude::FromPrimitive; +use serde::de::{DeserializeSeed, Error, MapAccess, SeqAccess, Unexpected, Visitor}; +use serde::{Deserialize, Deserializer}; +use std::fmt::Formatter; +use std::marker::PhantomData; +use std::ops::Deref; +use std::rc::Rc; + +struct RcValueVisitor; + +impl<'de> Visitor<'de> for RcValueVisitor { + type Value = RcValue; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("A valid type") + } + + fn visit_bool(self, v: bool) -> Result + where + E: Error, + { + Ok(RcValue::Bool(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + Ok(RcValue::Number(Decimal::from_i64(v).ok_or_else(|| { + Error::invalid_value(Unexpected::Signed(v), &self) + })?)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: Error, + { + Ok(RcValue::Number(Decimal::from_u64(v).ok_or_else(|| { + Error::invalid_value(Unexpected::Unsigned(v), &self) + })?)) + } + + fn visit_f64(self, v: f64) -> Result + where + E: Error, + { + Ok(RcValue::Number(Decimal::from_f64(v).ok_or_else(|| { + Error::invalid_value(Unexpected::Float(v), &self) + })?)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + Ok(RcValue::String(Rc::from(v))) + } + + fn visit_unit(self) -> Result + where + E: Error, + { + Ok(RcValue::Null) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::with_capacity(seq.size_hint().unwrap_or_default()); + while let Some(value) = seq.next_element_seed(RcValueDeserializer)? { + vec.push(value); + } + + Ok(RcValue::Array(vec)) + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut m = HashMap::with_capacity(map.size_hint().unwrap_or_default()); + let mut first = true; + + while let Some((key, value)) = + map.next_entry_seed(PhantomData::>, RcValueDeserializer)? + { + if first && key.deref() == NUMBER_TOKEN { + let str = match &value { + RcValue::String(s) => s.as_ref(), + _ => return Err(Error::custom("failed to deserialize number")), + }; + + return Ok(RcValue::Number( + Decimal::from_str_exact(str) + .or_else(|_| Decimal::from_scientific(str)) + .map_err(|_| Error::custom("invalid number"))?, + )); + } + + m.insert(key, value); + first = false; + } + + Ok(RcValue::Object(m)) + } +} + +pub struct RcValueDeserializer; + +impl<'de> DeserializeSeed<'de> for RcValueDeserializer { + type Value = RcValue; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(RcValueVisitor) + } +} + +impl<'de> Deserialize<'de> for RcValue { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(RcValueVisitor) + } +} diff --git a/core/types/src/rcvalue/mod.rs b/core/types/src/rcvalue/mod.rs new file mode 100644 index 00000000..3e7dd912 --- /dev/null +++ b/core/types/src/rcvalue/mod.rs @@ -0,0 +1,19 @@ +mod conv; +mod de; +mod ser; + +use ahash::HashMap; +pub use de::RcValueDeserializer; +use rust_decimal::Decimal; +use std::rc::Rc; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub enum RcValue { + #[default] + Null, + Bool(bool), + Number(Decimal), + String(Rc), + Array(Vec), + Object(HashMap, RcValue>), +} diff --git a/core/types/src/rcvalue/ser.rs b/core/types/src/rcvalue/ser.rs new file mode 100644 index 00000000..0cd6ca5a --- /dev/null +++ b/core/types/src/rcvalue/ser.rs @@ -0,0 +1,26 @@ +use crate::constant::NUMBER_TOKEN; +use crate::rcvalue::RcValue; +use serde::ser::SerializeStruct; +use serde::{Serialize, Serializer}; + +impl Serialize for RcValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + RcValue::Null => serializer.serialize_unit(), + RcValue::Bool(v) => serializer.serialize_bool(*v), + RcValue::Number(v) => { + let str = v.normalize().to_string(); + + let mut s = serializer.serialize_struct(NUMBER_TOKEN, 1)?; + s.serialize_field(NUMBER_TOKEN, &str)?; + s.end() + } + RcValue::String(v) => serializer.serialize_str(v), + RcValue::Array(v) => serializer.collect_seq(v.iter()), + RcValue::Object(v) => serializer.collect_map(v.iter()), + } + } +} diff --git a/core/expression/src/variable/conv.rs b/core/types/src/variable/conv.rs similarity index 76% rename from core/expression/src/variable/conv.rs rename to core/types/src/variable/conv.rs index 8caeb15b..b81137d1 100644 --- a/core/expression/src/variable/conv.rs +++ b/core/types/src/variable/conv.rs @@ -1,8 +1,4 @@ use crate::variable::Variable; -use crate::vm::helpers::date_time; -use crate::vm::VMError; -use chrono::NaiveDateTime; -use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; use serde_json::{Number, Value}; use std::rc::Rc; @@ -88,28 +84,3 @@ impl From for Value { } } } - -impl TryFrom<&Variable> for NaiveDateTime { - type Error = VMError; - - fn try_from(value: &Variable) -> Result { - match value { - Variable::String(a) => date_time(a), - #[allow(deprecated)] - Variable::Number(a) => NaiveDateTime::from_timestamp_opt( - a.to_i64().ok_or_else(|| VMError::OpcodeErr { - opcode: "DateManipulation".into(), - message: "Failed to extract date".into(), - })?, - 0, - ) - .ok_or_else(|| VMError::ParseDateTimeErr { - timestamp: a.to_string(), - }), - _ => Err(VMError::OpcodeErr { - opcode: "DateManipulation".into(), - message: "Unsupported type".into(), - }), - } - } -} diff --git a/core/expression/src/variable/de.rs b/core/types/src/variable/de.rs similarity index 98% rename from core/expression/src/variable/de.rs rename to core/types/src/variable/de.rs index 81066f37..6c3ad6fd 100644 --- a/core/expression/src/variable/de.rs +++ b/core/types/src/variable/de.rs @@ -1,7 +1,8 @@ +use crate::constant::NUMBER_TOKEN; use crate::variable::Variable; use ahash::{HashMap, HashMapExt}; -use rust_decimal::prelude::FromPrimitive; use rust_decimal::Decimal; +use rust_decimal::prelude::FromPrimitive; use serde::de::{DeserializeSeed, Error, MapAccess, SeqAccess, Unexpected, Visitor}; use serde::{Deserialize, Deserializer}; use std::fmt::Formatter; @@ -11,8 +12,6 @@ use std::rc::Rc; struct VariableVisitor; -pub(super) const NUMBER_TOKEN: &str = "$serde_json::private::Number"; - impl<'de> Visitor<'de> for VariableVisitor { type Value = Variable; diff --git a/core/types/src/variable/impls.rs b/core/types/src/variable/impls.rs new file mode 100644 index 00000000..3b6508b3 --- /dev/null +++ b/core/types/src/variable/impls.rs @@ -0,0 +1,196 @@ +use crate::variable::Variable; +use rust_decimal::Decimal; +use rust_decimal::prelude::FromPrimitive; +use serde_json::Value; +use std::collections::HashMap; +use std::rc::Rc; +use std::sync::Arc; + +pub trait ToVariable { + fn to_variable(&self) -> Variable; +} + +impl ToVariable for String { + fn to_variable(&self) -> Variable { + Variable::String(Rc::from(self.as_str())) + } +} + +impl ToVariable for str { + fn to_variable(&self) -> Variable { + Variable::String(Rc::from(self)) + } +} + +impl ToVariable for bool { + fn to_variable(&self) -> Variable { + Variable::Bool(*self) + } +} + +impl ToVariable for Decimal { + fn to_variable(&self) -> Variable { + Variable::Number(*self) + } +} + +impl ToVariable for Variable { + fn to_variable(&self) -> Variable { + self.clone() + } +} + +impl ToVariable for Value { + fn to_variable(&self) -> Variable { + Variable::from(self) + } +} + +macro_rules! impl_to_variable_numeric { + ($($t:ty),* $(,)?) => { + $( + impl ToVariable for $t { + fn to_variable(&self) -> Variable { + Variable::Number(Decimal::from(*self)) + } + } + )* + }; +} + +impl_to_variable_numeric!( + i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize +); + +impl ToVariable for f32 { + fn to_variable(&self) -> Variable { + Variable::Number(Decimal::from_f32(*self).unwrap_or_default()) + } +} + +impl ToVariable for f64 { + fn to_variable(&self) -> Variable { + Variable::Number(Decimal::from_f64(*self).unwrap_or_default()) + } +} + +impl ToVariable for Vec +where + T: ToVariable, +{ + fn to_variable(&self) -> Variable { + Variable::from_array(self.iter().map(|v| v.to_variable()).collect()) + } +} + +impl ToVariable for HashMap, V, S> +where + V: ToVariable, + S: std::hash::BuildHasher, +{ + fn to_variable(&self) -> Variable { + Variable::from_object( + self.iter() + .map(|(k, v)| (k.clone(), v.to_variable())) + .collect(), + ) + } +} + +impl ToVariable for HashMap +where + V: ToVariable, + S: std::hash::BuildHasher, +{ + fn to_variable(&self) -> Variable { + Variable::from_object( + self.iter() + .map(|(k, v)| (Rc::from(k.as_str()), v.to_variable())) + .collect(), + ) + } +} + +macro_rules! tuple_impls { + ( $( ($($T:ident),+) ),+ ) => { + $( + impl<$($T),+> ToVariable for ($($T,)+) + where + $($T: ToVariable,)+ + { + #[allow(non_snake_case)] + fn to_variable(&self) -> Variable { + let ($($T,)+) = self; + Variable::from_array(vec![ + $($T.to_variable(),)+ + ]) + } + } + )+ + }; +} + +tuple_impls! { + (T1), + (T1, T2), + (T1, T2, T3), + (T1, T2, T3, T4), + (T1, T2, T3, T4, T5) +} + +impl ToVariable for &T +where + T: ?Sized + ToVariable, +{ + fn to_variable(&self) -> Variable { + (**self).to_variable() + } +} + +impl ToVariable for &mut T +where + T: ?Sized + ToVariable, +{ + fn to_variable(&self) -> Variable { + (**self).to_variable() + } +} + +impl ToVariable for Option +where + T: ToVariable, +{ + fn to_variable(&self) -> Variable { + match self { + Some(value) => value.to_variable(), + None => Variable::Null, + } + } +} + +impl ToVariable for Box +where + T: ?Sized + ToVariable, +{ + fn to_variable(&self) -> Variable { + (**self).to_variable() + } +} + +impl ToVariable for Rc +where + T: ?Sized + ToVariable, +{ + fn to_variable(&self) -> Variable { + (**self).to_variable() + } +} + +impl ToVariable for Arc +where + T: ?Sized + ToVariable, +{ + fn to_variable(&self) -> Variable { + (**self).to_variable() + } +} diff --git a/core/types/src/variable/mod.rs b/core/types/src/variable/mod.rs new file mode 100644 index 00000000..644d6f8f --- /dev/null +++ b/core/types/src/variable/mod.rs @@ -0,0 +1,503 @@ +use crate::variable::ref_ser::RefSerializer; +use ahash::HashMap; +use rust_decimal::Decimal; +use rust_decimal::prelude::Zero; +use serde_json::Value; +use std::any::Any; +use std::cell::RefCell; +use std::collections::hash_map::Entry; +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; +use std::rc::Rc; + +use crate::rcvalue::RcValue; +pub use crate::variable::ref_deser::RefDeserializeError; +use crate::variable::ref_deser::RefDeserializer; +pub use de::VariableDeserializer; +pub use impls::ToVariable; + +mod conv; +mod de; +mod impls; +mod ref_deser; +mod ref_ser; +mod ser; + +pub(crate) type RcCell = Rc>; + +pub type VariableMap = HashMap, Variable>; + +pub enum Variable { + Null, + Bool(bool), + Number(Decimal), + String(Rc), + Array(RcCell>), + Object(RcCell), + Dynamic(Rc), +} + +pub trait DynamicVariable: Display { + fn type_name(&self) -> &'static str; + + fn as_any(&self) -> &dyn Any; + + fn to_value(&self) -> Value; +} + +impl Variable { + pub fn from_array(arr: Vec) -> Self { + Self::Array(Rc::new(RefCell::new(arr))) + } + + pub fn serialize_ref(&self) -> RcValue { + RefSerializer::new().serialize(self) + } + + pub fn deserialize_ref(serialized: RcValue) -> Result { + RefDeserializer::new().deserialize(serialized) + } + + pub fn from_object(obj: HashMap, Self>) -> Self { + Self::Object(Rc::new(RefCell::new(obj))) + } + + pub fn empty_object() -> Self { + Variable::Object(Default::default()) + } + + pub fn empty_array() -> Self { + Variable::Array(Default::default()) + } + + pub fn as_str(&self) -> Option<&str> { + match self { + Variable::String(s) => Some(s.as_ref()), + _ => None, + } + } + + pub fn as_rc_str(&self) -> Option> { + match self { + Variable::String(s) => Some(s.clone()), + _ => None, + } + } + + pub fn as_array(&self) -> Option>> { + match self { + Variable::Array(arr) => Some(arr.clone()), + _ => None, + } + } + + pub fn is_array(&self) -> bool { + match self { + Variable::Array(_) => true, + _ => false, + } + } + + pub fn as_object(&self) -> Option, Variable>>> { + match self { + Variable::Object(obj) => Some(obj.clone()), + _ => None, + } + } + + pub fn is_object(&self) -> bool { + match self { + Variable::Object(_) => true, + _ => false, + } + } + + pub fn as_bool(&self) -> Option { + match self { + Variable::Bool(b) => Some(*b), + _ => None, + } + } + + pub fn as_number(&self) -> Option { + match self { + Variable::Number(n) => Some(*n), + _ => None, + } + } + + pub fn type_name(&self) -> &'static str { + match self { + Variable::Null => "null", + Variable::Bool(_) => "bool", + Variable::Number(_) => "number", + Variable::String(_) => "string", + Variable::Array(_) => "array", + Variable::Object(_) => "object", + Variable::Dynamic(d) => d.type_name(), + } + } + + pub fn dynamic(&self) -> Option<&T> { + match self { + Variable::Dynamic(d) => d.as_any().downcast_ref::(), + _ => None, + } + } + + pub fn to_value(&self) -> Value { + Value::from(self.shallow_clone()) + } + + pub fn dot(&self, key: &str) -> Option { + key.split('.') + .try_fold(self.shallow_clone(), |var, part| match var { + Variable::Object(obj) => { + let reference = obj.borrow(); + reference.get(part).map(|v| v.shallow_clone()) + } + _ => None, + }) + } + + fn dot_head(&self, key: &str) -> Option { + let mut parts = Vec::from_iter(key.split('.')); + parts.pop(); + + parts + .iter() + .try_fold(self.shallow_clone(), |var, part| match var { + Variable::Object(obj) => { + let mut obj_ref = obj.borrow_mut(); + Some(match obj_ref.entry(Rc::from(*part)) { + Entry::Occupied(occ) => occ.get().shallow_clone(), + Entry::Vacant(vac) => vac.insert(Self::empty_object()).shallow_clone(), + }) + } + _ => None, + }) + } + + fn dot_head_detach(&self, key: &str) -> (Variable, Option) { + let mut parts = Vec::from_iter(key.split('.')); + parts.pop(); + + let cloned_self = self.depth_clone(1); + let head = parts + .iter() + .try_fold(cloned_self.shallow_clone(), |var, part| match var { + Variable::Object(obj) => { + let mut obj_ref = obj.borrow_mut(); + Some(match obj_ref.entry(Rc::from(*part)) { + Entry::Occupied(mut occ) => { + let var = occ.get(); + let new_obj = match var { + Variable::Object(_) => var.depth_clone(1), + _ => Variable::empty_object(), + }; + + occ.insert(new_obj.shallow_clone()); + new_obj + } + Entry::Vacant(vac) => vac.insert(Self::empty_object()).shallow_clone(), + }) + } + _ => None, + }); + + (cloned_self, head) + } + + pub fn dot_remove(&self, key: &str) -> Option { + let last_part = key.split('.').last()?; + let head = self.dot_head(key)?; + let Variable::Object(object_ref) = head else { + return None; + }; + + let mut object = object_ref.borrow_mut(); + object.remove(last_part) + } + + pub fn dot_insert(&self, key: &str, variable: Variable) -> Option { + let last_part = key.split('.').last()?; + let head = self.dot_head(key)?; + let Variable::Object(object_ref) = head else { + return None; + }; + + let mut object = object_ref.borrow_mut(); + object.insert(Rc::from(last_part), variable) + } + + pub fn dot_insert_detached(&self, key: &str, variable: Variable) -> Option { + let last_part = key.split('.').last()?; + let (new_var, head_opt) = self.dot_head_detach(key); + let head = head_opt?; + let Variable::Object(object_ref) = head else { + return None; + }; + + let mut object = object_ref.borrow_mut(); + object.insert(Rc::from(last_part), variable); + Some(new_var) + } + + pub fn merge(&mut self, patch: &Variable) -> Variable { + let _ = merge_variables(self, patch, true, MergeStrategy::InPlace); + + self.shallow_clone() + } + + pub fn merge_clone(&mut self, patch: &Variable) -> Variable { + let mut new_self = self.shallow_clone(); + + let _ = merge_variables(&mut new_self, patch, true, MergeStrategy::CloneOnWrite); + new_self + } + + pub fn shallow_clone(&self) -> Self { + match self { + Variable::Null => Variable::Null, + Variable::Bool(b) => Variable::Bool(*b), + Variable::Number(n) => Variable::Number(*n), + Variable::String(s) => Variable::String(s.clone()), + Variable::Array(a) => Variable::Array(a.clone()), + Variable::Object(o) => Variable::Object(o.clone()), + Variable::Dynamic(d) => Variable::Dynamic(d.clone()), + } + } + + pub fn deep_clone(&self) -> Self { + match self { + Variable::Array(a) => { + let arr = a.borrow(); + Variable::from_array(arr.iter().map(|v| v.deep_clone()).collect()) + } + Variable::Object(o) => { + let obj = o.borrow(); + Variable::from_object( + obj.iter() + .map(|(k, v)| (k.clone(), v.deep_clone())) + .collect(), + ) + } + _ => self.shallow_clone(), + } + } + + pub fn depth_clone(&self, depth: usize) -> Self { + match depth.is_zero() { + true => self.shallow_clone(), + false => match self { + Variable::Array(a) => { + let arr = a.borrow(); + Variable::from_array(arr.iter().map(|v| v.depth_clone(depth - 1)).collect()) + } + Variable::Object(o) => { + let obj = o.borrow(); + Variable::from_object( + obj.iter() + .map(|(k, v)| (k.clone(), v.depth_clone(depth - 1))) + .collect(), + ) + } + _ => self.shallow_clone(), + }, + } + } +} + +impl Clone for Variable { + fn clone(&self) -> Self { + self.shallow_clone() + } +} + +#[derive(Copy, Clone)] +enum MergeStrategy { + InPlace, + CloneOnWrite, +} + +fn merge_variables( + doc: &mut Variable, + patch: &Variable, + top_level: bool, + strategy: MergeStrategy, +) -> bool { + if patch.is_array() && top_level { + *doc = patch.shallow_clone(); + return true; + } + + if !patch.is_object() && top_level { + return false; + } + + if doc.is_object() && patch.is_object() { + let doc_ref = doc.as_object().unwrap(); + let patch_ref = patch.as_object().unwrap(); + if Rc::ptr_eq(&doc_ref, &patch_ref) { + return false; + } + + let patch = patch_ref.borrow(); + match strategy { + MergeStrategy::InPlace => { + let mut map = doc_ref.borrow_mut(); + for (key, value) in patch.deref() { + if value == &Variable::Null { + map.remove(key); + } else { + let entry = map.entry(key.clone()).or_insert(Variable::Null); + merge_variables(entry, value, false, strategy); + } + } + + return true; + } + MergeStrategy::CloneOnWrite => { + let mut changed = false; + let mut new_map = None; + + for (key, value) in patch.deref() { + // Get or create the new map if we haven't yet + let map = if let Some(ref mut m) = new_map { + m + } else { + let m = doc_ref.borrow().clone(); + new_map = Some(m); + new_map.as_mut().unwrap() + }; + + if value == &Variable::Null { + // Remove null values + if map.remove(key).is_some() { + changed = true; + } + } else { + // Handle nested merging + let entry = map.entry(key.clone()).or_insert(Variable::Null); + if merge_variables(entry, value, false, strategy) { + changed = true; + } + } + } + + // Only update doc if changes were made + if changed { + if let Some(new_map) = new_map { + *doc = Variable::Object(Rc::new(RefCell::new(new_map))); + } + return true; + } + + return false; + } + } + } else { + let new_value = patch.shallow_clone(); + if *doc != new_value { + *doc = new_value; + return true; + } + + return false; + } +} + +impl Default for Variable { + fn default() -> Self { + Variable::Null + } +} + +impl Display for Variable { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Variable::Null => write!(f, "null"), + Variable::Bool(b) => match *b { + true => write!(f, "true"), + false => write!(f, "false"), + }, + Variable::Number(n) => write!(f, "{n}"), + Variable::String(s) => write!(f, "\"{s}\""), + Variable::Array(arr) => { + let arr = arr.borrow(); + let s = arr + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(","); + write!(f, "[{s}]") + } + Variable::Object(obj) => { + let obj = obj.borrow(); + let s = obj + .iter() + .map(|(k, v)| format!("\"{k}\":{v}")) + .collect::>() + .join(","); + + write!(f, "{{{s}}}") + } + Variable::Dynamic(d) => write!(f, "{d}"), + } + } +} + +impl Debug for Variable { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self) + } +} + +impl PartialEq for Variable { + fn eq(&self, other: &Self) -> bool { + match (&self, &other) { + (Variable::Null, Variable::Null) => true, + (Variable::Bool(b1), Variable::Bool(b2)) => b1 == b2, + (Variable::Number(n1), Variable::Number(n2)) => n1 == n2, + (Variable::String(s1), Variable::String(s2)) => s1 == s2, + (Variable::Array(a1), Variable::Array(a2)) => a1 == a2, + (Variable::Object(obj1), Variable::Object(obj2)) => obj1 == obj2, + (Variable::Dynamic(d1), Variable::Dynamic(d2)) => Rc::ptr_eq(d1, d2), + _ => false, + } + } +} + +impl Eq for Variable {} + +#[cfg(test)] +mod tests { + use crate::variable::Variable; + use rust_decimal_macros::dec; + use serde_json::json; + + #[test] + fn insert_detached() { + let some_data: Variable = json!({ "customer": { "firstName": "John" }}).into(); + + let a_a = some_data + .dot_insert_detached("a.a", Variable::Number(dec!(1))) + .unwrap(); + let a_b = a_a + .dot_insert_detached("a.b", Variable::Number(dec!(2))) + .unwrap(); + let a_c = a_b + .dot_insert_detached("a.c", Variable::Number(dec!(3))) + .unwrap(); + + assert_eq!(a_a.dot("a"), Some(Variable::from(json!({ "a": 1 })))); + assert_eq!( + a_b.dot("a"), + Some(Variable::from(json!({ "a": 1, "b": 2 }))) + ); + assert_eq!( + a_c.dot("a"), + Some(Variable::from(json!({ "a": 1, "b": 2, "c": 3 }))) + ); + } +} diff --git a/core/types/src/variable/ref_deser.rs b/core/types/src/variable/ref_deser.rs new file mode 100644 index 00000000..d3762141 --- /dev/null +++ b/core/types/src/variable/ref_deser.rs @@ -0,0 +1,152 @@ +use crate::rcvalue::RcValue; +use crate::variable::Variable; +use ahash::{HashMap, HashMapExt}; +use std::cell::RefCell; +use std::rc::Rc; +use thiserror::Error; + +pub struct RefDeserializer { + refs: Vec>, +} + +impl RefDeserializer { + pub fn new() -> Self { + Self { refs: Vec::new() } + } + + pub fn deserialize(&mut self, value: RcValue) -> Result { + let RcValue::Object(mut root_obj) = value else { + return Err(RefDeserializeError::InvalidFormat( + "Expected root object".into(), + )); + }; + + if let Some(RcValue::Array(refs_array)) = root_obj.remove(&Rc::from("$refs")) { + self.refs = vec![None; refs_array.len()]; + + for (i, _) in refs_array.iter().enumerate() { + match &refs_array[i] { + RcValue::Array(_) => { + self.refs[i] = Some(Variable::Array(Rc::new(RefCell::new(Vec::new())))); + } + RcValue::Object(_) => { + self.refs[i] = + Some(Variable::Object(Rc::new(RefCell::new(HashMap::default())))); + } + _ => { + self.refs[i] = Some(self.deserialize_value(&refs_array[i])?); + } + } + } + + for (i, ref_value) in refs_array.iter().enumerate() { + match ref_value { + RcValue::Array(arr) => { + if let Some(Variable::Array(target)) = &self.refs[i] { + let mut items = Vec::with_capacity(arr.len()); + for item in arr { + items.push(self.deserialize_value(item)?); + } + *target.borrow_mut() = items; + } + } + RcValue::Object(obj) => { + if let Some(Variable::Object(target)) = &self.refs[i] { + let mut map = HashMap::with_capacity(obj.len()); + for (key, value) in obj { + let key_var = self.deserialize_key(key)?; + let value_var = self.deserialize_value(value)?; + map.insert(key_var, value_var); + } + *target.borrow_mut() = map; + } + } + _ => {} + } + } + } + + let root_value = root_obj + .remove(&Rc::from("$root")) + .ok_or_else(|| RefDeserializeError::InvalidFormat("Missing $root".into()))?; + + self.deserialize_value(&root_value) + } + + fn deserialize_key(&self, key: &Rc) -> Result, RefDeserializeError> { + if let Some(ref_id) = parse_ref_id(key) { + if ref_id >= self.refs.len() { + return Err(RefDeserializeError::InvalidReference(ref_id)); + } + + match &self.refs[ref_id] { + Some(Variable::String(s)) => Ok(s.clone()), + Some(_) => Err(RefDeserializeError::InvalidFormat( + "Reference used as key must be a string".into(), + )), + None => Err(RefDeserializeError::UnresolvedReference(ref_id)), + } + } else { + Ok(unescape_at_string(key)) + } + } + + fn deserialize_value(&self, value: &RcValue) -> Result { + match value { + RcValue::Null => Ok(Variable::Null), + RcValue::Bool(b) => Ok(Variable::Bool(*b)), + RcValue::Number(n) => Ok(Variable::Number(*n)), + RcValue::String(s) => { + if let Some(ref_id) = parse_ref_id(s) { + if ref_id >= self.refs.len() { + return Err(RefDeserializeError::InvalidReference(ref_id)); + } + + self.refs[ref_id] + .clone() + .ok_or(RefDeserializeError::UnresolvedReference(ref_id)) + } else { + Ok(Variable::String(unescape_at_string(s))) + } + } + RcValue::Array(arr) => { + let mut items = Vec::with_capacity(arr.len()); + for item in arr { + items.push(self.deserialize_value(item)?); + } + Ok(Variable::Array(Rc::new(RefCell::new(items)))) + } + RcValue::Object(obj) => { + let mut map = HashMap::with_capacity(obj.len()); + for (key, value) in obj { + let key_var = self.deserialize_key(key)?; + let value_var = self.deserialize_value(value)?; + map.insert(key_var, value_var); + } + Ok(Variable::Object(Rc::new(RefCell::new(map)))) + } + } + } +} + +#[derive(Debug, Error)] +pub enum RefDeserializeError { + #[error("Invalid format: {0}")] + InvalidFormat(String), + #[error("Invalid reference: {0}")] + InvalidReference(usize), + #[error("UnresolvedReference: {0}")] + UnresolvedReference(usize), +} + +fn unescape_at_string(s: &Rc) -> Rc { + if s.starts_with("@@") { + Rc::from(&s[1..]) + } else { + s.clone() + } +} + +fn parse_ref_id(s: &str) -> Option { + s.strip_prefix('@')?.parse().ok() +} diff --git a/core/types/src/variable/ref_ser.rs b/core/types/src/variable/ref_ser.rs new file mode 100644 index 00000000..047e3d07 --- /dev/null +++ b/core/types/src/variable/ref_ser.rs @@ -0,0 +1,201 @@ +use crate::rcvalue::RcValue; +use crate::variable::Variable; +use ahash::AHashMap; +use nohash_hasher::BuildNoHashHasher; +use std::collections::HashMap; +use std::rc::Rc; + +pub struct RefSerializer { + ref_counts: HashMap>, + refs: HashMap), BuildNoHashHasher>, + string_intern: AHashMap, Rc>, + ref_data: Vec, + min_ref_count: usize, + min_str_len: usize, +} + +impl RefSerializer { + pub fn new() -> Self { + Self { + ref_counts: HashMap::default(), + refs: HashMap::default(), + string_intern: AHashMap::default(), + ref_data: Vec::new(), + min_ref_count: 2, + min_str_len: 5, + } + } + + fn escape_at_string(s: &Rc) -> Rc { + if s.starts_with('@') { + let string = format!("@{s}"); + Rc::from(string.as_str()) + } else { + s.clone() + } + } + + fn intern_string_addr(&mut self, s: &Rc) -> usize { + let reference = match self.string_intern.get(s) { + Some(interned) => interned, + None => { + self.string_intern.insert(s.clone(), s.clone()); + s + } + }; + + Rc::as_ptr(&reference) as *const () as usize + } + + pub fn serialize(mut self, var: &Variable) -> RcValue { + self.count_refs(var); + self.assign_ref_ids(); + + let data = self.serialize_with_refs(var); + + let mut result = HashMap::default(); + if !self.ref_data.is_empty() { + result.insert(Rc::from("$refs"), RcValue::Array(self.ref_data)); + } + + result.insert(Rc::from("$root"), data); + RcValue::Object(result) + } + + fn count_refs(&mut self, var: &Variable) { + match var { + Variable::String(s) => { + if s.len() < self.min_str_len { + return; + } + + let addr = self.intern_string_addr(s); + *self.ref_counts.entry(addr).or_insert(0) += 1; + } + Variable::Array(arr) => { + let addr = Rc::as_ptr(arr) as *const () as usize; + *self.ref_counts.entry(addr).or_insert(0) += 1; + + let borrowed = arr.borrow(); + for item in borrowed.iter() { + self.count_refs(item); + } + } + Variable::Object(obj) => { + let addr = Rc::as_ptr(obj) as *const () as usize; + *self.ref_counts.entry(addr).or_insert(0) += 1; + + let borrowed = obj.borrow(); + for (key, value) in borrowed.iter() { + let key_addr = self.intern_string_addr(key); + *self.ref_counts.entry(key_addr).or_insert(0) += 1; + self.count_refs(value); + } + } + Variable::Dynamic(_) => {} + _ => {} // Null, Bool, Number don't need ref counting + } + } + + fn assign_ref_ids(&mut self) { + let mut sorted_refs: Vec<_> = self + .ref_counts + .iter() + .filter(|&(_, &count)| count >= self.min_ref_count) + .collect(); + + sorted_refs.sort_by(|a, b| b.1.cmp(&a.1).then(b.0.cmp(&a.0))); + + self.refs.reserve(sorted_refs.len()); + self.ref_data.reserve(sorted_refs.len()); + + for (&addr, _) in sorted_refs { + let id = self.ref_data.len(); + let id_string = format!("@{id}"); + + self.refs.insert(addr, (id, Rc::from(id_string.as_str()))); + self.ref_data.push(RcValue::Null); + } + } + + fn serialize_with_refs(&mut self, var: &Variable) -> RcValue { + match var { + Variable::String(s) => { + let addr = self.intern_string_addr(s); + let Some((id, id_str)) = self.refs.get(&addr) else { + return RcValue::String(Self::escape_at_string(s)); + }; + + if self.ref_data[*id] == RcValue::Null { + self.ref_data[*id] = RcValue::String(Self::escape_at_string(s)); + } + + RcValue::String(id_str.clone()) + } + + Variable::Array(arr) => { + let addr = Rc::as_ptr(arr) as *const () as usize; + let data = { + let borrowed = arr.borrow(); + let items: Vec<_> = borrowed + .iter() + .map(|item| self.serialize_with_refs(item)) + .collect(); + + RcValue::Array(items) + }; + + let Some((id, id_str)) = self.refs.get(&addr) else { + return data; + }; + + if self.ref_data[*id] == RcValue::Null { + self.ref_data[*id] = data; + } + + RcValue::String(id_str.clone()) + } + + Variable::Object(obj) => { + let addr = Rc::as_ptr(obj) as *const () as usize; + let data = { + let borrowed = obj.borrow(); + let mut map = HashMap::with_capacity_and_hasher( + borrowed.len(), + ahash::RandomState::new(), + ); + + for (key, value) in borrowed.iter() { + let key_addr = self.intern_string_addr(key); + let key_str = if let Some((key_id, key_id_str)) = self.refs.get(&key_addr) { + if self.ref_data[*key_id] == RcValue::Null { + self.ref_data[*key_id] = + RcValue::String(Self::escape_at_string(key)); + } + + key_id_str.clone() + } else { + Self::escape_at_string(key) + }; + + map.insert(key_str, self.serialize_with_refs(value)); + } + + RcValue::Object(map) + }; + + let Some((id, id_str)) = self.refs.get(&addr) else { + return data; + }; + + if self.ref_data[*id] == RcValue::Null { + self.ref_data[*id] = data; + } + + RcValue::String(id_str.clone()) + } + + _ => RcValue::from(var), + } + } +} diff --git a/core/expression/src/variable/ser.rs b/core/types/src/variable/ser.rs similarity index 96% rename from core/expression/src/variable/ser.rs rename to core/types/src/variable/ser.rs index 6bbf1d2f..1f57b709 100644 --- a/core/expression/src/variable/ser.rs +++ b/core/types/src/variable/ser.rs @@ -1,4 +1,4 @@ -use crate::variable::de::NUMBER_TOKEN; +use crate::constant::NUMBER_TOKEN; use crate::variable::Variable; use serde::ser::SerializeStruct; use serde::{Serialize, Serializer}; diff --git a/core/expression/src/variable/types/conv.rs b/core/types/src/variable_type/conv.rs similarity index 84% rename from core/expression/src/variable/types/conv.rs rename to core/types/src/variable_type/conv.rs index 761d6a1f..943b5bda 100644 --- a/core/expression/src/variable/types/conv.rs +++ b/core/types/src/variable_type/conv.rs @@ -1,4 +1,4 @@ -use crate::variable::types::VariableType; +use crate::variable_type::VariableType; use serde_json::Value; use std::borrow::Cow; use std::cell::RefCell; @@ -79,17 +79,3 @@ impl From<&Vec> for VariableType { VariableType::Array(Rc::new(result_type.unwrap_or(VariableType::Any))) } } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_value_to_value_kind() { - assert_eq!(VariableType::from(json!(null)), VariableType::Null); - assert_eq!(VariableType::from(json!(true)), VariableType::Bool); - assert_eq!(VariableType::from(json!(42)), VariableType::Number); - assert_eq!(VariableType::from(json!("hello")), VariableType::String); - } -} diff --git a/core/expression/src/variable/types/mod.rs b/core/types/src/variable_type/mod.rs similarity index 97% rename from core/expression/src/variable/types/mod.rs rename to core/types/src/variable_type/mod.rs index bdf0c446..dabcf776 100644 --- a/core/expression/src/variable/types/mod.rs +++ b/core/types/src/variable_type/mod.rs @@ -2,8 +2,9 @@ mod conv; mod util; use crate::variable::RcCell; +use ahash::HashMap; +pub use ahash::HashMapExt as VariableMapExt; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::fmt::{Display, Write}; use std::hash::{Hash, Hasher}; use std::rc::Rc; diff --git a/core/expression/src/variable/types/util.rs b/core/types/src/variable_type/util.rs similarity index 90% rename from core/expression/src/variable/types/util.rs rename to core/types/src/variable_type/util.rs index 031f4857..03e17d18 100644 --- a/core/expression/src/variable/types/util.rs +++ b/core/types/src/variable_type/util.rs @@ -1,8 +1,8 @@ -use crate::variable::types::VariableType; +use crate::variable_type::VariableType; +use ahash::{HashMap, HashMapExt}; use rust_decimal::prelude::Zero; use std::cell::RefCell; use std::collections::hash_map::Entry; -use std::collections::HashMap; use std::rc::Rc; impl VariableType { @@ -320,50 +320,3 @@ impl VariableType { }) } } - -#[cfg(test)] -mod tests { - use crate::variable::VariableType; - use std::rc::Rc; - - #[test] - fn merge_simple() { - assert_eq!( - VariableType::Number.merge(&VariableType::Number), - VariableType::Number - ); - assert_eq!( - VariableType::String.merge(&VariableType::String), - VariableType::String - ); - assert_eq!( - VariableType::Bool.merge(&VariableType::Bool), - VariableType::Bool - ); - assert_eq!( - VariableType::Null.merge(&VariableType::Null), - VariableType::Null - ); - assert_eq!( - VariableType::Any.merge(&VariableType::Any), - VariableType::Any - ); - } - - #[test] - fn merge_array() { - assert_eq!( - VariableType::Array(Rc::new(VariableType::Number)) - .merge(&VariableType::Array(Rc::new(VariableType::Number))), - VariableType::Array(Rc::new(VariableType::Number)) - ); - } - - #[test] - fn merge_mixed() { - assert_eq!( - VariableType::Number.merge(&VariableType::String), - VariableType::Any - ); - } -} diff --git a/core/types/tests/ref_ser.rs b/core/types/tests/ref_ser.rs new file mode 100644 index 00000000..5a6e02d8 --- /dev/null +++ b/core/types/tests/ref_ser.rs @@ -0,0 +1,260 @@ +use ahash::{HashMap, HashMapExt}; +use rust_decimal_macros::dec; +use serde_json::json; +use std::cell::RefCell; +use std::error::Error; +use std::rc::Rc; +use zen_types::rcvalue::RcValue; +use zen_types::variable::Variable; + +type TestResult = Result<(), Box>; + +#[test] +fn serialize_deserialize_simple() -> TestResult { + let var = Variable::from(json!({ + "name": "Alice", + "age": 30, + "active": true + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_deserialize_with_refs() -> TestResult { + let shared_string = "shared_value"; + let var = Variable::from(json!({ + "user1": { + "name": shared_string, + "status": shared_string + }, + "user2": { + "name": shared_string, + "friend": shared_string + }, + "metadata": { + "type": shared_string + } + })); + + let serialized = var.serialize_ref(); + + // Check that refs were created + if let RcValue::Object(ref obj) = serialized { + assert!(obj.contains_key(&Rc::from("$refs"))); + assert!(obj.contains_key(&Rc::from("$root"))); + } else { + panic!("Expected object"); + } + + let deserialized = Variable::deserialize_ref(serialized)?; + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_deserialize_array_refs() -> TestResult { + let shared_array = vec![1, 2, 3]; + let var = Variable::from(json!({ + "data1": shared_array, + "data2": shared_array, + "backup": shared_array + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_deserialize_at_string_escaping() -> TestResult { + let var = Variable::from(json!({ + "normal": "hello", + "at_string": "@special", + "double_at": "@@escaped" + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_deserialize_nested_structure() -> TestResult { + let var = Variable::from(json!({ + "level1": { + "level2": { + "level3": { + "data": "deep_value", + "numbers": [1, 2, 3, 4, 5] + } + }, + "shared": "common_string" + }, + "other": { + "ref": "common_string" + }, + "array": [ + {"shared": "common_string"}, + {"different": "unique"} + ] + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn no_refs_when_below_threshold() -> TestResult { + // String too short, should not create refs + let var = Variable::from(json!({ + "a": "hi", + "b": "hi", + "c": "hi" + })); + + let serialized = var.serialize_ref(); + + // Should not have refs section + if let RcValue::Object(ref obj) = serialized { + assert!(!obj.contains_key(&Rc::from("$refs"))); + } + + let deserialized = Variable::deserialize_ref(serialized)?; + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_circular_references() -> TestResult { + // Create a shared object that will be referenced multiple times + let shared_obj = Rc::new(RefCell::new({ + let mut map = HashMap::new(); + map.insert( + Rc::from("shared_data"), + Variable::String(Rc::from("important_value")), + ); + map.insert(Rc::from("id"), Variable::Number(dec!(42.0))); + map + })); + + // Create a structure where the same object appears in multiple places + let mut root_map = HashMap::new(); + root_map.insert(Rc::from("first_ref"), Variable::Object(shared_obj.clone())); + root_map.insert(Rc::from("second_ref"), Variable::Object(shared_obj.clone())); + root_map.insert(Rc::from("third_ref"), Variable::Object(shared_obj)); + + let var = Variable::Object(Rc::new(RefCell::new(root_map))); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_same_array_multiple_locations() -> TestResult { + use std::cell::RefCell; + + // Create a shared array + let shared_array = Rc::new(RefCell::new(vec![ + Variable::String(Rc::from("item_one")), + Variable::String(Rc::from("item_two")), + Variable::Number(dec!(123.0)), + ])); + + let var = Variable::from(json!({ + "list1": shared_array.clone(), + "backup_list": shared_array.clone(), + "nested": { + "inner_list": shared_array + } + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_mixed_shared_references() -> TestResult { + let shared_string: Rc = Rc::from("shared_between_key_and_value"); + + // Create an object where the same string is used as both key and value + let mut obj_map = HashMap::new(); + obj_map.insert( + shared_string.clone(), + Variable::String(shared_string.clone()), + ); + obj_map.insert( + Rc::from("other_key"), + Variable::String(shared_string.clone()), + ); + + let shared_obj = Rc::new(RefCell::new(obj_map)); + + // Use the shared object in multiple places + let var = Variable::from(json!({ + "container1": shared_obj.clone(), + "container2": shared_obj.clone(), + "metadata": { + "reference": shared_obj + } + })); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} + +#[test] +fn serialize_shared_array_with_shared_strings() -> TestResult { + let shared_string: Rc = Rc::from("shared_string_value"); + + // Create a shared array containing the shared string + let shared_array = Rc::new(RefCell::new(vec![ + Variable::String(shared_string.clone()), + Variable::Number(dec!(42.0)), + Variable::String(shared_string.clone()), + ])); + + // Use the shared array in multiple places + let mut root_map = HashMap::new(); + root_map.insert(Rc::from("array1"), Variable::Array(shared_array.clone())); + root_map.insert(Rc::from("array2"), Variable::Array(shared_array.clone())); + root_map.insert(Rc::from("array3"), Variable::Array(shared_array)); + + let var = Variable::Object(Rc::new(RefCell::new(root_map))); + + let serialized = var.serialize_ref(); + let deserialized = Variable::deserialize_ref(serialized)?; + + assert_eq!(var, deserialized); + + Ok(()) +} diff --git a/core/types/tests/variable.rs b/core/types/tests/variable.rs new file mode 100644 index 00000000..5715989c --- /dev/null +++ b/core/types/tests/variable.rs @@ -0,0 +1,206 @@ +use rust_decimal_macros::dec; +use serde_json::json; +use std::error::Error; +use std::rc::Rc; +use zen_types::variable::Variable; + +type TestResult = Result<(), Box>; +#[test] +fn dot_operations() -> TestResult { + let var = Variable::from(json!({ + "user": { + "profile": { + "name": "Alice", + "age": 30 + } + } + })); + + // Test dot get + assert_eq!( + var.dot("user.profile.name"), + Some(Variable::String(Rc::from("Alice"))) + ); + assert_eq!(var.dot("user.profile.nonexistent"), None); + assert_eq!(var.dot("nonexistent.path"), None); + + // Test dot insert + let updated = var.dot_insert( + "user.profile.email", + Variable::String(Rc::from("alice@example.com")), + ); + assert!(updated.is_none()); // Returns previous value (none) + assert_eq!( + var.dot("user.profile.email"), + Some(Variable::String(Rc::from("alice@example.com"))) + ); + + // Test dot insert detached + let new_var = var + .dot_insert_detached("settings.theme", Variable::String(Rc::from("dark"))) + .ok_or_else(|| "Failed to insert detached path".to_string())?; + assert_eq!( + new_var.dot("settings.theme"), + Some(Variable::String(Rc::from("dark"))) + ); + assert_eq!(var.dot("settings.theme"), None); // Original unchanged + + // Test dot remove + let removed = var.dot_remove("user.profile.age"); + assert_eq!(removed, Some(Variable::Number(dec!(30)))); + assert_eq!(var.dot("user.profile.age"), None); + + Ok(()) +} + +#[test] +fn clone_operations() -> TestResult { + let original = Variable::from(json!({ + "data": [1, 2, {"nested": "value"}], + "count": 42 + })); + + // Test shallow clone - shares references + let shallow = original.shallow_clone(); + if let (Variable::Array(orig_arr), Variable::Array(shallow_arr)) = (&original, &shallow) { + assert!(Rc::ptr_eq(orig_arr, shallow_arr)); + } + + // Test depth clone + let depth1 = original.depth_clone(1); + if let (Variable::Array(orig_arr), Variable::Array(depth_arr)) = (&original, &depth1) { + assert!(!Rc::ptr_eq(orig_arr, depth_arr)); // Different array refs + + let orig_nested = &orig_arr.borrow()[2]; + let depth_nested = &depth_arr.borrow()[2]; + if let (Variable::Object(orig_obj), Variable::Object(depth_obj)) = + (orig_nested, depth_nested) + { + assert!(Rc::ptr_eq(orig_obj, depth_obj)); // Nested still shared at depth 1 + } + } + + // Test deep clone - everything separate + let deep = original.deep_clone(); + if let (Variable::Array(orig_arr), Variable::Array(deep_arr)) = (&original, &deep) { + assert!(!Rc::ptr_eq(orig_arr, deep_arr)); + + let orig_nested = &orig_arr.borrow()[2]; + let deep_nested = &deep_arr.borrow()[2]; + if let (Variable::Object(orig_obj), Variable::Object(deep_obj)) = (orig_nested, deep_nested) + { + assert!(!Rc::ptr_eq(orig_obj, deep_obj)); // Nested also separate + } + } + + Ok(()) +} + +#[test] +fn merge_operations() -> TestResult { + let mut doc = Variable::from(json!({ + "user": {"name": "Alice", "age": 30}, + "settings": {"theme": "light"} + })); + + let patch = Variable::from(json!({ + "user": {"age": 31, "email": "alice@example.com"}, + "settings": {"notifications": true}, + "new_field": "value" + })); + + // Test merge clone (doesn't modify original) + let merged = doc.merge_clone(&patch); + assert_eq!(doc.dot("user.age"), Some(Variable::Number(dec!(30)))); // Original unchanged + assert_eq!(merged.dot("user.age"), Some(Variable::Number(dec!(31)))); // Merged updated + assert_eq!( + merged.dot("user.email"), + Some(Variable::String(Rc::from("alice@example.com"))) + ); + assert_eq!( + merged.dot("new_field"), + Some(Variable::String(Rc::from("value"))) + ); + + // Test in-place merge + doc.merge(&patch); + assert_eq!(doc.dot("user.age"), Some(Variable::Number(dec!(31)))); // Original now changed + assert_eq!( + doc.dot("user.name"), + Some(Variable::String(Rc::from("Alice"))) + ); // Preserved + assert_eq!( + doc.dot("settings.notifications"), + Some(Variable::Bool(true)) + ); // Added + + // Test null removal + let null_patch = Variable::from(json!({"user": {"name": null}})); + doc.merge(&null_patch); + assert_eq!(doc.dot("user.name"), None); // Removed by null + + Ok(()) +} + +#[test] +fn type_operations() -> TestResult { + let var = Variable::from(json!({ + "string": "hello", + "number": 42, + "bool": true, + "array": [1, 2, 3], + "object": {"key": "value"}, + "null": null + })); + + // Test type checks + assert!(var.dot("array").unwrap().is_array()); + assert!(var.dot("object").unwrap().is_object()); + assert!(!var.dot("string").unwrap().is_array()); + + // Test accessors + assert_eq!(var.dot("string").unwrap().as_str(), Some("hello")); + assert_eq!(var.dot("number").unwrap().as_number(), Some(dec!(42))); + assert_eq!(var.dot("bool").unwrap().as_bool(), Some(true)); + assert!(var.dot("array").unwrap().as_array().is_some()); + assert!(var.dot("object").unwrap().as_object().is_some()); + + // Test type names + assert_eq!(var.dot("string").unwrap().type_name(), "string"); + assert_eq!(var.dot("number").unwrap().type_name(), "number"); + assert_eq!(var.dot("bool").unwrap().type_name(), "bool"); + assert_eq!(var.dot("array").unwrap().type_name(), "array"); + assert_eq!(var.dot("object").unwrap().type_name(), "object"); + assert_eq!(var.dot("null").unwrap().type_name(), "null"); + + Ok(()) +} + +#[test] +fn edge_cases() -> TestResult { + // Empty structures + let empty_obj = Variable::empty_object(); + let empty_arr = Variable::empty_array(); + assert!(empty_obj.is_object()); + assert!(empty_arr.is_array()); + + // Dot operations on non-objects + let number = Variable::Number(dec!(42)); + assert_eq!(number.dot("anything"), None); + assert_eq!(number.dot_insert("path", Variable::Null), None); + assert_eq!(number.dot_remove("path"), None); + + // Array merge at top level + let mut doc = Variable::from(json!({"key": "value"})); + let array_patch = Variable::from(json!([1, 2, 3])); + doc.merge(&array_patch); + assert!(doc.is_array()); // Replaced with array + + // Self-merge (no-op) + let mut original = Variable::from(json!({"a": 1})); + let clone = original.shallow_clone(); + original.merge(&clone); + assert_eq!(original.dot("a"), Some(Variable::Number(dec!(1)))); + + Ok(()) +} diff --git a/core/types/tests/variable_type.rs b/core/types/tests/variable_type.rs new file mode 100644 index 00000000..f9819194 --- /dev/null +++ b/core/types/tests/variable_type.rs @@ -0,0 +1,278 @@ +use ahash::{HashMap, HashMapExt}; +use serde_json::json; +use std::cell::RefCell; +use std::rc::Rc; +use zen_types::variable_type::VariableType; + +type TestResult = Result<(), Box>; + +#[test] +fn variable_type_operations() -> TestResult { + // Test basic types and accessors + assert!(VariableType::Array(Rc::new(VariableType::Number)).is_array()); + assert!(VariableType::Interval.is_iterable()); + assert!(VariableType::String.is_string()); + assert!(VariableType::empty_object().is_object()); + assert!(VariableType::Null.is_null()); + + // Test iterator + assert_eq!( + VariableType::Array(Rc::new(VariableType::String)).iterator(), + Some(Rc::new(VariableType::String)) + ); + assert_eq!( + VariableType::Interval.iterator(), + Some(Rc::new(VariableType::Number)) + ); + assert_eq!(VariableType::String.iterator(), None); + + // Test const string extraction + let const_type = VariableType::Const(Rc::from("test")); + assert_eq!(const_type.as_const_str(), Some(Rc::from("test"))); + assert_eq!(VariableType::String.as_const_str(), None); + + // Test widen + assert_eq!(const_type.widen(), VariableType::String); + assert_eq!( + VariableType::Enum(None, vec![Rc::from("a")]).widen(), + VariableType::String + ); + assert_eq!(VariableType::Number.widen(), VariableType::Number); + + Ok(()) +} + +#[test] +fn variable_type_satisfies() -> TestResult { + // Basic type satisfaction + assert!(VariableType::Any.satisfies(&VariableType::String)); + assert!(VariableType::String.satisfies(&VariableType::Any)); + assert!(VariableType::Number.satisfies(&VariableType::Number)); + assert!(!VariableType::String.satisfies(&VariableType::Number)); + + // Date satisfaction + assert!(VariableType::Number.satisfies(&VariableType::Date)); + assert!(VariableType::String.satisfies(&VariableType::Date)); + + // Const and enum satisfaction + let const_a = VariableType::Const(Rc::from("a")); + let enum_ab = VariableType::Enum(None, vec![Rc::from("a"), Rc::from("b")]); + assert!(const_a.satisfies(&VariableType::String)); + assert!(const_a.satisfies(&enum_ab)); + assert!(enum_ab.satisfies(&VariableType::String)); + + // Array satisfaction + let arr_num = VariableType::Array(Rc::new(VariableType::Number)); + let arr_str = VariableType::Array(Rc::new(VariableType::String)); + assert!(arr_num.satisfies(&arr_num)); + assert!(!arr_num.satisfies(&arr_str)); + + Ok(()) +} + +#[test] +fn variable_type_merge() -> TestResult { + // Basic merges + assert_eq!( + VariableType::Number.merge(&VariableType::String), + VariableType::Any + ); + assert_eq!( + VariableType::String.merge(&VariableType::String), + VariableType::String + ); + assert_eq!( + VariableType::Date.merge(&VariableType::Date), + VariableType::Date + ); + assert_eq!( + VariableType::Interval.merge(&VariableType::Interval), + VariableType::Interval + ); + + // Const merges + let const_a = VariableType::Const(Rc::from("a")); + let const_b = VariableType::Const(Rc::from("b")); + let merged = const_a.merge(&const_b); + assert!(matches!(merged, VariableType::Enum(None, _))); + + // Same const merge + assert_eq!(const_a.merge(&const_a), const_a); + + // Const with string + assert_eq!(const_a.merge(&VariableType::String), VariableType::String); + + // Const with enum and enum with const + let enum_bc = VariableType::Enum(None, vec![Rc::from("b"), Rc::from("c")]); + let const_enum_merge = const_a.merge(&enum_bc); + if let VariableType::Enum(_, vals) = const_enum_merge { + assert_eq!(vals.len(), 3); + } + + let enum_const_merge = enum_bc.merge(&const_a); + if let VariableType::Enum(_, vals) = enum_const_merge { + assert_eq!(vals.len(), 3); + } + + // Array merge with same pointer + let shared_array_type = Rc::new(VariableType::Number); + let arr1 = VariableType::Array(shared_array_type.clone()); + let arr2 = VariableType::Array(shared_array_type); + assert_eq!(arr1.merge(&arr2), arr1); + + // Object merge + let obj1 = VariableType::Object(Rc::new(RefCell::new({ + let mut map = HashMap::new(); + map.insert(Rc::from("key1"), VariableType::String); + map.insert(Rc::from("shared"), VariableType::Number); + map + }))); + let obj2 = VariableType::Object(Rc::new(RefCell::new({ + let mut map = HashMap::new(); + map.insert(Rc::from("key2"), VariableType::Bool); + map.insert(Rc::from("shared"), VariableType::String); + map + }))); + let merged_obj = obj1.merge(&obj2); + if let VariableType::Object(obj_ref) = merged_obj { + let obj_map = obj_ref.borrow(); + assert_eq!(obj_map.len(), 3); + assert_eq!(obj_map.get("shared"), Some(&VariableType::Any)); + } + + // Enum merges + let enum1 = VariableType::Enum(Some(Rc::from("E1")), vec![Rc::from("x")]); + let enum2 = VariableType::Enum(Some(Rc::from("E2")), vec![Rc::from("y")]); + let merged_enum = enum1.merge(&enum2); + if let VariableType::Enum(name, vals) = merged_enum { + assert!(name.unwrap().contains("E1 | E2")); + assert_eq!(vals.len(), 2); + } + + Ok(()) +} + +#[test] +fn variable_type_dot_operations() -> TestResult { + let mut obj_map = HashMap::new(); + obj_map.insert( + Rc::from("user"), + VariableType::Object(Rc::new(RefCell::new({ + let mut user_map = HashMap::new(); + user_map.insert(Rc::from("name"), VariableType::String); + user_map + }))), + ); + let obj = VariableType::Object(Rc::new(RefCell::new(obj_map))); + + // Test dot get + assert_eq!(obj.dot("user.name"), Some(VariableType::String)); + assert_eq!(obj.dot("user.nonexistent"), None); + assert_eq!(obj.dot("nonexistent"), None); + + // Test dot insert + let prev = obj.dot_insert("user.email", VariableType::String); + assert_eq!(prev, None); // No previous value + assert_eq!(obj.dot("user.email"), Some(VariableType::String)); + + // Test dot insert detached + let new_obj = obj + .dot_insert_detached("settings.theme", VariableType::String) + .expect("should insert successfully"); + assert_eq!(new_obj.dot("settings.theme"), Some(VariableType::String)); + assert_eq!(obj.dot("settings.theme"), None); + + // Test invalid dot operations on non-objects + assert_eq!(VariableType::String.dot("anything"), None); + assert_eq!( + VariableType::String.dot_insert("path", VariableType::Number), + None + ); + + Ok(()) +} + +#[test] +fn variable_type_conversions() -> TestResult { + // Test from serde_json Value + assert_eq!(VariableType::from(json!(null)), VariableType::Null); + assert_eq!(VariableType::from(json!(true)), VariableType::Bool); + assert_eq!(VariableType::from(json!(42)), VariableType::Number); + assert_eq!(VariableType::from(json!("hello")), VariableType::String); + + // Test array conversion with mixed types + let mixed_array = json!([1, "hello", true]); + let array_type = VariableType::from(mixed_array); + assert!(matches!(array_type, VariableType::Array(_))); + + // Test empty array + let empty_array = Vec::::new(); + let empty_type = VariableType::from(empty_array); + assert_eq!(empty_type, VariableType::Array(Rc::new(VariableType::Any))); + + // Test array reference conversion + let vec_ref = vec![json!(1), json!("test")]; + let ref_type = VariableType::from(&vec_ref); + assert!(matches!(ref_type, VariableType::Array(_))); + + let empty_vec_ref = Vec::::new(); + let empty_ref_type = VariableType::from(&empty_vec_ref); + assert_eq!( + empty_ref_type, + VariableType::Array(Rc::new(VariableType::Any)) + ); + + // Test object conversion + let obj = json!({"name": "Alice", "age": 30}); + let obj_type = VariableType::from(obj); + assert!(matches!(obj_type, VariableType::Object(_))); + if let VariableType::Object(obj_ref) = obj_type { + let obj_map = obj_ref.borrow(); + assert!(obj_map.contains_key(&Rc::from("name"))); + assert!(obj_map.contains_key(&Rc::from("age"))); + } + + // Test convenience methods + assert_eq!( + VariableType::Number.array(), + VariableType::Array(Rc::new(VariableType::Number)) + ); + assert_eq!(VariableType::default(), VariableType::Null); + + Ok(()) +} + +#[test] +fn variable_type_clone_operations() -> TestResult { + let mut inner_map = HashMap::new(); + inner_map.insert(Rc::from("value"), VariableType::String); + let inner_obj = VariableType::Object(Rc::new(RefCell::new(inner_map))); + + let mut outer_map = HashMap::new(); + outer_map.insert(Rc::from("inner"), inner_obj); + let outer_obj = VariableType::Object(Rc::new(RefCell::new(outer_map))); + + // Test shallow clone - shares references + let shallow = outer_obj.shallow_clone(); + if let (VariableType::Object(orig), VariableType::Object(clone)) = (&outer_obj, &shallow) { + assert!(Rc::ptr_eq(orig, clone)); + } + + // Test depth clone + let depth1 = outer_obj.depth_clone(1); + if let (VariableType::Object(orig), VariableType::Object(clone)) = (&outer_obj, &depth1) { + assert!(!Rc::ptr_eq(orig, clone)); // Outer different + + let orig = orig.borrow(); + let clone = clone.borrow(); + let orig_inner = orig.get("inner").unwrap(); + let clone_inner = clone.get("inner").unwrap(); + if let (VariableType::Object(orig_inner_ref), VariableType::Object(clone_inner_ref)) = + (orig_inner, clone_inner) + { + assert!(Rc::ptr_eq(orig_inner_ref, clone_inner_ref)); // Inner still shared at depth 1 + } + } + + Ok(()) +}