diff --git a/examples/result.rs b/examples/result.rs new file mode 100644 index 0000000..928d710 --- /dev/null +++ b/examples/result.rs @@ -0,0 +1,16 @@ +use memoize::memoize; + +#[memoize] +fn hello(arg: String, arg2: usize) -> std::io::Result { + println!("{} => {}", arg, arg2); + Ok(arg.len() % 2 == arg2) +} + +fn main() { + // `hello` is only called once here. + assert!(hello("World2".to_string(), 0).unwrap()); + assert!(hello("World2".to_string(), 0).unwrap()); + // Sometimes one might need the original function. + assert!(memoized_original_hello("World2".to_string(), 0).unwrap()); + memoized_flush_hello(); +} diff --git a/inner/Cargo.toml b/inner/Cargo.toml index 4b07514..cdf7fbb 100644 --- a/inner/Cargo.toml +++ b/inner/Cargo.toml @@ -19,6 +19,9 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } +[dev-dependencies] +test-case = "3.3.1" + [features] default = [] full = [] diff --git a/inner/src/lib.rs b/inner/src/lib.rs index 4bbbde4..6aac75b 100644 --- a/inner/src/lib.rs +++ b/inner/src/lib.rs @@ -1,6 +1,6 @@ #![crate_type = "proc-macro"] #![allow(unused_imports)] // Spurious complaints about a required trait import. -use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ExprCall, ItemFn, Path}; +use syn::{self, parse, parse2, parse_macro_input, spanned::Spanned, AngleBracketedGenericArguments, Expr, ExprCall, ItemFn, Path, PathArguments, Type}; use proc_macro::TokenStream; use quote::{self, ToTokens}; @@ -110,6 +110,40 @@ impl parse::Parse for CacheOptions { } } + +fn check_for_result_type(outer: proc_macro2::TokenStream) -> bool { + // Parse the input as a Rust type + let input_ty = parse2::(outer).expect("failed to parse outer type"); + + if let Type::Path(path) = input_ty { + return path.path.segments.last().expect("O length path?").ident == "Result"; + } + false +} + +fn try_unwrap_result_type(outer: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let original = outer.clone(); + // Parse the input as a Rust type + let input_ty = parse2::(outer).expect("failed to parse outer type"); + + // Ensure it’s a path type (e.g., Result) + if let Type::Path(path) = input_ty { + // Look at the last segment (e.g., "Result") + let last_segment = path.path.segments.last().expect("Expected a Result"); + + if last_segment.ident.to_string().contains("Result") { + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) = &last_segment.arguments + { + // The first generic argument of Result is the Ok type + if let Some(syn::GenericArgument::Type(ok_type)) = args.first() { + return ok_type.to_token_stream() + } + } + } + } + original +} + // This implementation of the storage backend does not depend on any more crates. #[cfg(not(feature = "full"))] mod store { @@ -122,6 +156,7 @@ mod store { key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let value_type = crate::try_unwrap_result_type(value_type); // This is the unbounded default. if let Some(hasher) = &_options.custom_hasher { return ( @@ -148,8 +183,11 @@ mod store { // This implementation of the storage backend also depends on the `lru` crate. #[cfg(feature = "full")] mod store { - use crate::CacheOptions; + use crate::{try_unwrap_result_type, CacheOptions}; use proc_macro::TokenStream; + use quote::quote; + use syn::{parse2, AngleBracketedGenericArguments, PathArguments, Type, TypePath}; + /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable, /// and initializing it. @@ -161,6 +199,8 @@ mod store { key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let value_type = try_unwrap_result_type(value_type); + let value_type = match options.time_to_live { None => quote::quote! {#value_type}, Some(_) => quote::quote! {(std::time::Instant, #value_type)}, @@ -201,7 +241,6 @@ mod store { } } } - /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a /// store. pub(crate) fn cache_access_methods( @@ -379,20 +418,40 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { ), }; + let get_value = if check_for_result_type(return_type.clone()) { + quote::quote! { + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple?; + } + } else { + quote::quote! { + let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + } + }; + + let return_value = if check_for_result_type(return_type.clone()) { + quote::quote! { + Ok(ATTR_MEMOIZE_RETURN__) + } + } else { + quote::quote! { + ATTR_MEMOIZE_RETURN__ + } + }; + let memoizer = if options.shared_cache { quote::quote! { { let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); if let Some(ATTR_MEMOIZE_RETURN__) = #read_memo { - return ATTR_MEMOIZE_RETURN__ + return #return_value; } } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + #get_value let mut ATTR_MEMOIZE_HM__ = #store_ident.lock().unwrap(); #memoize - ATTR_MEMOIZE_RETURN__ + #return_value } } else { quote::quote! { @@ -401,17 +460,17 @@ pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { #read_memo }); if let Some(ATTR_MEMOIZE_RETURN__) = ATTR_MEMOIZE_RETURN__ { - return ATTR_MEMOIZE_RETURN__; + return #return_value; } - let ATTR_MEMOIZE_RETURN__ = #memoized_id #forwarding_tuple; + #get_value #store_ident.with(|ATTR_MEMOIZE_HM__| { let mut ATTR_MEMOIZE_HM__ = ATTR_MEMOIZE_HM__.borrow_mut(); #memoize }); - ATTR_MEMOIZE_RETURN__ + #return_value } }; @@ -505,4 +564,47 @@ fn check_signature( } #[cfg(test)] -mod tests {} +mod tests { + use std::str::FromStr; + use test_case::test_case; + use proc_macro2::TokenStream; + use quote::quote; + use crate::{check_for_result_type, try_unwrap_result_type}; + + #[test_case("Result")] + #[test_case("anyhow::Result")] + #[test_case("std::io::Result")] + #[test_case("io::Result")] + fn test_check_for_result_type_success(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(true, check_for_result_type(input)); + } + + #[test_case("Option")] + #[test_case("(bool, bool)")] + #[test_case("bool")] + fn test_check_for_result_type_fail(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(false, check_for_result_type(input)); + } + + #[test_case("Result", "bool")] + #[test_case("anyhow::Result", "u32")] + #[test_case("std::io::Result", "String")] + #[test_case("io::Result<(u32, u32)>", "(u32 , u32)")] + #[test_case("CustomResult", "CustomStruct")] + fn test_try_unwrap_result_type_inner(input_type: &str, output_type: &str) { + let input = TokenStream::from_str(input_type).unwrap(); + assert_eq!(output_type, + try_unwrap_result_type(input).to_string()); + } + + #[test_case("Option < bool >")] + #[test_case("(bool , bool)")] + #[test_case("bool")] + fn test_try_unwrap_result_type_original(typestr: &str) { + let input = TokenStream::from_str(typestr).unwrap(); + assert_eq!(typestr.replace(" ", ""), + try_unwrap_result_type(input).to_string().replace(" ", "")); + } +}