diff --git a/patchable-macro/src/context.rs b/patchable-macro/src/context.rs index 12369d6..0d346a1 100644 --- a/patchable-macro/src/context.rs +++ b/patchable-macro/src/context.rs @@ -7,7 +7,6 @@ //! macro can emit the companion patch struct plus the `Patchable` and `Patch` //! trait implementations. -use std::borrow::Cow; use std::collections::HashMap; use proc_macro_crate::{FoundCrate, crate_name}; @@ -21,7 +20,7 @@ use syn::{ pub const IS_SERDE_ENABLED: bool = cfg!(feature = "serde"); -static PATCHABLE: &str = "patchable"; +const PATCHABLE: &str = "patchable"; #[derive(Debug)] enum TypeUsage { @@ -74,13 +73,15 @@ impl<'a> MacroContext<'a> { } let mut preserved_types: HashMap<&Ident, TypeUsage> = HashMap::new(); - let mut field_actions = vec![]; + let mut field_actions = Vec::new(); - let stateful_fields = fields.iter().filter(|f| !has_patchable_skip_attr(f)); + for (index, field) in fields.iter().enumerate() { + if has_patchable_skip_attr(field) { + continue; + } - for (index, field) in stateful_fields.enumerate() { let member = if let Some(field_name) = field.ident.as_ref() { - FieldMember::Named(Cow::Borrowed(field_name)) + FieldMember::Named(field_name) } else { FieldMember::Unnamed(Index::from(index)) }; @@ -136,34 +137,22 @@ impl<'a> MacroContext<'a> { pub(crate) fn build_patch_struct(&self) -> TokenStream2 { let generic_params = self.build_patch_type_generics(); - - let mut bounded_types = Vec::new(); - let patchable_trait = &self.patchable_trait; - for param in self.generics.type_params() { - if let Some(TypeUsage::Patchable) = self.preserved_types.get(¶m.ident) { - bounded_types.push(quote! { #param: #patchable_trait }); - } - } - let where_clause = if bounded_types.is_empty() { - quote! {} - } else { - quote! { where #(#bounded_types),* } - }; - + let where_clause = self.build_where_clause_with_bound(&self.patchable_trait); let patch_fields = self.generate_patch_fields(); - let body = self.select_fields( - quote! { #generic_params #where_clause { #(#patch_fields),* } }, - quote! { #generic_params ( #(#patch_fields),* ) #where_clause; }, - quote! {;}, - ); + let body = match &self.fields { + Fields::Named(_) => quote! { #generic_params #where_clause { #(#patch_fields),* } }, + Fields::Unnamed(_) => quote! { #generic_params ( #(#patch_fields),* ) #where_clause; }, + Fields::Unit => quote! {;}, + }; let patch_name = &self.patch_struct_name; - let mut derives_list = Vec::with_capacity(3); - if IS_SERDE_ENABLED { - derives_list.push(quote! { ::serde::Deserialize }); - } + let derive_attr = if IS_SERDE_ENABLED { + quote! { #[derive(::core::fmt::Debug, ::serde::Deserialize)] } + } else { + quote! { #[derive(::core::fmt::Debug)] } + }; quote! { - #[derive(#(#derives_list),*)] + #derive_attr pub struct #patch_name #body } } @@ -175,8 +164,8 @@ impl<'a> MacroContext<'a> { pub(crate) fn build_patchable_trait_impl(&self) -> TokenStream2 { let patchable_trait = &self.patchable_trait; let (impl_generics, type_generics, _) = self.generics.split_for_impl(); - let where_clause = self.build_bounded_types(patchable_trait); - let assoc_type_decl = self.build_associate_type_declaration(); + let where_clause = self.build_where_clause_with_bound(patchable_trait); + let assoc_type_decl = self.build_associated_type_declaration(); let input_struct_name = self.struct_name; @@ -189,31 +178,6 @@ impl<'a> MacroContext<'a> { } } - // ====================================================================== - // impl From> for OriginalStructPatch<...> - // ====================================================================== - - pub(crate) fn build_from_trait_impl(&self) -> TokenStream2 { - let (impl_generics, type_generics, _) = self.generics.split_for_impl(); - let patch_type_generics = self.build_patch_type_generics(); - let where_clause = self.build_from_where_clause(); - - let input_struct_name = self.struct_name; - let patch_struct_name = &self.patch_struct_name; - let from_body = self.generate_from_body(); - - quote! { - impl #impl_generics ::core::convert::From<#input_struct_name #type_generics> - for #patch_struct_name #patch_type_generics - #where_clause { - #[inline(always)] - fn from(value: #input_struct_name #type_generics) -> Self { - #from_body - } - } - } - } - // ============================================================ // impl Patch for OriginalStruct MacroContext<'a> { pub(crate) fn build_patch_trait_impl(&self) -> TokenStream2 { let patch_trait = &self.patch_trait; let (impl_generics, type_generics, _) = self.generics.split_for_impl(); - let where_clause = self.build_bounded_types(patch_trait); + let where_clause = self.build_where_clause_with_bound(patch_trait); let input_struct_name = self.struct_name; @@ -244,33 +208,56 @@ impl<'a> MacroContext<'a> { } } + // ====================================================================== + // impl From> for OriginalStructPatch<...> + // ====================================================================== + + pub(crate) fn build_from_trait_impl(&self) -> TokenStream2 { + let (impl_generics, type_generics, _) = self.generics.split_for_impl(); + let patch_type_generics = self.build_patch_type_generics(); + let where_clause = self.build_where_clause_for_from_impl(); + + let input_struct_name = self.struct_name; + let patch_struct_name = &self.patch_struct_name; + let from_body = self.generate_from_body(); + + quote! { + impl #impl_generics ::core::convert::From<#input_struct_name #type_generics> + for #patch_struct_name #patch_type_generics + #where_clause { + #[inline(always)] + fn from(value: #input_struct_name #type_generics) -> Self { + #from_body + } + } + } + } + fn generate_patch_fields(&self) -> Vec { + let patchable_trait = &self.patchable_trait; self.field_actions .iter() .map(|action| match action { - FieldAction::Keep { - member: FieldMember::Named(name), - ty, - } => { - quote! { #name : #ty } - } - FieldAction::Keep { - member: FieldMember::Unnamed(_), - ty, - } => { - quote! { #ty } - } - FieldAction::Patch { - member: FieldMember::Named(name), - ty, - } => { - quote! { #name : #ty :: Patch } - } - FieldAction::Patch { - member: FieldMember::Unnamed(_), - ty, - } => { - quote! { #ty :: Patch } + FieldAction::Keep { member, ty } => match member { + FieldMember::Named(name) => quote! { #name : #ty }, + FieldMember::Unnamed(_) => quote! { #ty }, + }, + FieldAction::Patch { member, ty } => { + let field = match member { + FieldMember::Named(name) => quote! { #name : <#ty as #patchable_trait>::Patch }, + FieldMember::Unnamed(_) => quote! { <#ty as #patchable_trait>::Patch }, + }; + if IS_SERDE_ENABLED { + let bound = quote! { <#ty as #patchable_trait>::Patch: ::serde::de::DeserializeOwned }; + let bound_string = bound.to_string(); + let bound_lit = syn::LitStr::new(&bound_string, Span::call_site()); + quote! { + #[serde(bound(deserialize = #bound_lit))] + #field + } + } else { + quote! { #field } + } } }) .collect() @@ -281,18 +268,24 @@ impl<'a> MacroContext<'a> { return quote! {}; } - let statements = self.field_actions.iter().map(|action| match action { - FieldAction::Keep { member, .. } => { - quote! { - self.#member = patch.#member; + let statements = self + .field_actions + .iter() + .enumerate() + .map(|(patch_index, action)| match action { + FieldAction::Keep { member, .. } => { + let patch_member = patch_member(member, patch_index); + quote! { + self.#member = patch.#patch_member; + } } - } - FieldAction::Patch { member, .. } => { - quote! { - self.#member.patch(patch.#member); + FieldAction::Patch { member, .. } => { + let patch_member = patch_member(member, patch_index); + quote! { + self.#member.patch(patch.#patch_member); + } } - } - }); + }); quote! { #(#statements)* @@ -309,46 +302,45 @@ impl<'a> MacroContext<'a> { ), }; - self.select_fields(quote! { #member: #expr }, quote! { #expr }, quote! {}) + match &self.fields { + Fields::Named(_) => quote! { #member: #expr }, + Fields::Unnamed(_) => quote! { #expr }, + Fields::Unit => quote! {}, + } }); let body = quote! { #(#field_expressions),* }; - let body_ref = &body; - - self.select_fields( - quote! { Self { #body_ref } }, - quote! { Self(#body_ref) }, - quote! { Self }, - ) - } - fn collect_patch_generics(&self) -> Vec> { - let mut generics = Vec::new(); - for param in self.generics.type_params() { - if self.preserved_types.contains_key(¶m.ident) { - generics.push(Cow::Borrowed(¶m.ident)); - } + match &self.fields { + Fields::Named(_) => quote! { Self { #body } }, + Fields::Unnamed(_) => quote! { Self(#body) }, + Fields::Unit => quote! { Self }, } - generics } - fn build_bounded_types(&self, bound: &TokenStream2) -> TokenStream2 { - let mut bounded_types = Vec::new(); - for param in self.generics.type_params() { - let t = ¶m.ident; - if let Some(TypeUsage::Patchable) = self.preserved_types.get(t) { - bounded_types.push(quote! { #t: #bound }); - } - } + fn iter_patchable_type_params(&self) -> impl Iterator + '_ { + self.generics.type_params().filter_map(|param| { + matches!( + self.preserved_types.get(¶m.ident), + Some(TypeUsage::Patchable) + ) + .then_some(¶m.ident) + }) + } - self.extend_where_clause(bounded_types) + fn iter_preserved_type_params(&self) -> impl Iterator + '_ { + self.generics.type_params().filter_map(|param| { + self.preserved_types + .contains_key(¶m.ident) + .then_some(¶m.ident) + }) } // ============================================================ // type Patch = MyPatch // ============================================================ - fn build_associate_type_declaration(&self) -> TokenStream2 { + fn build_associated_type_declaration(&self) -> TokenStream2 { let patch_type_generics = self.build_patch_type_generics(); let state_name = &self.patch_struct_name; quote! { @@ -356,70 +348,63 @@ impl<'a> MacroContext<'a> { } } - fn build_from_where_clause(&self) -> TokenStream2 { - let patchable_trait = &self.patchable_trait; - let mut bounded_types = Vec::new(); - for param in self.generics.type_params() { - let t = ¶m.ident; - if let Some(TypeUsage::Patchable) = self.preserved_types.get(t) { - bounded_types.push(quote! { #t: #patchable_trait }); - bounded_types.push(quote! { - <#t as #patchable_trait>::Patch: ::core::convert::From<#t> - }); - } - } - - self.extend_where_clause(bounded_types) - } - fn build_patch_type_generics(&self) -> TokenStream2 { - let patch_generic_params = self.collect_patch_generics(); + let patch_generic_params = self.iter_preserved_type_params(); // Empty `<>` is legal in Rust, and adding or dropping the `<>` doesn't affect the // definition. For example, `struct A<>(i32)` and `struct A(i32)` have the // same HIR. quote! { <#(#patch_generic_params),*> } } - fn extend_where_clause(&self, bounds: Vec) -> TokenStream2 { - if let Some(where_clause) = &self.generics.where_clause { - if bounds.is_empty() { - return quote! { #where_clause }; - } + // =========================================== + // Helper functions for building where clauses + // =========================================== - let normalized_input_where_clause = if where_clause.predicates.empty_or_trailing() { - quote! { #where_clause } - } else { - quote! { #where_clause, } - }; + fn build_where_clause_with_bound(&self, bound: &TokenStream2) -> TokenStream2 { + self.build_where_clause_for_patchable_types(|ty, patchable_trait| { quote! { - #normalized_input_where_clause - #(#bounds),* + #ty: #bound, + <#ty as #patchable_trait>::Patch: ::core::fmt::Debug, } - } else if !bounds.is_empty() { + }) + } + + fn build_where_clause_for_from_impl(&self) -> TokenStream2 { + self.build_where_clause_for_patchable_types(|ty, patchable_trait| { quote! { - where #(#bounds),* + #ty: #patchable_trait, + <#ty as #patchable_trait>::Patch: ::core::convert::From<#ty> + ::core::fmt::Debug, } - } else { - quote! {} - } + }) } - fn select_fields( - &self, - named: TokenStream2, - unnamed: TokenStream2, - unit: TokenStream2, - ) -> TokenStream2 { - match &self.fields { - Fields::Named(_) => named, - Fields::Unnamed(_) => unnamed, - Fields::Unit => unit, + fn build_where_clause_for_patchable_types(&self, mut build_bounds: F) -> TokenStream2 + where + F: FnMut(&Ident, &TokenStream2) -> TokenStream2, + { + let patchable_trait = &self.patchable_trait; + let bounded_types: Vec<_> = self + .iter_patchable_type_params() + .map(|ty| build_bounds(ty, patchable_trait)) + .collect(); + self.extend_where_clause(bounded_types) + } + + fn extend_where_clause(&self, bounds: Vec) -> TokenStream2 { + match (&self.generics.where_clause, bounds.is_empty()) { + (None, true) => quote! {}, + (None, false) => quote! { where #(#bounds),* }, + (Some(where_clause), true) => quote! { #where_clause }, + (Some(where_clause), false) => { + let sep = (!where_clause.predicates.trailing_punct()).then_some(quote! {,}); + quote! { #where_clause #sep #(#bounds),* } + } } } } enum FieldMember<'a> { - Named(Cow<'a, Ident>), + Named(&'a Ident), Unnamed(Index), } @@ -443,6 +428,16 @@ enum FieldAction<'a> { }, } +fn patch_member(member: &FieldMember<'_>, patch_index: usize) -> TokenStream2 { + match member { + FieldMember::Named(name) => quote! { #name }, + FieldMember::Unnamed(_) => { + let index = Index::from(patch_index); + quote! { #index } + } + } +} + pub fn use_site_crate_path() -> TokenStream2 { let found_crate = crate_name(PATCHABLE).expect("patchable library should be present in `Cargo.toml`"); diff --git a/patchable-macro/src/lib.rs b/patchable-macro/src/lib.rs index 8495a88..0ab551c 100644 --- a/patchable-macro/src/lib.rs +++ b/patchable-macro/src/lib.rs @@ -22,11 +22,11 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{self, DeriveInput}; +use syn::{Fields, ItemStruct, parse_macro_input, parse_quote}; mod context; -use syn::{Fields, ItemStruct, parse_macro_input, parse_quote}; +use syn::DeriveInput; use crate::context::{IS_SERDE_ENABLED, has_patchable_skip_attr, use_site_crate_path}; @@ -46,34 +46,19 @@ pub fn patchable_model(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as ItemStruct); let crate_root = use_site_crate_path(); - // Note: We use parse_quote! to easily generate Attribute types - if !IS_SERDE_ENABLED { - input.attrs.push(parse_quote! { - #[derive(#crate_root::Patchable, #crate_root::Patch)] - }); - } else { - input.attrs.push(parse_quote! { + let derives = if IS_SERDE_ENABLED { + parse_quote! { #[derive(#crate_root::Patchable, #crate_root::Patch, ::serde::Serialize)] - }); - - match input.fields { - Fields::Named(ref mut fields) => { - for field in &mut fields.named { - // Check if this field has the #[patchable(skip)] attribute - if has_patchable_skip_attr(field) { - field.attrs.push(parse_quote! { #[serde(skip)] }); - } - } - } - Fields::Unnamed(ref mut fields) => { - for field in &mut fields.unnamed { - if has_patchable_skip_attr(field) { - field.attrs.push(parse_quote! { #[serde(skip)] }); - } - } - } - Fields::Unit => {} } + } else { + parse_quote! { + #[derive(#crate_root::Patchable, #crate_root::Patch)] + } + }; + input.attrs.push(derives); + + if IS_SERDE_ENABLED { + add_serde_skip_attrs(&mut input.fields); } (quote! { #input }).into() @@ -143,9 +128,17 @@ fn derive_with(input: TokenStream, f: F) -> TokenStream where F: FnOnce(&context::MacroContext) -> TokenStream2, { - let input: DeriveInput = syn::parse_macro_input!(input as DeriveInput); + let input: DeriveInput = parse_macro_input!(input as DeriveInput); match context::MacroContext::new(&input) { Ok(ctx) => f(&ctx).into(), Err(e) => e.to_compile_error().into(), } } + +fn add_serde_skip_attrs(fields: &mut Fields) { + for field in fields.iter_mut() { + if has_patchable_skip_attr(field) { + field.attrs.push(parse_quote! { #[serde(skip)] }); + } + } +} diff --git a/patchable/tests/serde.rs b/patchable/tests/serde.rs index 4b67d6f..22236b2 100644 --- a/patchable/tests/serde.rs +++ b/patchable/tests/serde.rs @@ -3,6 +3,10 @@ use std::fmt::Debug; use patchable::{Patch, Patchable, TryPatch, patchable_model}; use serde::{Deserialize, Serialize}; +fn identity(x: &i32) -> i32 { + *x +} + #[patchable_model] #[derive(Clone, Default, Debug, PartialEq)] struct FakeMeasurement { @@ -25,10 +29,6 @@ struct ScopedMeasurement { #[test] fn test_scoped_peek() -> anyhow::Result<()> { - fn identity(x: &i32) -> i32 { - *x - } - let fake_measurement: FakeMeasurement i32> = FakeMeasurement { v: 42, how: identity, @@ -81,6 +81,37 @@ fn test_tuple_struct_patch() { assert_eq!(s, TupleStruct(10, 20)); } +#[patchable_model] +#[derive(Clone, Debug)] +struct TupleStructWithSkippedMiddle(i32, #[patchable(skip)] F, i64); + +#[test] +fn test_tuple_struct_skip_keeps_original_field_index() { + let mut s = TupleStructWithSkippedMiddle(1, identity, 2); + let patch: i32> as Patchable>::Patch = + serde_json::from_str(r#"[10, 20]"#).unwrap(); + s.patch(patch); + assert_eq!(s.0, 10); + assert_eq!(s.2, 20); +} + +#[patchable_model] +#[derive(Clone, Debug)] +struct TupleStructWithWhereClause(i32, T, i64) +where + T: From<(u32, u32)>; + +#[test] +fn test_tuple_struct_with_where_clause() { + let mut s = TupleStructWithWhereClause(1, (0, 0), 2); + let patch: as Patchable>::Patch = + serde_json::from_str(r#"[10, [42, 84], 20]"#).unwrap(); + s.patch(patch); + assert_eq!(s.0, 10); + assert_eq!(s.1, (42, 84)); + assert_eq!(s.2, 20); +} + #[patchable_model] #[derive(Clone, Debug, PartialEq, Eq)] struct UnitStruct;