diff --git a/Cargo.toml b/Cargo.toml index df6cd3c..9c00338 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ simple-ai-backend = { path = "packages/backend" } simple-ai-frontend = { path = "packages/frontend" } simple-ai-macros = { path = "packages/macros" } +uuid = { version = "1.19", features = [ "v5" ]} dioxus = { version = "0.7.1" } [profile] diff --git a/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin b/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin new file mode 100644 index 0000000..95ec8ec Binary files /dev/null and b/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin differ diff --git a/packages/backend/nodes/bundled_node/meta.toml b/nodes/bundled_node/meta.toml similarity index 76% rename from packages/backend/nodes/bundled_node/meta.toml rename to nodes/bundled_node/meta.toml index 80c3777..2f33926 100644 --- a/packages/backend/nodes/bundled_node/meta.toml +++ b/nodes/bundled_node/meta.toml @@ -1,7 +1,7 @@ name = "bundled_node" description = "A bundled node" author = "Author" -date = "2025-03-10T12:33:26.326778746Z" +date = "2025-12-04T12:05:46.983020666Z" [[versions]] version = "0.0.1" diff --git a/packages/backend/Cargo.toml b/packages/backend/Cargo.toml index 46b9b58..7dcde7a 100644 --- a/packages/backend/Cargo.toml +++ b/packages/backend/Cargo.toml @@ -24,3 +24,4 @@ bincode = { version = "1.1.3" } tokio = { version = "1.43.0", features = ["sync"] } fuzzy-matcher = "0.3.7" walkdir = "2.5.0" +onnx-ir = "0.19.1" diff --git a/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin b/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin deleted file mode 100644 index 91eb49c..0000000 Binary files a/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin and /dev/null differ diff --git a/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin b/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin deleted file mode 100644 index 8428a81..0000000 Binary files a/packages/backend/nodes/bundled_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin and /dev/null differ diff --git a/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/0.0.1.bin b/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/0.0.1.bin deleted file mode 100644 index 901dca4..0000000 Binary files a/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/0.0.1.bin and /dev/null differ diff --git a/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/node.bin b/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/node.bin deleted file mode 100644 index 1848209..0000000 Binary files a/packages/backend/nodes/complex_bundled_node/347d4e7d76b463a7b8415486415c698bf2c792ee6384f96eca96e1d45a2f3986/node.bin and /dev/null differ diff --git a/packages/backend/nodes/complex_bundled_node/meta.toml b/packages/backend/nodes/complex_bundled_node/meta.toml deleted file mode 100644 index b833f7a..0000000 --- a/packages/backend/nodes/complex_bundled_node/meta.toml +++ /dev/null @@ -1,25 +0,0 @@ -name = "complex_bundled_node" -description = "A complex bundled node" -author = "Author" -date = "2025-03-10T12:33:26.326848257Z" - -[[versions]] -version = "0.0.1" - -[[versions.env.deps]] -name = "serde" -versions = ["1.0"] -lib = true - -[[versions]] -version = "0.0.2" - -[[versions.env.deps]] -name = "serde" -versions = ["1.0"] -lib = true - -[[versions.env.deps]] -name = "torch" -versions = ["2.0"] -lib = true diff --git a/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin b/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin deleted file mode 100644 index fb384a6..0000000 Binary files a/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/0.0.1.bin and /dev/null differ diff --git a/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin b/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin deleted file mode 100644 index c1be2f7..0000000 Binary files a/packages/backend/nodes/test_code_node/e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855/node.bin and /dev/null differ diff --git a/packages/backend/nodes/test_code_node/meta.toml b/packages/backend/nodes/test_code_node/meta.toml deleted file mode 100644 index e4bb6ca..0000000 --- a/packages/backend/nodes/test_code_node/meta.toml +++ /dev/null @@ -1,10 +0,0 @@ -name = "test_code_node" -description = "A simple code node" -author = "Author" -date = "2025-03-10T12:33:26.326772004Z" - -[[versions]] -version = "0.0.1" - -[versions.env] -deps = [] diff --git a/packages/backend/src/lib.rs b/packages/backend/src/lib.rs index ba64358..55e5f1d 100644 --- a/packages/backend/src/lib.rs +++ b/packages/backend/src/lib.rs @@ -1,6 +1 @@ pub mod modules; - -pub mod prelude { - pub use super::modules::nms::*; - pub use super::modules::utils::prelude::*; -} diff --git a/packages/backend/src/modules.rs b/packages/backend/src/modules.rs index 14f774b..8587ad2 100644 --- a/packages/backend/src/modules.rs +++ b/packages/backend/src/modules.rs @@ -1,3 +1,3 @@ -pub mod compiler; -pub mod nms; +pub mod nodes; +pub mod projects; pub mod utils; diff --git a/packages/backend/src/modules/compiler.rs b/packages/backend/src/modules/compiler.rs deleted file mode 100644 index 65880be..0000000 --- a/packages/backend/src/modules/compiler.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod model; diff --git a/packages/backend/src/modules/nms/create.rs b/packages/backend/src/modules/nms/create.rs deleted file mode 100644 index 17b32be..0000000 --- a/packages/backend/src/modules/nms/create.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::modules::nms::check_name; -use crate::modules::utils::prelude::*; -use anyhow::Result; -use std::fs::{create_dir, File}; -use std::io::Write; -use std::path::Path; - -// #[cfg(feature = "desktop")] -// pub fn create_node(node: Node) -> Result<(), String> { -// let name = node.name.clone(); -// -// if !check_name(name.clone()) { -// return Err(format!( -// "Node name {} is not allowed! Please only use letters, dashes and underscores.", -// name -// )); -// } -// if !Path::new("nodes/").exists() { -// create_dir(Path::new("nodes/")).map_err(|e| e.to_string())?; -// } -// if Path::new("nodes/").join(name.clone()).exists() { -// return Err(format!("A node named {} does already exist!", name)); -// } -// -// create_dir(Path::new("nodes/").join(name.clone())).map_err(|e| e.to_string())?; -// let meta: Metadata = node.clone().into(); -// let meta_toml = toml::to_string(&meta).unwrap(); -// let mut meta_file = File::create(Path::new("nodes/").join(name.clone()).join("meta.toml")) -// .map_err(|e| e.to_string())?; -// meta_file -// .write_all(meta_toml.as_bytes()) -// .map_err(|e| e.to_string())?; -// -// let env_hash = meta.impls[0].clone().1 meta.; -// create_dir( -// Path::new("nodes/") -// .join(name.clone()) -// .join(env_hash.clone()), -// ) -// .map_err(|e| e.to_string())?; -// let node_bin = bincode::serialize(&SaveNode::from(node)).map_err(|e| e.to_string())?; -// let mut node_file = File::create( -// Path::new("nodes/") -// .join(name) -// .join(env_hash) -// .join("node.bin"), -// ) -// .map_err(|e| e.to_string())?; -// node_file.write_all(&node_bin).map_err(|e| e.to_string())?; -// -// Ok(()) -// } diff --git a/packages/backend/src/modules/nms.rs b/packages/backend/src/modules/nodes.rs similarity index 84% rename from packages/backend/src/modules/nms.rs rename to packages/backend/src/modules/nodes.rs index 7ef94ba..e6cc05a 100644 --- a/packages/backend/src/modules/nms.rs +++ b/packages/backend/src/modules/nodes.rs @@ -1,6 +1,4 @@ -pub mod create; pub mod delete; -pub mod modify; pub mod query; pub mod save; diff --git a/packages/backend/src/modules/nms/delete.rs b/packages/backend/src/modules/nodes/delete.rs similarity index 92% rename from packages/backend/src/modules/nms/delete.rs rename to packages/backend/src/modules/nodes/delete.rs index a924d71..0d7f0bd 100644 --- a/packages/backend/src/modules/nms/delete.rs +++ b/packages/backend/src/modules/nodes/delete.rs @@ -4,8 +4,8 @@ use std::path::Path; use crate::modules::utils::prelude::*; -pub fn delete_node(name: String, version: Option) -> Result<(), String> { - let node_path = Path::new("nodes/").join(&name); +pub fn delete_node(author: String, name: String, version: Option) -> Result<(), String> { + let node_path = Path::new("nodes/").join(&author).join(&name); let meta_path = node_path.join("meta.toml"); if !node_path.exists() { diff --git a/packages/backend/src/modules/nms/query.rs b/packages/backend/src/modules/nodes/query.rs similarity index 87% rename from packages/backend/src/modules/nms/query.rs rename to packages/backend/src/modules/nodes/query.rs index c796d07..853c8c5 100644 --- a/packages/backend/src/modules/nms/query.rs +++ b/packages/backend/src/modules/nodes/query.rs @@ -23,7 +23,7 @@ pub fn get_all_nodes() -> Result { let save_node: SaveNode = bincode::deserialize(&data) .map_err(|e| format!("Failed to deserialize {}: {}", path.display(), e))?; - let node = Node::from(save_node); + let node = Node::from_save_node(save_node, None); nc.push_context(StrongNode::from(node)); } @@ -31,13 +31,13 @@ pub fn get_all_nodes() -> Result { } /// This function searches through all available Nodes and returns a NodeContainer containing all Nodes available for the inferred environment. -pub fn query_nodes(query_filter: Vec) -> NodeContainer { +pub fn query_nodes(query_filters: Vec) -> NodeContainer { let all_nodes = get_all_nodes().expect("Error walking directory!"); all_nodes .iter() .filter(|node| { - query_filter.iter().all(|filter| { + query_filters.iter().all(|filter| { filter .clone() .is_ok(node.context.try_lock().unwrap().to_owned()) diff --git a/packages/backend/src/modules/nms/save.rs b/packages/backend/src/modules/nodes/save.rs similarity index 95% rename from packages/backend/src/modules/nms/save.rs rename to packages/backend/src/modules/nodes/save.rs index a7fca44..ec34038 100644 --- a/packages/backend/src/modules/nms/save.rs +++ b/packages/backend/src/modules/nodes/save.rs @@ -1,4 +1,4 @@ -use crate::modules::nms::check_name; +use crate::modules::nodes::check_name; use crate::modules::utils::prelude::*; use anyhow::Result; use std::fs::{self, create_dir, create_dir_all, File}; @@ -9,6 +9,7 @@ use toml; // #[cfg(feature = "desktop")] pub fn save_node(node: Node) -> Result<(), String> { let name = node.name.clone(); + let author = node.author.clone(); if !check_name(name.clone()) { return Err(format!( @@ -17,7 +18,7 @@ pub fn save_node(node: Node) -> Result<(), String> { )); } - let node_path = Path::new("nodes/").join(&name); + let node_path = Path::new("nodes/").join(&author).join(&name); if !node_path.exists() { create_dir_all(&node_path).map_err(|e| e.to_string())?; } diff --git a/packages/backend/src/modules/projects.rs b/packages/backend/src/modules/projects.rs new file mode 100644 index 0000000..a99f84e --- /dev/null +++ b/packages/backend/src/modules/projects.rs @@ -0,0 +1,4 @@ +pub mod create; +pub mod delete; +pub mod query; +pub mod settings; diff --git a/packages/backend/src/modules/compiler/model.rs b/packages/backend/src/modules/projects/create.rs similarity index 100% rename from packages/backend/src/modules/compiler/model.rs rename to packages/backend/src/modules/projects/create.rs diff --git a/packages/backend/src/modules/nms/modify.rs b/packages/backend/src/modules/projects/delete.rs similarity index 100% rename from packages/backend/src/modules/nms/modify.rs rename to packages/backend/src/modules/projects/delete.rs diff --git a/packages/backend/src/modules/projects/query.rs b/packages/backend/src/modules/projects/query.rs new file mode 100644 index 0000000..d764326 --- /dev/null +++ b/packages/backend/src/modules/projects/query.rs @@ -0,0 +1,45 @@ +use std::path::Path; + +use crate::modules::utils::prelude::*; +use walkdir::WalkDir; + +pub fn get_all_projects() -> Result, String> { + let projects_dir = Path::new("projects"); + if !projects_dir.exists() { + return Ok(Vec::default()); + } + + let mut projects = Vec::default(); + + for entry in WalkDir::new(projects_dir) + .into_iter() + .filter_map(Result::ok) + .filter(|e| e.path().extension().is_some_and(|ext| ext == "json")) + { + let path = entry.path(); + let data = std::fs::read_to_string(path) + .map_err(|e| format!("Failed to read {}: {}", path.display(), e))?; + + let project: Project = serde_json::from_str(&data) + .map_err(|e| format!("Failed to deserialize {}: {}", path.display(), e))?; + + projects.push(project); + } + + Ok(projects) +} + +/// This function searches through all available Projects +pub fn query_projects(query_filters: Vec) -> Vec { + let all_projects = get_all_projects().expect("Error walking directory!"); + + all_projects + .iter() + .filter(|project| { + query_filters + .iter() + .all(|filter| filter.clone().is_ok(project)) + }) + .cloned() + .collect() +} diff --git a/packages/backend/src/modules/projects/settings.rs b/packages/backend/src/modules/projects/settings.rs new file mode 100644 index 0000000..e69de29 diff --git a/packages/backend/src/modules/utils.rs b/packages/backend/src/modules/utils.rs index ac3899c..3518450 100644 --- a/packages/backend/src/modules/utils.rs +++ b/packages/backend/src/modules/utils.rs @@ -5,7 +5,9 @@ pub mod dtype; pub mod environment; pub mod metadata; pub mod node; +pub mod node_type; pub mod param; +pub mod project; pub mod query_filter; pub mod save_node; pub mod save_param; @@ -18,7 +20,9 @@ pub mod prelude { pub use super::environment::*; pub use super::metadata::*; pub use super::node::*; + pub use super::node_type::*; pub use super::param::*; + pub use super::project::*; pub use super::query_filter::*; pub use super::save_node::*; pub use super::save_param::*; diff --git a/packages/backend/src/modules/utils/node.rs b/packages/backend/src/modules/utils/node.rs index 4a542ab..1a7d715 100644 --- a/packages/backend/src/modules/utils/node.rs +++ b/packages/backend/src/modules/utils/node.rs @@ -1,10 +1,11 @@ use super::prelude::*; use derive_builder::Builder; +use onnx_ir::NodeType; use std::collections::HashMap; // -------------------- NODE KIND -------------------- // #[derive(Clone, PartialEq)] pub enum NodeKind { - Code { code: String }, + Onnx { onnx: OnnxNode }, Bundled { bundle: NodeContainer }, } // -------------------- NODE -------------------- // @@ -15,39 +16,136 @@ pub type WeakNode = WeakContext; pub struct Node { #[builder(setter(into))] pub name: String, - pub params: Vec, pub version: Version, pub kind: NodeKind, pub description: String, pub author: String, - #[builder(default)] - pub compiled: Option, // or and bytes... pub date: Date, #[builder(default)] pub position: Option<(f64, f64)>, } +#[derive(Builder, Clone, PartialEq)] +pub struct OnnxNode { + pub node_type: NodeType, + pub name: String, + pub inputs: Vec, + pub outputs: Vec, +} + impl Node { pub fn get_full_env(self) -> Environment { let mut env = self.version.env; if let NodeKind::Bundled { bundle } = self.kind { for context in bundle.tree.iter() { - let node: Node = context.context.try_lock().unwrap().to_owned(); + let node = context.context.try_lock().unwrap().to_owned(); env = node.get_full_env().merge(&env).unwrap(); } } env } + + pub fn from_save_node( + node: SaveNode, + param_map: Option<&mut HashMap>, + ) -> Self { + let top = param_map.is_none(); + let mut new_map; + let param_map: &mut HashMap = match param_map { + Some(pm) => pm, + None => { + new_map = HashMap::new(); + &mut new_map + } + }; + + let mut binding = NodeBuilder::default(); + let builder = binding + .name(node.name) + .description(node.description) + .author(node.author) + .version(Version { + version: node.version.version, + env: node.version.env, + }) + .date(node.date); + + let kind = match node.kind { + SaveNodeKind::Onnx { onnx } => NodeKind::Onnx { + onnx: OnnxNode::from_save_node(onnx, param_map), + }, + SaveNodeKind::Bundled { bundle } => { + let mut node_container = NodeContainer::new(); + for save_node in bundle { + node_container.push_context(StrongContext::from(Node::from_save_node( + save_node, + Some(param_map), + ))); + } + NodeKind::Bundled { + bundle: node_container, + } + } + }; + + // Second pass: resolve connections + let mut resolve = Vec::new(); + for (_, strong_param) in param_map.iter_mut() { + let param = strong_param.context.try_lock().unwrap(); + if let ParamKind::Runtime { id, .. } = ¶m.kind { + resolve.push((*id, strong_param.clone())); + } + } + + if top { + for (id, strong_param) in resolve { + if let Some(target) = param_map.get(&id) { + let mut param = strong_param.context.try_lock().unwrap(); + if let ParamKind::Runtime { connection, .. } = &mut param.kind { + *connection = Some(WeakContext::from(target.clone())); + } + } + } + } + + builder.kind(kind).build().expect("Failed to build Node") + } + + pub fn get_params(&self) -> Vec { + match &self.kind { + NodeKind::Onnx { onnx } => [onnx.inputs.clone(), onnx.outputs.clone()].concat(), + + NodeKind::Bundled { bundle } => { + let mut res = Vec::new(); + + for context in &bundle.tree { + let node = context.context.try_lock().unwrap().clone(); + let node_params = node.get_params(); + + res.extend(node_params.into_iter().filter(|p| { + let p = p.context.try_lock().unwrap(); + match &p.kind { + ParamKind::Runtime { connection, .. } => connection.is_some(), + ParamKind::Static { .. } => true, + } + })); + } + + res + } + } + } } -impl From for Node { - fn from(node: SaveNode) -> Self { - let mut param_map: HashMap = HashMap::new(); - let mut params = Vec::new(); +impl OnnxNode { + fn from_save_node(node: SaveOnnxNode, param_map: &mut HashMap) -> Self { + let mut binding = OnnxNodeBuilder::default(); + let mut inputs = Vec::new(); + let mut outputs = Vec::new(); // First pass: instantiate Params without setting connections - for save_param in &node.params { + for save_param in &node.inputs { let strong_param = StrongParam::from(Param { name: save_param.name.clone(), desc: save_param.desc.clone(), @@ -64,45 +162,34 @@ impl From for Node { }, }); param_map.insert(save_param.id, strong_param.clone()); - params.push(strong_param); + inputs.push(strong_param); } - - // Second pass: resolve connections - for strong_param in ¶ms { - let mut param = strong_param.context.try_lock().unwrap(); - if let ParamKind::Runtime { connection, id, .. } = &mut param.kind { - if let Some(target) = param_map.get(id) { - *connection = Some(WeakContext::from(target.clone())); - } - } + for save_param in &node.outputs { + let strong_param = StrongParam::from(Param { + name: save_param.name.clone(), + desc: save_param.desc.clone(), + dtype: save_param.dtype.clone(), + kind: match &save_param.kind { + SaveParamKind::Static { value } => ParamKind::Static { + value: value.clone(), + }, + SaveParamKind::Runtime { kind, .. } => ParamKind::Runtime { + kind: kind.clone(), + connection: None, + id: save_param.id, + }, + }, + }); + param_map.insert(save_param.id, strong_param.clone()); + outputs.push(strong_param); } - let mut binding = NodeBuilder::default(); let builder = binding + .node_type(node.node_type) .name(node.name) - .description(node.description) - .author(node.author) - .compiled(node.compiled) - .version(Version { - version: node.version.version, - env: node.version.env, - }) - .date(node.date) - .params(params); - - let kind = match node.kind { - SaveNodeKind::Code { code } => NodeKind::Code { code }, - SaveNodeKind::Bundled { bundle } => { - let mut node_container = NodeContainer::new(); - for save_node in bundle { - node_container.push_context(StrongNode::from(Node::from(save_node))); - } - NodeKind::Bundled { - bundle: node_container, - } - } - }; + .inputs(inputs) + .outputs(outputs); - builder.kind(kind).build().expect("Failed to build Node") + builder.build().expect("Failed to build OnnxNode") } } diff --git a/packages/backend/src/modules/utils/node_type.rs b/packages/backend/src/modules/utils/node_type.rs new file mode 100644 index 0000000..035a9b8 --- /dev/null +++ b/packages/backend/src/modules/utils/node_type.rs @@ -0,0 +1,207 @@ +use onnx_ir::NodeType; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(remote = "NodeType")] +pub enum OnnxNodeType { + Abs, + Acos, + Acosh, + Add, + And, + ArgMax, + ArgMin, + Asin, + Asinh, + Atan, + Atanh, + Attention, + AveragePool, + AveragePool1d, + AveragePool2d, + BatchNormalization, + Bernoulli, + BitShift, + BitwiseAnd, + BitwiseNot, + BitwiseOr, + BitwiseXor, + BlackmanWindow, + Cast, + CastLike, + Ceil, + Celu, + CenterCropPad, + Clip, + Col, + Compress, + Concat, + ConcatFromSequence, + Constant, + ConstantOfShape, + Conv, + Conv1d, + Conv2d, + Conv3d, + ConvInteger, + ConvTranspose, + ConvTranspose1d, + ConvTranspose2d, + ConvTranspose3d, + Cos, + Cosh, + CumSum, + DepthToSpace, + DequantizeLinear, + Det, + DFT, + Div, + Dropout, + DynamicQuantizeLinear, + Einsum, + Elu, + Equal, + Erf, + Exp, + Expand, + EyeLike, + Flatten, + Floor, + Gather, + GatherElements, + GatherND, + Gelu, + Gemm, + GlobalAveragePool, + GlobalLpPool, + GlobalMaxPool, + Greater, + GreaterOrEqual, + GridSample, + GroupNormalization, + GRU, + HammingWindow, + HannWindow, + Hardmax, + HardSigmoid, + HardSwish, + Identity, + If, + Im, + InstanceNormalization, + IsInf, + IsNaN, + LayerNormalization, + LeakyRelu, + Less, + LessOrEqual, + Linear, + Log, + LogSoftmax, + Loop, + LpNormalization, + LpPool, + LRN, + LSTM, + MatMul, + MatMulInteger, + Max, + MaxPool, + MaxPool1d, + MaxPool2d, + MaxRoiPool, + MaxUnpool, + Mean, + MeanVarianceNormalization, + MelWeightMatrix, + Min, + Mish, + Mod, + Mul, + Multinomial, + Neg, + NegativeLogLikelihoodLoss, + NonMaxSuppression, + NonZero, + Not, + OneHot, + Optional, + OptionalGetElement, + OptionalHasElement, + Or, + Pad, + Pow, + PRelu, + QLinearConv, + QLinearMatMul, + QuantizeLinear, + RandomNormal, + RandomNormalLike, + RandomUniform, + RandomUniformLike, + Range, + Reciprocal, + ReduceL1, + ReduceL2, + ReduceLogSum, + ReduceLogSumExp, + ReduceMax, + ReduceMean, + ReduceMin, + ReduceProd, + ReduceSum, + ReduceSumSquare, + Relu, + Reshape, + Resize, + ReverseSequence, + RNN, + RoiAlign, + Round, + Scan, + Scatter, + ScatterElements, + ScatterND, + Selu, + SequenceAt, + SequenceConstruct, + SequenceEmpty, + SequenceErase, + SequenceInsert, + SequenceLength, + SequenceMap, + Shape, + Shrink, + Sigmoid, + Sign, + Sin, + Sinh, + Size, + Slice, + Softmax, + SoftmaxCrossEntropyLoss, + Softplus, + Softsign, + SpaceToDepth, + Split, + SplitToSequence, + Sqrt, + Squeeze, + STFT, + StringNormalizer, + Sub, + Sum, + Tan, + Tanh, + TfIdfVectorizer, + ThresholdedRelu, + Tile, + TopK, + Transpose, + Trilu, + Unique, + Unsqueeze, + Upsample, + Where, + Xor, +} diff --git a/packages/backend/src/modules/utils/project.rs b/packages/backend/src/modules/utils/project.rs new file mode 100644 index 0000000..80783bc --- /dev/null +++ b/packages/backend/src/modules/utils/project.rs @@ -0,0 +1,12 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +// -------------------- PROJECT -------------------- // +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct Project { + pub name: String, + pub date: DateTime, + pub desc: String, + pub author: String, + pub node: String, +} diff --git a/packages/backend/src/modules/utils/query_filter.rs b/packages/backend/src/modules/utils/query_filter.rs index 05ea1ba..7eec125 100644 --- a/packages/backend/src/modules/utils/query_filter.rs +++ b/packages/backend/src/modules/utils/query_filter.rs @@ -3,7 +3,7 @@ use fuzzy_matcher::{skim::SkimMatcherV2, FuzzyMatcher}; // ---------------- QUERY FILTER ---------------- // #[derive(Clone)] -pub enum QueryFilter { +pub enum NodeQueryFilter { Older { date: Date }, Newer { date: Date }, Author { author: String }, @@ -11,20 +11,50 @@ pub enum QueryFilter { Name { name: String }, } -impl QueryFilter { +#[derive(Clone)] +pub enum ProjectQueryFilter { + Older { date: Date }, + Newer { date: Date }, + Name { name: String }, + Author { author: String }, + Node { node: String }, +} + +impl NodeQueryFilter { pub fn is_ok(self, node: Node) -> bool { match self { - QueryFilter::Name { name } => { + NodeQueryFilter::Name { name } => { let matcher = SkimMatcherV2::default(); matcher.fuzzy_match(&node.name, &name).is_some() } - QueryFilter::Older { date } => date > node.date, - QueryFilter::Newer { date } => date < node.date, - QueryFilter::Author { author } => { + NodeQueryFilter::Older { date } => date > node.date, + NodeQueryFilter::Newer { date } => date < node.date, + NodeQueryFilter::Author { author } => { let matcher = SkimMatcherV2::default(); matcher.fuzzy_match(&node.author, &author).is_some() } - QueryFilter::Environment { env } => env.merge(&node.version.env).is_ok(), + NodeQueryFilter::Environment { env } => env.merge(&node.version.env).is_ok(), + } + } +} + +impl ProjectQueryFilter { + pub fn is_ok(self, project: &Project) -> bool { + match self { + ProjectQueryFilter::Name { name } => { + let matcher = SkimMatcherV2::default(); + matcher.fuzzy_match(&project.name, &name).is_some() + } + ProjectQueryFilter::Node { node } => { + let matcher = SkimMatcherV2::default(); + matcher.fuzzy_match(&project.node, &node).is_some() + } + ProjectQueryFilter::Older { date } => date > project.date, + ProjectQueryFilter::Newer { date } => date < project.date, + ProjectQueryFilter::Author { author } => { + let matcher = SkimMatcherV2::default(); + matcher.fuzzy_match(&project.author, &author).is_some() + } } } } diff --git a/packages/backend/src/modules/utils/save_node.rs b/packages/backend/src/modules/utils/save_node.rs index c93a890..299900e 100644 --- a/packages/backend/src/modules/utils/save_node.rs +++ b/packages/backend/src/modules/utils/save_node.rs @@ -1,30 +1,39 @@ use super::prelude::*; use derive_builder::Builder; +use onnx_ir::NodeType; use serde::{Deserialize, Serialize}; // -------------------- SAVE NODES -------------------- // #[derive(Clone, Serialize, Deserialize)] pub enum SaveNodeKind { - Code { code: String }, + Onnx { onnx: SaveOnnxNode }, Bundled { bundle: Vec }, } #[derive(Builder, Clone, Serialize, Deserialize)] pub struct SaveNode { pub name: String, - pub params: Vec, + pub version: Version, pub kind: SaveNodeKind, pub description: String, pub author: String, - pub compiled: Option, - pub version: Version, pub date: Date, + pub position: Option<(f64, f64)>, +} + +#[derive(Builder, Clone, Serialize, Deserialize)] +pub struct SaveOnnxNode { + #[serde(with = "OnnxNodeType")] + pub node_type: NodeType, + pub name: String, + pub inputs: Vec, + pub outputs: Vec, } impl From for SaveNodeKind { fn from(nodes: NodeContainer) -> Self { - let mut save_nodes: Vec = Vec::new(); + let mut save_nodes = Vec::new(); for context in nodes.tree.iter() { - let node: Node = context.context.try_lock().unwrap().to_owned(); + let node = context.context.try_lock().unwrap().to_owned(); save_nodes.push(node.into()); } SaveNodeKind::Bundled { bundle: save_nodes } @@ -38,21 +47,36 @@ impl From for SaveNode { .name(node.name.clone()) .description(node.description.clone()) .author(node.author.clone()) - .compiled(node.compiled.clone()) .version(Version { version: node.version.version.clone(), env: node.clone().get_full_env(), }) + .position(node.position) .date(node.date); - if let NodeKind::Code { code } = node.kind { - builder = builder.kind(SaveNodeKind::Code { code }); + if let NodeKind::Onnx { onnx } = node.kind { + builder = builder.kind(SaveNodeKind::Onnx { onnx: onnx.into() }); } else if let NodeKind::Bundled { bundle } = node.kind { builder = builder.kind(bundle.into()); } - builder = builder.params( - node.params + builder.build().unwrap() + } +} + +impl From for SaveOnnxNode { + fn from(node: OnnxNode) -> Self { + let mut binding = SaveOnnxNodeBuilder::default(); + let mut builder = binding.node_type(node.node_type).name(node.name); + + builder = builder.inputs( + node.inputs + .iter() + .map(|param| SaveParam::from(param.clone())) + .collect(), + ); + builder = builder.outputs( + node.outputs .iter() .map(|param| SaveParam::from(param.clone())) .collect(), diff --git a/packages/backend/tests.rs b/packages/backend/tests.rs deleted file mode 100644 index c5fb369..0000000 --- a/packages/backend/tests.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod create; diff --git a/packages/backend/tests/create.rs b/packages/backend/tests/create.rs index f2500a4..d0b2888 100644 --- a/packages/backend/tests/create.rs +++ b/packages/backend/tests/create.rs @@ -1,214 +1,81 @@ use chrono::Utc; -use sai_backend::nms::save::save_node; -use sai_backend::utils::prelude::*; -use uuid::Uuid; - -#[test] -fn test_create_code_node() { - let node = Node { - name: "test_code_node".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Hello, world!\"); }".to_string(), - }, - description: "A simple code node".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.1"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - let res = save_node(node); - println!("{:?}", res); - - assert!(res.is_ok()); -} +use simple_ai_backend::modules::{ + nodes::save::save_node, + utils::{ + dtype::DType, + node::{Node, NodeKind, OnnxNode}, + param::{ParamKind, RuntimeParamKind, StrongParam}, + prelude::{Environment, NodeContainer, ParamBuilder, StrongContext, Version, WeakContext}, + }, +}; #[test] fn test_create_bundled_node() { - let code_node1 = Node { - name: "code_node1".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Node 1\"); }".to_string(), - }, - description: "First code node".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.1"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), + let node1_param_out = ParamBuilder::default() + .name("Param1".into()) + .desc("A param".into()) + .dtype(DType::F32) + .kind(ParamKind::Static { value: "5".into() }) + .build() + .expect("Failed to build param"); + let node2_param_in = ParamBuilder::default() + .name("Param2".into()) + .desc("Another param".into()) + .dtype(DType::F32) + .kind(ParamKind::Static { value: "5".into() }) + .build() + .expect("Failed to build param"); + let node1_param_out = StrongParam::from(node1_param_out); + let node2_param_in = StrongParam::from(node2_param_in); + node1_param_out.context.try_lock().unwrap().kind = ParamKind::Runtime { + kind: RuntimeParamKind::Output, + connection: Some(WeakContext::from(node2_param_in.clone())), + id: 1, + }; + node2_param_in.context.try_lock().unwrap().kind = ParamKind::Runtime { + kind: RuntimeParamKind::Input, + connection: Some(WeakContext::from(node1_param_out.clone())), + id: 2, }; - let code_node2 = Node { - name: "code_node2".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Node 2\"); }".to_string(), + let node1 = Node { + name: "node1".to_string(), + kind: NodeKind::Onnx { + onnx: OnnxNode { + node_type: onnx_ir::NodeType::Relu, + name: "node1".into(), + inputs: vec![], + outputs: vec![StrongParam::from(node1_param_out)], + }, }, - description: "Second code node".to_string(), - author: "Author".to_string(), - compiled: None, + description: "First code node".to_string(), + author: "It's mee".to_string(), version: Version { version: String::from("0.0.1"), env: Environment { deps: vec![] }, }, date: Utc::now(), + position: Some((20.0, 50.0)), }; let mut nc: NodeContainer = NodeContainer::new(); - nc.push_context(StrongContext::from(code_node1)); - nc.push_context(StrongContext::from(code_node2)); + nc.push_context(StrongContext::from(node1)); let bundled_node = Node { name: "bundled_node".to_string(), - params: vec![], kind: NodeKind::Bundled { bundle: nc }, description: "A bundled node".to_string(), author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.1"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - let res = save_node(bundled_node); - - println!("{:?}", res); - - assert!(res.is_ok()); -} - -#[test] -fn test_create_complex_bundled_node() { - let static_param = StrongParam::from(Param { - name: "static_param".to_string(), - desc: "A static parameter".to_string(), - dtype: DType::String, - kind: ParamKind::Static { - value: "static_value".to_string(), - }, - }); - - let runtime_param1 = StrongParam::from(Param { - name: "runtime_param1".to_string(), - desc: "A runtime parameter".to_string(), - dtype: DType::String, - kind: ParamKind::Runtime { - kind: RuntimeParamKind::Input, - connection: None, - id: Uuid::new_v4().as_u128(), - }, - }); - - let runtime_param2 = StrongParam::from(Param { - name: "runtime_param2".to_string(), - desc: "Another runtime parameter".to_string(), - dtype: DType::String, - kind: ParamKind::Runtime { - kind: RuntimeParamKind::Output, - connection: Some(WeakContext::from(runtime_param1.clone())), - id: Uuid::new_v4().as_u128(), - }, - }); - - let code_node = Node { - name: "complex_code_node".to_string(), - params: vec![runtime_param1.clone(), runtime_param2.clone()], - kind: NodeKind::Code { - code: "fn main() { println!(\"Complex Node\"); }".to_string(), - }, - description: "A complex code node".to_string(), - author: "Author".to_string(), - compiled: None, version: Version { version: String::from("0.0.1"), - env: Environment { - deps: vec![Dependency { - name: "serde".to_string(), - versions: vec!["1.0".to_string()], - lib: true, - }], - }, - }, - date: Utc::now(), - }; - - let code_node2 = Node { - name: "complex_code_node".to_string(), - params: vec![runtime_param1.clone(), runtime_param2.clone()], - kind: NodeKind::Code { - code: "fn main() { println!(\"Complex Node\"); }".to_string(), - }, - description: "A complex code node".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.2"), - env: Environment { - deps: vec![ - Dependency { - name: "serde".to_string(), - versions: vec!["1.0".to_string()], - lib: true, - }, - Dependency { - name: "torch".to_string(), - versions: vec!["2.0".to_string()], - lib: true, - }, - ], - }, - }, - date: Utc::now(), - }; - - let mut nc: NodeContainer = NodeContainer::new(); - nc.push_context(StrongContext::from(code_node)); - - let mut nc2: NodeContainer = NodeContainer::new(); - nc2.push_context(StrongContext::from(code_node2)); - - let bundled_node = Node { - name: "complex_bundled_node".to_string(), - params: vec![static_param.clone()], - kind: NodeKind::Bundled { bundle: nc }, - description: "A complex bundled node".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.1"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - let bundled_node2 = Node { - name: "complex_bundled_node".to_string(), - params: vec![static_param.clone()], - kind: NodeKind::Bundled { bundle: nc2 }, - description: "A complex bundled node".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("0.0.2"), env: Environment { deps: vec![] }, }, date: Utc::now(), + position: None, }; let res = save_node(bundled_node); - let res2 = save_node(bundled_node2); println!("{:?}", res); - println!("{:?}", res2); assert!(res.is_ok()); - assert!(res2.is_ok()); } diff --git a/packages/backend/tests/delete.rs b/packages/backend/tests/delete.rs deleted file mode 100644 index 6379244..0000000 --- a/packages/backend/tests/delete.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::path::Path; - -use chrono::Utc; -use sai_backend::{ - nms::{delete::delete_node, save::save_node}, - utils::prelude::*, -}; - -#[test] -fn test_delete_entire_node() { - let node = Node { - name: "test_delete_node".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Delete me\"); }".to_string(), - }, - description: "A node to be deleted".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("1.0.0"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - let res = save_node(node.clone()); - assert!(res.is_ok()); - - let del_res = delete_node(String::from("test_delete_node"), None); - assert!(del_res.is_ok()); - - let node_path = Path::new("nodes/").join("test_delete_node"); - assert!(!node_path.exists(), "Node folder should be deleted"); -} - -#[test] -fn test_delete_specific_version() { - let node_v1 = Node { - name: "test_delete_version".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Version 1\"); }".to_string(), - }, - description: "A node with multiple versions".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("1.0.0"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - let node_v2 = Node { - name: "test_delete_version".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Version 2\"); }".to_string(), - }, - description: "A node with multiple versions".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("2.0.0"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - assert!(save_node(node_v1.clone()).is_ok()); - assert!(save_node(node_v2.clone()).is_ok()); - - let del_res = delete_node( - String::from("test_delete_version"), - Some(String::from("1.0.0")), - ); - assert!(del_res.is_ok()); - - let meta_path = Path::new("nodes/test_delete_version/meta.toml"); - assert!(meta_path.exists(), "Meta file should still exist"); - - let meta_content = std::fs::read_to_string(meta_path).unwrap(); - let meta: Metadata = toml::from_str(&meta_content).unwrap(); - - assert_eq!(meta.versions.len(), 1); - assert_eq!(meta.versions[0].version, "2.0.0"); -} - -#[test] -fn test_delete_nonexistent_node() { - let del_res = delete_node(String::from("nonexistent_node"), None); - assert!(del_res.is_err()); -} - -#[test] -fn test_delete_nonexistent_version() { - let node = Node { - name: "test_delete_invalid_version".to_string(), - params: vec![], - kind: NodeKind::Code { - code: "fn main() { println!(\"Only one version\"); }".to_string(), - }, - description: "A node with a single version".to_string(), - author: "Author".to_string(), - compiled: None, - version: Version { - version: String::from("1.0.0"), - env: Environment { deps: vec![] }, - }, - date: Utc::now(), - }; - - assert!(save_node(node.clone()).is_ok()); - - let del_res = delete_node( - String::from("test_delete_invalid_version"), - Some(String::from("2.0.0")), - ); - assert!(del_res.is_err()); -} diff --git a/packages/frontend/Cargo.toml b/packages/frontend/Cargo.toml index f06db88..4ec1830 100644 --- a/packages/frontend/Cargo.toml +++ b/packages/frontend/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] dioxus = { workspace = true, features = ["router"]} +uuid = { workspace = true } simple-ai-backend = { workspace = true } simple-ai-macros = { workspace = true } async-recursion = { version = "1.1" } diff --git a/packages/frontend/assets/scripts/onnx-drag-in.js b/packages/frontend/assets/scripts/onnx-drag-in.js index 514a306..2ca5084 100644 --- a/packages/frontend/assets/scripts/onnx-drag-in.js +++ b/packages/frontend/assets/scripts/onnx-drag-in.js @@ -5,64 +5,45 @@ // Also this could be replaced eventually with a native rendering solution from dixous class OnnxDragIn { constructor(rootelm, search, nodeContainer) { - if (!(rootelm instanceof HTMLElement)) { - throw new Error( - "Yout first argument is no valid DOM element. Please provide the drag in element itself.", - ); - } - if (!(search instanceof HTMLElement)) { - throw new Error( - "Your second argument is no valid DOM element. Please provide a search type input.", - ); - } - if (!(nodeContainer instanceof HTMLElement)) { - throw new Error( - "Your third argument is no valid DOM element. Please provide the container of the nodes.", - ); - } this.rootelm = rootelm; this.search = search; - this.draggingChild = null; this.nodeContainer = nodeContainer; - // this.nodeContainer.style.display = "relative"; - this.nodes = Array.from(this.nodeContainer.children); + this.draggingChild = null; + this._ondrag = this._ondrag.bind(this); this._ondrop = this._ondrop.bind(this); - this.start_drag = null; - this._attachListeners(); - } - _ondrop(e) { - this.draggingChild.style.left = "initial"; - this.draggingChild.style.top = "initial"; - this.draggingChild.style.position = "initial"; - window.activeOnnxViewport.addNodeFromId(this.draggingChild.id, { - position: { x: e.x, y: e.y }, + + this.nodeContainer.addEventListener("mousedown", (e) => { + let child = e.target.closest(".draggable-node"); + if (!child || !this.nodeContainer.contains(child)) return; + + this.draggingChild = child; + this.draggingChild.style.position = "fixed"; + window.addEventListener("mousemove", this._ondrag); + }); + + window.addEventListener("mouseup", (e) => { + if (this.draggingChild) { + window.removeEventListener("mousemove", this._ondrag); + this._ondrop(e); + } }); - this.draggingChild = null; } + _ondrag(e) { let rect = this.draggingChild.getBoundingClientRect(); this.draggingChild.style.left = `${e.x - rect.width / 2}px`; this.draggingChild.style.top = `${e.y - rect.height / 2}px`; } - _attachListeners() { - this.nodes.forEach((child) => { - console.log("hello"); - child.addEventListener("mousedown", (e) => { - this.draggingChild = child; - // let container_rect = this.nodeContainer.getBoundingClientRect(); - // let node_rect = this.draggingChild.getBoundingClientRect(); - // this.start_drag = { x: e.x, y: e.y }; - this.draggingChild.style.position = "fixed"; - window.addEventListener("mousemove", this._ondrag); - }); - }); - window.addEventListener("mouseup", (e) => { - if (this.draggingChild) { - window.removeEventListener("mousemove", this._ondrag); - this._ondrop(e); - } + + _ondrop(e) { + this.draggingChild.style.position = "initial"; + this.draggingChild.style.left = "initial"; + this.draggingChild.style.top = "initial"; + window.activeOnnxViewport.addNodeFromId(this.draggingChild.id, { + position: { x: e.x, y: e.y }, }); + this.draggingChild = null; } } diff --git a/packages/frontend/assets/scripts/onnx-viewport.js b/packages/frontend/assets/scripts/onnx-viewport.js index 47ae978..028b6d5 100644 --- a/packages/frontend/assets/scripts/onnx-viewport.js +++ b/packages/frontend/assets/scripts/onnx-viewport.js @@ -53,7 +53,7 @@ function drawRoundedRect(ctx, x, y, width, height, radius = 6, options = {}) { } } -class Parameter { +class VNodeParameter { constructor(type = "input", name = "") { this.type = type; // "input" or "output" this.name = name; @@ -87,7 +87,7 @@ class VNode { this.label = label; // params is array of Parameter instances or plain objects {type,name} this.params = params.map((p) => - p instanceof Parameter ? p : new Parameter(p.type, p.name), + p instanceof VNodeParameter ? p : new VNodeParameter(p.type, p.name), ); this.inputs = this.params.filter((p) => p.type === "input"); this.outputs = this.params.filter((p) => p.type === "output"); @@ -186,7 +186,7 @@ class VNode { } } -class Connection { +class VConnection { constructor(fromNode, fromOutput, toNode, toInput) { this.fromNode = fromNode; this.fromOutput = fromOutput; @@ -232,10 +232,13 @@ class Connection { } class Viewport { - constructor(containerElement, options = {}) { + constructor(dioxus, containerElement, options = {}) { if (!(containerElement instanceof HTMLElement)) { throw new Error("Viewport constructor requires a DOM container element"); } + + this.dioxus = dioxus; + // options and defaults this.gridSpacing = options.gridSpacing || 40; this.dotRadius = options.dotRadius || 2; @@ -292,13 +295,6 @@ class Viewport { this._onResize(); } - // convenience factory - createNode(x, y, label, params = []) { - const node = new VNode(x, y, label, params); - this.addNode(node); - return node; - } - addNode(node) { this.nodes.push(node); this.draw(); @@ -312,12 +308,24 @@ class Viewport { let { x, y } = this.toEditor(position.x - rect.x, position.y - rect.y); - this.addNode( - new window.VNode(x, y, "SampleNODE", [ - new window.Parameter("input", "inA"), - new window.Parameter("output", "outA"), - ]), + this.dioxus.send({ + AddNode: { id: "00000000-0000-0000-0000-000000000000", x: x, y: y }, + }); + } + + handleAddNode(jsonNode) { + let node = new VNode( + jsonNode.x, + jsonNode.y, + jsonNode.label, + jsonNode.params, ); + console.log(node); + this.addNode(node); + } + + save() { + return { nodes: this.nodes, connections: this.connections }; } addConnection(conn) { @@ -521,7 +529,7 @@ class Viewport { let inIdx = node.inputHit(p.x, p.y); if (inIdx !== null && node !== this.connectingFrom.node) { this.connections.push( - new Connection( + new VConnection( this.connectingFrom.node, this.connectingFrom.outIdx, node, @@ -602,6 +610,6 @@ class Viewport { } } -window.Parameter = Parameter; +window.Parameter = VNodeParameter; window.VNode = VNode; window.Viewport = Viewport; diff --git a/packages/frontend/assets/style/items/search.css b/packages/frontend/assets/style/items/search.css index f3d9ff7..3a2734d 100644 --- a/packages/frontend/assets/style/items/search.css +++ b/packages/frontend/assets/style/items/search.css @@ -8,11 +8,13 @@ } .Search > header { + width: 100%; position: sticky; background-color: rgba(var(--background-highlight-color), 0.5); } .Search > main { + width: 100%; padding: 1rem 0; flex-grow: 1; display: flex; diff --git a/packages/frontend/src/modules/components.rs b/packages/frontend/src/modules/components.rs index f8f493d..6e91483 100644 --- a/packages/frontend/src/modules/components.rs +++ b/packages/frontend/src/modules/components.rs @@ -55,5 +55,4 @@ pub mod prelude { // %% utils %% pub(crate) mod utils { pub use crate::utils::*; - pub use simple_ai_backend::prelude::*; } diff --git a/packages/frontend/src/modules/components/connection.rs b/packages/frontend/src/modules/components/connection.rs index 43e91d8..f11dad5 100644 --- a/packages/frontend/src/modules/components/connection.rs +++ b/packages/frontend/src/modules/components/connection.rs @@ -1,5 +1,7 @@ // %%% components / connection.rs %%% +use simple_ai_backend::modules::utils::param::RuntimeParamKind; + // %% includes %% use super::utils::*; diff --git a/packages/frontend/src/modules/components/node.rs b/packages/frontend/src/modules/components/node.rs index 449cae3..f7ec3fa 100644 --- a/packages/frontend/src/modules/components/node.rs +++ b/packages/frontend/src/modules/components/node.rs @@ -1,13 +1,16 @@ // %%% components / node.rs %%% +use simple_ai_backend::modules::utils::node::StrongNode; +use simple_ai_backend::modules::utils::param::ParamKind; + // %% includes %% use super::runtime_param::{InternRuntimeParam, RuntimeParam}; -use super::static_param::{InternStaticParam, StaticParam}; +use super::static_param::InternStaticParam; use super::utils::*; // %% main %% // % transferrer % // -pub static NODE_TRANSFERER: GlobalSignal> = +pub static NODE_TRANSFERER: GlobalSignal> = GlobalSignal::new(|| None); // % Node % // @@ -41,7 +44,7 @@ impl From for InternNode { let node = node_ctx.context.try_lock().unwrap(); let mut runtime_params = Vec::::new(); let mut static_params = Vec::::new(); - node.params.iter().for_each(move |param_ctx| { + node.get_params().iter().for_each(move |param_ctx| { let param = param_ctx.context.try_lock().unwrap(); match param.kind { ParamKind::Runtime { .. } => { @@ -79,12 +82,11 @@ pub fn Node(intern: InternNode) -> Element { } }); - let rendered_params = intern - .runtime_params - .iter() - .map(|intern| rsx! { + let rendered_params = intern.runtime_params.iter().map(|intern| { + rsx! { RuntimeParam { intern: intern.clone() } - }); + } + }); rsx! { body { diff --git a/packages/frontend/src/modules/components/runtime_param.rs b/packages/frontend/src/modules/components/runtime_param.rs index 0c4c67d..de256c0 100644 --- a/packages/frontend/src/modules/components/runtime_param.rs +++ b/packages/frontend/src/modules/components/runtime_param.rs @@ -1,5 +1,7 @@ // %%% components / runtime_param.rs %%% +use simple_ai_backend::modules::utils::param::{ParamKind, RuntimeParamKind, StrongParam}; + // %% includes %% use super::connection::{Connection, InternConnection}; use super::utils::*; @@ -30,8 +32,7 @@ impl From for InternRuntimeParam { #[component] pub fn RuntimeParam(intern: InternRuntimeParam) -> Element { rsx! { - body { - class: "Param", + body { class: "Param", Connection { intern: (intern.connection)() } } } diff --git a/packages/frontend/src/modules/components/search.rs b/packages/frontend/src/modules/components/search.rs index 1098c29..4738612 100644 --- a/packages/frontend/src/modules/components/search.rs +++ b/packages/frontend/src/modules/components/search.rs @@ -3,20 +3,18 @@ // %% includes %% use super::search_result::{InternSearchResult, SearchResult}; use super::utils::*; -use chrono::Utc; -use simple_ai_backend::modules::utils::node::NodeBuilder; +use simple_ai_backend::modules::nodes::query::query_nodes; +use simple_ai_backend::modules::utils::prelude::NodeQueryFilter; use tokio::time::*; // %% main %% #[item] pub fn Search(#[props(extends = GlobalAttributes)] attributes: Vec) -> Element { let mut intern_search_results = use_signal(Vec::::new); - let mut search_results = use_signal(Container::new); + let mut search_results = use_signal(|| query_nodes(vec![])); let input = move |e: FormEvent| { - search_results.set(query::query_nodes(vec![QueryFilter::Name { - name: e.value(), - }])); + search_results.set(query_nodes(vec![NodeQueryFilter::Name { name: e.value() }])); intern_search_results.clear(); }; @@ -24,30 +22,11 @@ pub fn Search(#[props(extends = GlobalAttributes)] attributes: Vec) - intern_search_results.set( search_results() .iter() - .map(|result| InternSearchResult::from(result.context.blocking_lock().clone())) + .map(|result| InternSearchResult::from(result.context.try_lock().unwrap().clone())) .collect::>(), ); }); - // Todo: remove this its just a test - let intern = InternSearchResult::from( - NodeBuilder::default() - .name("SampleNode".to_string()) - .params(Vec::new()) - .version(Version { - version: "0.0.1".to_string(), - env: Environment { deps: Vec::new() }, - }) - .kind(NodeKind::Bundled { - bundle: Container::new(), - }) - .description("this is a sample node".to_string()) - .author("sert".to_string()) - .date(Utc::now()) - .build() - .unwrap(), - ); - let future = use_resource(move || async move { // You can create as many eval instances as you want let mut eval = document::eval(r#"console.log("HELLOO")"#); @@ -85,12 +64,6 @@ pub fn Search(#[props(extends = GlobalAttributes)] attributes: Vec) - for intern in intern_search_results() { SearchResult { intern } } - SearchResult { intern } - SearchResult { intern } - SearchResult { intern } - SearchResult { intern } - SearchResult { intern } - SearchResult { intern } } } } diff --git a/packages/frontend/src/modules/components/search_result.rs b/packages/frontend/src/modules/components/search_result.rs index 62b05b5..f67f481 100644 --- a/packages/frontend/src/modules/components/search_result.rs +++ b/packages/frontend/src/modules/components/search_result.rs @@ -1,7 +1,8 @@ // %%% components / search_result.rs %%% +use simple_ai_backend::modules::utils::node::Node; + // %% includes %% -use super::draggable::Draggable; use super::node::NODE_TRANSFERER; use super::utils::*; @@ -29,6 +30,8 @@ pub fn SearchResult(intern: InternSearchResult) -> Element { rsx! { article { + class: "draggable-node", + id: "{intern.node.cloned().author}/{intern.node.cloned().name}", h3 { span { id: "name", "{intern.node.cloned().name}" } } diff --git a/packages/frontend/src/modules/components/static_param.rs b/packages/frontend/src/modules/components/static_param.rs index e508d85..3542da4 100644 --- a/packages/frontend/src/modules/components/static_param.rs +++ b/packages/frontend/src/modules/components/static_param.rs @@ -1,8 +1,9 @@ // %%% components / static_param.rs %%% +use simple_ai_backend::modules::utils::param::StrongParam; + // %% includes %% use super::utils::*; -use simple_ai_backend::prelude::*; // %% main %% #[derive(PartialEq, Props, Clone)] @@ -19,9 +20,7 @@ impl From for InternStaticParam { #[component] pub fn StaticParam(style: String, intern: InternStaticParam) -> Element { rsx! { - style { { style } } - body { - class: "Param", - } + style { {style} } + body { class: "Param" } } } diff --git a/packages/frontend/src/modules/components/viewport.rs b/packages/frontend/src/modules/components/viewport.rs index 99d9af8..4ffe780 100644 --- a/packages/frontend/src/modules/components/viewport.rs +++ b/packages/frontend/src/modules/components/viewport.rs @@ -1,10 +1,11 @@ // %%% components / viewport.rs %%% +use simple_ai_backend::modules::utils::{node::StrongNode, prelude::NodeContainer}; + use crate::modules::pages::utils::NODE_TRANSFERER; // %% includes %% use super::utils::*; -use simple_ai_backend::prelude::*; // %% main %% #[derive(Clone)] @@ -38,9 +39,11 @@ pub fn Viewport( ) -> Element { // ------------------------------ VARIABLES ------------------------------ // let fnc = node_container().frontend_node_container.cloned(); - let rendered_nodes = fnc - .iter() - .map(|intern| rsx! { super::node::Node { intern: intern.clone() } }); + let rendered_nodes = fnc.iter().map(|intern| { + rsx! { + super::node::Node { intern: intern.clone() } + } + }); let get_client_rect = move || async move { if let Some(data) = mounted_data() { @@ -190,7 +193,7 @@ pub fn Viewport( transform: "translate({position().x}px, {position().y}px) scale({scale()})", user_select: "none", onmounted: mounted, - { rendered_nodes } + {rendered_nodes} } } } diff --git a/packages/frontend/src/modules/pages/edit.rs b/packages/frontend/src/modules/pages/edit.rs index 2288282..f522bf5 100644 --- a/packages/frontend/src/modules/pages/edit.rs +++ b/packages/frontend/src/modules/pages/edit.rs @@ -9,8 +9,8 @@ pub struct Project { #[page] pub fn Edit() -> Element { let mut proj = Project { - name: "".into(), - description: "".into(), + name: "bundlenode".into(), + description: "hello this is bundlenode".into(), }; rsx! { main { diff --git a/packages/frontend/src/modules/pages/editor.rs b/packages/frontend/src/modules/pages/editor.rs index f430a76..3a733dd 100644 --- a/packages/frontend/src/modules/pages/editor.rs +++ b/packages/frontend/src/modules/pages/editor.rs @@ -1,11 +1,121 @@ use super::utils::*; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; use tokio::time::*; +use uuid::Uuid; const CREATE_VIEWPORT: &str = r#" - window.activeOnnxViewport = new window.Viewport(document.getElementById("viewport")); - // window.activeOnnxViewport.listener(); + window.activeOnnxViewport = new window.Viewport(dioxus, document.getElementById("viewport")); "#; +#[derive(Serialize, Deserialize, Debug)] +struct VNodeParameter { + r#type: String, + name: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct VNode { + x: f32, + y: f32, + label: String, + params: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +struct VConnection { + from_node: VNode, + from_output: VNodeParameter, + to_node: VNode, + to_input: VNodeParameter, +} + +#[derive(Serialize, Deserialize, Debug)] +struct ViewportSave { + nodes: Vec, + connections: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +struct AddNodeData { + id: Uuid, + x: f32, + y: f32, +} + +#[derive(Serialize, Deserialize, Debug)] +struct SaveNodeData { + viewport_save: ViewportSave, +} + +/// # handy for json representation: +/// let dbg: String = ViewportEvent::AddNode(Uuid::default()).into(); +/// debug!("{}", dbg); +/// +/// ``` { AddNode: ".." } ``` +#[derive(Serialize, Deserialize, Debug)] +enum ViewportEvent { + AddNode(AddNodeData), + Save(SaveNodeData), +} + +impl ViewportEvent { + pub fn exec(self) { + match self { + Self::AddNode(AddNodeData { id, x, y }) => { + // TODO: let node = simple_ai_backend::modules::nodes::query_onnx(); + + // then remove this + let params = vec![ + VNodeParameter { + r#type: "input".into(), + name: "in1".into(), + }, + VNodeParameter { + r#type: "output".into(), + name: "out1".into(), + }, + ]; + let node = VNode { + x, + y, + label: "sample".into(), + params, + }; + + let json_node = serde_json::to_string(&node).unwrap(); + + document::eval(&format!( + r#"window.activeOnnxViewport.handleAddNode({})"#, + json_node + )); + } + Self::Save(SaveNodeData { viewport_save }) => { + // TODO: simple_ai_backend::modules::porject::save_onnx(viewport_save); + debug!("Saving..."); + } + } + } +} + +impl From for ViewportEvent { + fn from(value: JsonValue) -> Self { + serde_json::from_value(value).expect("Could not convert ViewportEvent from json") + } +} + +impl Into for ViewportEvent { + fn into(self) -> String { + serde_json::to_string(&self).expect("Could not convert ViewportEvent to json.") + } +} + +impl From for ViewportEvent { + fn from(value: String) -> Self { + serde_json::from_str(&value).expect("Could not convert ViewportEvent from json string.") + } +} + // TODO: // - add a train button // -> add something like simple_ai_backend :: onnx :: Project . train() @@ -19,21 +129,26 @@ pub fn Editor(children: Element) -> Element { main { onmounted: move |e| async move { sleep(Duration::from_millis(100)).await; - let mut viewport_listener = document::eval(CREATE_VIEWPORT); + + let mut handle = document::eval(CREATE_VIEWPORT); + + loop { + let s: JsonValue = handle.recv().await.expect("error recieving string"); + let e: ViewportEvent = s.into(); + e.exec(); + } + // loop { // TODO: also convert the string to a rust new Event enum // - the event types of the enum are AddNode and CheckConnection // - the backend will need something like simple_ai_backend :: onnx :: // check_connection (param1, param2) // - // let event: String = viewport_listener.recv().await.unwrap(); - // - // - // TODO: simple_ai_backend :: onnx :: fetch_node_from_id(id) -> Node; - // - Then make a function that converts the Node to the js one; - // - Lastly document::eval(format!(r#"window.activeOnnxViewport.addNode({})"#, node)); - // + // TODO: If something failes make a notification. + // debug!("{:?}", viewport_listener.recv::< String > (). await); + // ViewportEvent::from(viewport_listener.recv::().await.unwrap()) + // .exec(); // } }, section { id: "viewport" } diff --git a/packages/frontend/src/modules/pages/projects.rs b/packages/frontend/src/modules/pages/projects.rs index 7790933..817bf23 100644 --- a/packages/frontend/src/modules/pages/projects.rs +++ b/packages/frontend/src/modules/pages/projects.rs @@ -1,34 +1,30 @@ +use simple_ai_backend::modules::projects::query::query_projects; +use simple_ai_backend::modules::utils::prelude::ProjectQueryFilter; + use super::super::components::project::Project; use super::utils::*; -use chrono::{DateTime, Utc}; -use dioxus::router::NavigationTarget; #[page] pub fn Projects() -> Element { + let mut search_results = use_signal(|| query_projects(vec![])); + let input = move |e: FormEvent| { + search_results.set(query_projects(vec![ProjectQueryFilter::Name { + name: e.value(), + }])); + }; + rsx! { main { - input { r#type: "search" } + input { + oninput: input, + r#type: "search", + placeholder: "search", + id: "search", + } article { class: "projects-wrapper", div { class: "projects-view", - Project { - name: "sample project", - date: Utc::now(), - desc: "this is a sample Project for development", - } - Project { - name: "sample project", - date: Utc::now(), - desc: "this is a sample Project for development", - } - Project { - name: "sample project", - date: Utc::now(), - desc: "this is a sample Project for development", - } - Project { - name: "sample project", - date: Utc::now(), - desc: "this is a sample Project for development", + for res in search_results() { + Project { name: res.name, date: res.date, desc: res.date } } } } diff --git a/packages/macros/src/formifiable.rs b/packages/macros/src/formifiable.rs index 3cbd306..aed1deb 100644 --- a/packages/macros/src/formifiable.rs +++ b/packages/macros/src/formifiable.rs @@ -19,6 +19,7 @@ pub fn macro_impl(item: TokenStream) -> TokenStream { input { class: "FormifyInput", name: {#field_name}, + value: {self.#field_ident.clone()}, required: true, r#type: "text", } diff --git a/projects/Author/something.json b/projects/Author/something.json new file mode 100644 index 0000000..4c1c5b0 --- /dev/null +++ b/projects/Author/something.json @@ -0,0 +1,7 @@ +{ + "name": "bundled_node", + "desc": "A bundled node", + "author": "Author", + "node": "bundled_node", + "date": "2025-12-04T12:05:46.983020666Z" +}