From cc9af5121dcbed0a10e85ac6dad9305be32ea136 Mon Sep 17 00:00:00 2001 From: Cameron Cross Date: Thu, 6 Nov 2025 20:25:22 +1100 Subject: [PATCH 1/5] Work with Result --- examples/trivial.rs | 10 +++++----- inner/src/lib.rs | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/examples/trivial.rs b/examples/trivial.rs index 7ee1e05..aec17a2 100644 --- a/examples/trivial.rs +++ b/examples/trivial.rs @@ -1,16 +1,16 @@ use memoize::memoize; #[memoize] -fn hello(arg: String, arg2: usize) -> bool { +fn hello(arg: String, arg2: usize) -> Result { println!("{} => {}", arg, arg2); - arg.len() % 2 == arg2 + Ok(arg.len() % 2 == arg2) } fn main() { // `hello` is only called once here. - assert!(!hello("World".to_string(), 0)); - assert!(!hello("World".to_string(), 0)); + assert!(!hello("World".to_string(), 0).unwrap()); + assert!(!hello("World".to_string(), 0).unwrap()); // Sometimes one might need the original function. - assert!(!memoized_original_hello("World".to_string(), 0)); + assert!(!memoized_original_hello("World".to_string(), 0).unwrap()); memoized_flush_hello(); } diff --git a/inner/src/lib.rs b/inner/src/lib.rs index 4bbbde4..45e1528 100644 --- a/inner/src/lib.rs +++ b/inner/src/lib.rs @@ -150,6 +150,35 @@ mod store { mod store { use crate::CacheOptions; use proc_macro::TokenStream; + use quote::quote; + use syn::{parse2, AngleBracketedGenericArguments, PathArguments, Type, TypePath}; + + fn get_inner_value(outer: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + // Parse the input as a Rust type + let input_ty = parse2::(outer).expect("failed to parse outer type"); + + // Ensure it’s a path type (e.g., Result) + if let Type::Path(path) = input_ty { + // Look at the last segment (e.g., "Result") + let last_segment = path.path.segments.last().expect("Expected a Result"); + + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = + &last_segment.arguments + { + // The first generic argument of Result is the Ok type + if let Some(syn::GenericArgument::Type(ok_type)) = args.first() { + return quote! { #ok_type }.into() + } else { + panic!("Expected a type argument inside Result<>"); + } + } else { + panic!("Expected angle bracketed generic arguments"); + } + } else { + panic!("Expected a type path like Result"); + }; + } + /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable, /// and initializing it. @@ -161,6 +190,8 @@ mod store { key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let value_type = get_inner_value(value_type); + let value_type = match options.time_to_live { None => quote::quote! {#value_type}, Some(_) => quote::quote! {(std::time::Instant, #value_type)}, @@ -384,15 +415,15 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { { let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo { - return ATTR_MEMOIZE_RETURN__ + return Ok(ATTR_MEMOIZE_RETURN__) } } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); #memoize - ATTR_MEMOIZE_RETURN__ + Ok(ATTR_MEMOIZE_RETURN__) } } else { quote::quote! { @@ -401,17 +432,17 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { #read_memo }); if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ { - return ATTR_MEMOIZE_RETURN__; + return Ok(ATTR_MEMOIZE_RETURN__); } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; #store_ident.with(|ATTR_MEMOIZE_HM__| { let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut(); #memoize }); - ATTR_MEMOIZE_RETURN__ + Ok(ATTR_MEMOIZE_RETURN__) } }; From 5f430a09702063da8ed857a7d6f0221d20d4ed50 Mon Sep 17 00:00:00 2001 From: Cameron Cross Date: Sat, 8 Nov 2025 03:46:18 +1100 Subject: [PATCH 2/5] Tests, better understanding of how this all works. Revert changes to ignore.rs --- examples/trivial.rs | 23 +++++-- inner/Cargo.toml | 3 + inner/src/lib.rs | 155 +++++++++++++++++++++++++++++++++----------- 3 files changed, 139 insertions(+), 42 deletions(-) diff --git a/examples/trivial.rs b/examples/trivial.rs index aec17a2..3d6bb07 100644 --- a/examples/trivial.rs +++ b/examples/trivial.rs @@ -1,16 +1,29 @@ use memoize::memoize; #[memoize] -fn hello(arg: String, arg2: usize) -> Result { +fn hello1(arg: String, arg2: usize) -> bool { + println!("{} => {}", arg, arg2); + arg.len() % 2 == arg2 +} + +#[memoize] +fn hello2(arg: String, arg2: usize) -> Result { println!("{} => {}", arg, arg2); Ok(arg.len() % 2 == arg2) } fn main() { // `hello` is only called once here. - assert!(!hello("World".to_string(), 0).unwrap()); - assert!(!hello("World".to_string(), 0).unwrap()); + assert!(hello1("World1".to_string(), 0)); + assert!(hello1("World1".to_string(), 0)); + // Sometimes one might need the original function. + assert!(memoized_original_hello1("World1".to_string(), 0)); + memoized_flush_hello1(); + + // `hello` is only called once here. + assert!(hello2("World2".to_string(), 0).unwrap()); + assert!(hello2("World2".to_string(), 0).unwrap()); // Sometimes one might need the original function. - assert!(!memoized_original_hello("World".to_string(), 0).unwrap()); - memoized_flush_hello(); + assert!(memoized_original_hello2("World2".to_string(), 0).unwrap()); + memoized_flush_hello2(); } diff --git a/inner/Cargo.toml b/inner/Cargo.toml index 4b07514..a5cb892 100644 --- a/inner/Cargo.toml +++ b/inner/Cargo.toml @@ -19,6 +19,9 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } +[dev-dependencies] +parameterized = "2.1.0" + [features] default = [] full = [] diff --git a/inner/src/lib.rs b/inner/src/lib.rs index 45e1528..2571efc 100644 --- a/inner/src/lib.rs +++ b/inner/src/lib.rs @@ -1,6 +1,6 @@ #![crate_type = "proc-macro"] #![allow(unused_imports)] // Spurious complaints about a required trait import. -use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ExprCall, ItemFn, Path}; +use syn::{self, parse, parse2, parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprCall, ItemFn, Path, PathArguments, Type}; use proc_macro::TokenStream; use quote::{self, ToTokens}; @@ -110,6 +110,40 @@ impl parse::Parse for CacheOptions { } } + +fn check_for_result_type(outer: proc_macro2::TokenStream) -> bool { + // Parse the input as a Rust type + let input_ty = parse2::(outer).expect("failed to parse outer type"); + + if let Type::Path(path) = input_ty { + return path.path.segments.last().expect("O length path?").ident == "Result"; + } + false +} + +fn try_unwrap_result_type(outer: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let original = outer.clone(); + // Parse the input as a Rust type + let input_ty = parse2::(outer).expect("failed to parse outer type"); + + // Ensure it’s a path type (e.g., Result) + if let Type::Path(path) = input_ty { + // Look at the last segment (e.g., "Result") + let last_segment = path.path.segments.last().expect("Expected a Result"); + + if last_segment.ident == "Result" { + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &last_segment.arguments + { + // The first generic argument of Result is the Ok type + if let Some(syn::GenericArgument::Type(ok_type)) = args.first() { + return ok_type.to_token_stream() + } + } + } + } + original +} + // This implementation of the storage backend does not depend on any more crates. #[cfg(not(feature = "full"))] mod store { @@ -122,6 +156,7 @@ mod store { key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let value_type = crate::try_unwrap_result_type(value_type); // This is the unbounded default. if let Some(hasher) = &_options.custom_hasher { return ( @@ -148,37 +183,11 @@ mod store { // This implementation of the storage backend also depends on the `lru` crate. #[cfg(feature = "full")] mod store { - use crate::CacheOptions; + use crate::{try_unwrap_result_type, CacheOptions}; use proc_macro::TokenStream; use quote::quote; use syn::{parse2, AngleBracketedGenericArguments, PathArguments, Type, TypePath}; - fn get_inner_value(outer: proc_macro2::TokenStream) -> proc_macro2::TokenStream { - // Parse the input as a Rust type - let input_ty = parse2::(outer).expect("failed to parse outer type"); - - // Ensure it’s a path type (e.g., Result) - if let Type::Path(path) = input_ty { - // Look at the last segment (e.g., "Result") - let last_segment = path.path.segments.last().expect("Expected a Result"); - - if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = - &last_segment.arguments - { - // The first generic argument of Result is the Ok type - if let Some(syn::GenericArgument::Type(ok_type)) = args.first() { - return quote! { #ok_type }.into() - } else { - panic!("Expected a type argument inside Result<>"); - } - } else { - panic!("Expected angle bracketed generic arguments"); - } - } else { - panic!("Expected a type path like Result"); - }; - } - /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable, /// and initializing it. @@ -190,7 +199,7 @@ mod store { key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - let value_type = get_inner_value(value_type); + let value_type = try_unwrap_result_type(value_type); let value_type = match options.time_to_live { None => quote::quote! {#value_type}, @@ -232,7 +241,6 @@ mod store { } } } - /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a /// store. pub(crate) fn cache_access_methods( @@ -410,20 +418,40 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { ), }; + let get_value = if check_for_result_type(return_type.clone()) { + quote::quote! { + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; + } + } else { + quote::quote! { + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + } + }; + + let return_value = if check_for_result_type(return_type.clone()) { + quote::quote! { + Ok(ATTR_MEMOIZE_RETURN__) + } + } else { + quote::quote! { + ATTR_MEMOIZE_RETURN__ + } + }; + let memoizer = if options.shared_cache { quote::quote! { { let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo { - return Ok(ATTR_MEMOIZE_RETURN__) + return #return_value; } } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; + #get_value let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); #memoize - Ok(ATTR_MEMOIZE_RETURN__) + #return_value } } else { quote::quote! { @@ -432,17 +460,17 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { #read_memo }); if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ { - return Ok(ATTR_MEMOIZE_RETURN__); + return #return_value; } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; + #get_value #store_ident.with(|ATTR_MEMOIZE_HM__| { let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut(); #memoize }); - Ok(ATTR_MEMOIZE_RETURN__) + #return_value } }; @@ -536,4 +564,57 @@ fn check_signature( } #[cfg(test)] -mod tests {} +mod tests { + use std::str::FromStr; + use parameterized::parameterized; + use proc_macro2::TokenStream; + use quote::quote; + use crate::{check_for_result_type, try_unwrap_result_type}; + + #[parameterized(typestr = { + "Result", + "anyhow::Result", + "std::io::Result", + "io::Result", + })] + fn test_check_for_result_type_success(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(true, check_for_result_type(input)); + } + + #[parameterized(typestr = { + "Option", + "(bool, bool)", + "bool", + })] + fn test_check_for_result_type_fail(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(false, check_for_result_type(input)); + } + + #[parameterized(params = { + ("Result", "bool"), + ("anyhow::Result", "bool"), + ("std::io::Result", "bool"), + ("io::Result", "bool"), + })] + fn test_try_unwrap_result_type_inner(params: (&str, &str)) { + let (input_type, output_type) = params; + let input = TokenStream::from_str(input_type).unwrap(); + let output = TokenStream::from_str(output_type).unwrap(); + assert_eq!(output_type, + try_unwrap_result_type(input).to_string()); + } + + #[parameterized(typestr = { + "Option < bool >", + "(bool , bool)", + "bool", + })] + fn test_try_unwrap_result_type_original(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(typestr.replace(" ", ""), + try_unwrap_result_type(input).to_string().replace(" ", "")); + } + +} From 6c6feedcc640326df943be766a28c14118c027b3 Mon Sep 17 00:00:00 2001 From: Cameron Cross Date: Sat, 8 Nov 2025 11:40:58 +1100 Subject: [PATCH 3/5] Handle other Result types, now only care that Result is in the name of the type. (doco?) Better testing. --- inner/src/lib.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/inner/src/lib.rs b/inner/src/lib.rs index 2571efc..39408b9 100644 --- a/inner/src/lib.rs +++ b/inner/src/lib.rs @@ -131,7 +131,7 @@ fn try_unwrap_result_type(outer: proc_macro2::TokenStream) -> proc_macro2::Token // Look at the last segment (e.g., "Result") let last_segment = path.path.segments.last().expect("Expected a Result"); - if last_segment.ident == "Result" { + if last_segment.ident.to_string().contains("Result") { if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &last_segment.arguments { // The first generic argument of Result is the Ok type @@ -571,6 +571,11 @@ mod tests { use quote::quote; use crate::{check_for_result_type, try_unwrap_result_type}; + #[test] + fn test() { + + } + #[parameterized(typestr = { "Result", "anyhow::Result", @@ -594,14 +599,14 @@ mod tests { #[parameterized(params = { ("Result", "bool"), - ("anyhow::Result", "bool"), - ("std::io::Result", "bool"), - ("io::Result", "bool"), + ("anyhow::Result", "u32"), + ("std::io::Result", "String"), + ("io::Result<(u32, u32)>", "(u32 , u32)"), + ("CustomResult", "CustomStruct"), })] fn test_try_unwrap_result_type_inner(params: (&str, &str)) { let (input_type, output_type) = params; let input = TokenStream::from_str(input_type).unwrap(); - let output = TokenStream::from_str(output_type).unwrap(); assert_eq!(output_type, try_unwrap_result_type(input).to_string()); } From f4452aee37ed032e279d0fe5801d18f444b7f2c9 Mon Sep 17 00:00:00 2001 From: Cameron Cross Date: Sun, 7 Dec 2025 16:09:23 +1100 Subject: [PATCH 4/5] Review fixes. --- examples/result.rs | 16 ++++++++++++++++ inner/src/lib.rs | 3 --- 2 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 examples/result.rs diff --git a/examples/result.rs b/examples/result.rs new file mode 100644 index 0000000..7ee1e05 --- /dev/null +++ b/examples/result.rs @@ -0,0 +1,16 @@ +use memoize::memoize; + +#[memoize] +fn hello(arg: String, arg2: usize) -> bool { + println!("{} => {}", arg, arg2); + arg.len() % 2 == arg2 +} + +fn main() { + // `hello` is only called once here. + assert!(!hello("World".to_string(), 0)); + assert!(!hello("World".to_string(), 0)); + // Sometimes one might need the original function. + assert!(!memoized_original_hello("World".to_string(), 0)); + memoized_flush_hello(); +} diff --git a/inner/src/lib.rs b/inner/src/lib.rs index 4bbbde4..ec1cb7f 100644 --- a/inner/src/lib.rs +++ b/inner/src/lib.rs @@ -503,6 +503,3 @@ fn check_signature( } Ok(params) } - -#[cfg(test)] -mod tests {} From 90d3f49ca244108989d2c43d130ac4f55d31c707 Mon Sep 17 00:00:00 2001 From: Cameron Cross Date: Sun, 7 Dec 2025 16:25:58 +1100 Subject: [PATCH 5/5] Tweaks --- examples/trivial.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/trivial.rs b/examples/trivial.rs index c0fd68e..7ee1e05 100644 --- a/examples/trivial.rs +++ b/examples/trivial.rs @@ -13,4 +13,4 @@ fn main() { // Sometimes one might need the original function. assert!(!memoized_original_hello("World".to_string(), 0)); memoized_flush_hello(); -} \ No newline at end of file +}