diff --git a/CHANGELOG.md b/CHANGELOG.md index 914d89b84e..5d5990c460 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ incremented for features. * lang: Add fallback functions ([#457](https://github.com/project-serum/anchor/pull/457)). * lang: Add feature flag for using the old state account discriminator. This is a temporary flag for those with programs built prior to v0.7.0 but want to use the latest Anchor version. Expect this to be removed in a future version ([#446](https://github.com/project-serum/anchor/pull/446)). +* lang: Add generic support to Accounts ([#496](https://github.com/project-serum/anchor/pull/496)). ### Breaking Changes diff --git a/lang/attribute/account/src/lib.rs b/lang/attribute/account/src/lib.rs index 46be1fa282..7a6cbbe8ce 100644 --- a/lang/attribute/account/src/lib.rs +++ b/lang/attribute/account/src/lib.rs @@ -78,6 +78,7 @@ pub fn account( let account_strct = parse_macro_input!(input as syn::ItemStruct); let account_name = &account_strct.ident; + let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl(); let discriminator: proc_macro2::TokenStream = { // Namespace the discriminator to prevent collisions. @@ -103,12 +104,16 @@ pub fn account( #[zero_copy] #account_strct - unsafe impl anchor_lang::__private::bytemuck::Pod for #account_name {} - unsafe impl anchor_lang::__private::bytemuck::Zeroable for #account_name {} + #[automatically_derived] + unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {} + #[automatically_derived] + unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {} - impl anchor_lang::ZeroCopy for #account_name {} + #[automatically_derived] + impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {} - impl anchor_lang::Discriminator for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { fn discriminator() -> [u8; 8] { #discriminator } @@ -116,7 +121,8 @@ pub fn account( // This trait is useful for clients deserializing accounts. // It's expected on-chain programs deserialize via zero-copy. - impl anchor_lang::AccountDeserialize for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause { fn try_deserialize(buf: &mut &[u8]) -> std::result::Result { if buf.len() < #discriminator.len() { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into()); @@ -142,7 +148,8 @@ pub fn account( #[derive(AnchorSerialize, AnchorDeserialize, Clone)] #account_strct - impl anchor_lang::AccountSerialize for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause { fn try_serialize(&self, writer: &mut W) -> std::result::Result<(), ProgramError> { writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?; AnchorSerialize::serialize( @@ -154,7 +161,8 @@ pub fn account( } } - impl anchor_lang::AccountDeserialize for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause { fn try_deserialize(buf: &mut &[u8]) -> std::result::Result { if buf.len() < #discriminator.len() { return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into()); @@ -173,7 +181,8 @@ pub fn account( } } - impl anchor_lang::Discriminator for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause { fn discriminator() -> [u8; 8] { #discriminator } @@ -206,6 +215,7 @@ pub fn associated( ) -> proc_macro::TokenStream { let mut account_strct = parse_macro_input!(input as syn::ItemStruct); let account_name = &account_strct.ident; + let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl(); // Add a `__nonce: u8` field to the struct to hold the bump seed for // the program dervied address. @@ -245,7 +255,8 @@ pub fn associated( #[anchor_lang::account(#args)] #account_strct - impl anchor_lang::Bump for #account_name { + #[automatically_derived] + impl #impl_gen anchor_lang::Bump for #account_name #ty_gen #where_clause { fn seed(&self) -> u8 { self.__nonce } @@ -257,6 +268,7 @@ pub fn associated( pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream { let account_strct = parse_macro_input!(item as syn::ItemStruct); let account_name = &account_strct.ident; + let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl(); let fields = match &account_strct.fields { syn::Fields::Named(n) => n, @@ -300,7 +312,8 @@ pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::T }) .collect(); proc_macro::TokenStream::from(quote! { - impl #account_name { + #[automatically_derived] + impl #impl_gen #account_name #ty_gen #where_clause { #(#methods)* } }) diff --git a/lang/syn/src/codegen/accounts/__client_accounts.rs b/lang/syn/src/codegen/accounts/__client_accounts.rs index 0c5841bc5b..a661288bb8 100644 --- a/lang/syn/src/codegen/accounts/__client_accounts.rs +++ b/lang/syn/src/codegen/accounts/__client_accounts.rs @@ -115,6 +115,7 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { #(#account_struct_fields),* } + #[automatically_derived] impl anchor_lang::ToAccountMetas for #name { fn to_account_metas(&self, is_signer: Option) -> Vec { let mut account_metas = vec![]; diff --git a/lang/syn/src/codegen/accounts/constraints.rs b/lang/syn/src/codegen/accounts/constraints.rs index b9db65ddec..0184d75f3b 100644 --- a/lang/syn/src/codegen/accounts/constraints.rs +++ b/lang/syn/src/codegen/accounts/constraints.rs @@ -355,24 +355,24 @@ pub fn generate_constraint_associated_init( ) } -fn parse_ty(f: &Field) -> (&syn::Ident, proc_macro2::TokenStream, bool) { +fn parse_ty(f: &Field) -> (&syn::TypePath, proc_macro2::TokenStream, bool) { match &f.ty { Ty::ProgramAccount(ty) => ( - &ty.account_ident, + &ty.account_type_path, quote! { anchor_lang::ProgramAccount }, false, ), Ty::Loader(ty) => ( - &ty.account_ident, + &ty.account_type_path, quote! { anchor_lang::Loader }, true, ), Ty::CpiAccount(ty) => ( - &ty.account_ident, + &ty.account_type_path, quote! { anchor_lang::CpiAccount }, @@ -617,7 +617,7 @@ pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2: let program_target = c.program_target.clone(); let ident = &f.ident; let account_ty = match &f.ty { - Ty::CpiState(ty) => &ty.account_ident, + Ty::CpiState(ty) => &ty.account_type_path, _ => panic!("Invalid state constraint"), }; quote! { diff --git a/lang/syn/src/codegen/accounts/exit.rs b/lang/syn/src/codegen/accounts/exit.rs index 94b1c3dc66..60b024b124 100644 --- a/lang/syn/src/codegen/accounts/exit.rs +++ b/lang/syn/src/codegen/accounts/exit.rs @@ -1,11 +1,16 @@ -use crate::codegen::accounts::generics; +use crate::codegen::accounts::{generics, ParsedGenerics}; use crate::{AccountField, AccountsStruct}; use quote::quote; // Generates the `Exit` trait implementation. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { let name = &accs.ident; - let (combined_generics, trait_generics, strct_generics) = generics(accs); + let ParsedGenerics { + combined_generics, + trait_generics, + struct_generics, + where_clause, + } = generics(accs); let on_save: Vec = accs .fields @@ -39,7 +44,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { }) .collect(); quote! { - impl#combined_generics anchor_lang::AccountsExit#trait_generics for #name#strct_generics { + #[automatically_derived] + impl<#combined_generics> anchor_lang::AccountsExit<#trait_generics> for #name<#struct_generics> #where_clause{ fn exit(&self, program_id: &anchor_lang::solana_program::pubkey::Pubkey) -> anchor_lang::solana_program::entrypoint::ProgramResult { #(#on_save)* Ok(()) diff --git a/lang/syn/src/codegen/accounts/mod.rs b/lang/syn/src/codegen/accounts/mod.rs index 05ff4204e8..7c71dd99cd 100644 --- a/lang/syn/src/codegen/accounts/mod.rs +++ b/lang/syn/src/codegen/accounts/mod.rs @@ -1,5 +1,9 @@ use crate::AccountsStruct; use quote::quote; +use std::iter; +use syn::punctuated::Punctuated; +use syn::{ConstParam, LifetimeDef, Token, TypeParam}; +use syn::{GenericParam, PredicateLifetime, WhereClause, WherePredicate}; mod __client_accounts; mod constraints; @@ -26,18 +30,70 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { } } -fn generics( - accs: &AccountsStruct, -) -> ( - proc_macro2::TokenStream, - proc_macro2::TokenStream, - proc_macro2::TokenStream, -) { - match accs.generics.lt_token { - None => (quote! {<'info>}, quote! {<'info>}, quote! {}), - Some(_) => { - let g = &accs.generics; - (quote! {#g}, quote! {#g}, quote! {#g}) - } +fn generics(accs: &AccountsStruct) -> ParsedGenerics { + let trait_lifetime = accs + .generics + .lifetimes() + .next() + .cloned() + .unwrap_or_else(|| syn::parse_str("'info").expect("Could not parse lifetime")); + + let mut where_clause = accs.generics.where_clause.clone().unwrap_or(WhereClause { + where_token: Default::default(), + predicates: Default::default(), + }); + for lifetime in accs.generics.lifetimes().map(|def| &def.lifetime) { + where_clause + .predicates + .push(WherePredicate::Lifetime(PredicateLifetime { + lifetime: lifetime.clone(), + colon_token: Default::default(), + bounds: iter::once(trait_lifetime.lifetime.clone()).collect(), + })) + } + let trait_lifetime = GenericParam::Lifetime(trait_lifetime); + + ParsedGenerics { + combined_generics: if accs.generics.lifetimes().next().is_some() { + accs.generics.params.clone() + } else { + iter::once(trait_lifetime.clone()) + .chain(accs.generics.params.clone()) + .collect() + }, + trait_generics: iter::once(trait_lifetime).collect(), + struct_generics: accs + .generics + .params + .clone() + .into_iter() + .map(|param: GenericParam| match param { + GenericParam::Const(ConstParam { ident, .. }) + | GenericParam::Type(TypeParam { ident, .. }) => GenericParam::Type(TypeParam { + attrs: vec![], + ident, + colon_token: None, + bounds: Default::default(), + eq_token: None, + default: None, + }), + GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => { + GenericParam::Lifetime(LifetimeDef { + attrs: vec![], + lifetime, + colon_token: None, + bounds: Default::default(), + }) + } + }) + .collect(), + where_clause, } } + +struct ParsedGenerics { + pub combined_generics: Punctuated, + pub trait_generics: Punctuated, + pub struct_generics: Punctuated, + pub where_clause: WhereClause, +} diff --git a/lang/syn/src/codegen/accounts/to_account_infos.rs b/lang/syn/src/codegen/accounts/to_account_infos.rs index 8a689d72b5..6ba143115a 100644 --- a/lang/syn/src/codegen/accounts/to_account_infos.rs +++ b/lang/syn/src/codegen/accounts/to_account_infos.rs @@ -1,11 +1,16 @@ -use crate::codegen::accounts::generics; +use crate::codegen::accounts::{generics, ParsedGenerics}; use crate::{AccountField, AccountsStruct}; use quote::quote; // Generates the `ToAccountInfos` trait implementation. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { let name = &accs.ident; - let (combined_generics, trait_generics, strct_generics) = generics(accs); + let ParsedGenerics { + combined_generics, + trait_generics, + struct_generics, + where_clause, + } = generics(accs); let to_acc_infos: Vec = accs .fields @@ -21,7 +26,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { }) .collect(); quote! { - impl#combined_generics anchor_lang::ToAccountInfos#trait_generics for #name#strct_generics { + #[automatically_derived] + impl<#combined_generics> anchor_lang::ToAccountInfos<#trait_generics> for #name <#struct_generics> #where_clause{ fn to_account_infos(&self) -> Vec> { let mut account_infos = vec![]; diff --git a/lang/syn/src/codegen/accounts/to_account_metas.rs b/lang/syn/src/codegen/accounts/to_account_metas.rs index 7a72c38ace..d7e1b01ee4 100644 --- a/lang/syn/src/codegen/accounts/to_account_metas.rs +++ b/lang/syn/src/codegen/accounts/to_account_metas.rs @@ -1,11 +1,9 @@ -use crate::codegen::accounts::generics; use crate::{AccountField, AccountsStruct}; use quote::quote; // Generates the `ToAccountMetas` trait implementation. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { let name = &accs.ident; - let (combined_generics, _trait_generics, strct_generics) = generics(accs); let to_acc_metas: Vec = accs .fields @@ -26,8 +24,12 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { } }) .collect(); + + let (impl_gen, ty_gen, where_clause) = accs.generics.split_for_impl(); + quote! { - impl#combined_generics anchor_lang::ToAccountMetas for #name#strct_generics { + #[automatically_derived] + impl#impl_gen anchor_lang::ToAccountMetas for #name #ty_gen #where_clause{ fn to_account_metas(&self, is_signer: Option) -> Vec { let mut account_metas = vec![]; diff --git a/lang/syn/src/codegen/accounts/try_accounts.rs b/lang/syn/src/codegen/accounts/try_accounts.rs index 31f311eb2d..19cb2d0bdb 100644 --- a/lang/syn/src/codegen/accounts/try_accounts.rs +++ b/lang/syn/src/codegen/accounts/try_accounts.rs @@ -1,4 +1,4 @@ -use crate::codegen::accounts::{constraints, generics}; +use crate::codegen::accounts::{constraints, generics, ParsedGenerics}; use crate::{AccountField, AccountsStruct, Field, SysvarTy, Ty}; use proc_macro2::TokenStream; use quote::quote; @@ -7,7 +7,12 @@ use syn::Expr; // Generates the `Accounts` trait implementation. pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { let name = &accs.ident; - let (combined_generics, trait_generics, strct_generics) = generics(accs); + let ParsedGenerics { + combined_generics, + trait_generics, + struct_generics, + where_clause, + } = generics(accs); // Deserialization for each field let deser_fields: Vec = accs @@ -88,7 +93,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream { }; quote! { - impl#combined_generics anchor_lang::Accounts#trait_generics for #name#strct_generics { + #[automatically_derived] + impl<#combined_generics> anchor_lang::Accounts<#trait_generics> for #name<#struct_generics> #where_clause { #[inline(never)] fn try_accounts( program_id: &anchor_lang::solana_program::pubkey::Pubkey, @@ -133,31 +139,31 @@ fn typed_ident(field: &Field) -> TokenStream { let ty = match &field.ty { Ty::AccountInfo => quote! { AccountInfo }, Ty::ProgramState(ty) => { - let account = &ty.account_ident; + let account = &ty.account_type_path; quote! { ProgramState<#account> } } Ty::CpiState(ty) => { - let account = &ty.account_ident; + let account = &ty.account_type_path; quote! { CpiState<#account> } } Ty::ProgramAccount(ty) => { - let account = &ty.account_ident; + let account = &ty.account_type_path; quote! { ProgramAccount<#account> } } Ty::Loader(ty) => { - let account = &ty.account_ident; + let account = &ty.account_type_path; quote! { Loader<#account> } } Ty::CpiAccount(ty) => { - let account = &ty.account_ident; + let account = &ty.account_type_path; quote! { CpiAccount<#account> } diff --git a/lang/syn/src/lib.rs b/lang/syn/src/lib.rs index 6c9ed9c3da..5c59cb89f0 100644 --- a/lang/syn/src/lib.rs +++ b/lang/syn/src/lib.rs @@ -12,7 +12,7 @@ use syn::spanned::Spanned; use syn::token::Comma; use syn::{ Expr, Generics, Ident, ImplItemMethod, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStruct, LitInt, - LitStr, PatType, Token, + LitStr, PatType, Token, TypePath, }; pub mod codegen; @@ -198,30 +198,30 @@ pub enum SysvarTy { #[derive(Debug, PartialEq)] pub struct ProgramStateTy { - pub account_ident: Ident, + pub account_type_path: TypePath, } #[derive(Debug, PartialEq)] pub struct CpiStateTy { - pub account_ident: Ident, + pub account_type_path: TypePath, } #[derive(Debug, PartialEq)] pub struct ProgramAccountTy { // The struct type of the account. - pub account_ident: Ident, + pub account_type_path: TypePath, } #[derive(Debug, PartialEq)] pub struct CpiAccountTy { // The struct type of the account. - pub account_ident: Ident, + pub account_type_path: TypePath, } #[derive(Debug, PartialEq)] pub struct LoaderTy { // The struct type of the account. - pub account_ident: Ident, + pub account_type_path: TypePath, } #[derive(Debug)] diff --git a/lang/syn/src/parser/accounts/mod.rs b/lang/syn/src/parser/accounts/mod.rs index 89191a4bc9..2fa5dfb100 100644 --- a/lang/syn/src/parser/accounts/mod.rs +++ b/lang/syn/src/parser/accounts/mod.rs @@ -118,30 +118,40 @@ fn ident_string(f: &syn::Field) -> ParseResult { fn parse_program_state(path: &syn::Path) -> ParseResult { let account_ident = parse_account(path)?; - Ok(ProgramStateTy { account_ident }) + Ok(ProgramStateTy { + account_type_path: account_ident, + }) } fn parse_cpi_state(path: &syn::Path) -> ParseResult { let account_ident = parse_account(path)?; - Ok(CpiStateTy { account_ident }) + Ok(CpiStateTy { + account_type_path: account_ident, + }) } fn parse_cpi_account(path: &syn::Path) -> ParseResult { let account_ident = parse_account(path)?; - Ok(CpiAccountTy { account_ident }) + Ok(CpiAccountTy { + account_type_path: account_ident, + }) } fn parse_program_account(path: &syn::Path) -> ParseResult { let account_ident = parse_account(path)?; - Ok(ProgramAccountTy { account_ident }) + Ok(ProgramAccountTy { + account_type_path: account_ident, + }) } fn parse_program_account_zero_copy(path: &syn::Path) -> ParseResult { let account_ident = parse_account(path)?; - Ok(LoaderTy { account_ident }) + Ok(LoaderTy { + account_type_path: account_ident, + }) } -fn parse_account(path: &syn::Path) -> ParseResult { +fn parse_account(path: &syn::Path) -> ParseResult { let segments = &path.segments[0]; match &segments.arguments { syn::PathArguments::AngleBracketed(args) => { @@ -153,18 +163,7 @@ fn parse_account(path: &syn::Path) -> ParseResult { )); } match &args.args[1] { - syn::GenericArgument::Type(syn::Type::Path(ty_path)) => { - // TODO: allow segmented paths. - if ty_path.path.segments.len() != 1 { - return Err(ParseError::new( - ty_path.path.span(), - "segmented paths are not currently allowed", - )); - } - - let path_segment = &ty_path.path.segments[0]; - Ok(path_segment.ident.clone()) - } + syn::GenericArgument::Type(syn::Type::Path(ty_path)) => Ok(ty_path.clone()), _ => Err(ParseError::new( args.args[1].span(), "first bracket argument must be a lifetime", diff --git a/lang/tests/generics_test.rs b/lang/tests/generics_test.rs new file mode 100644 index 0000000000..4e71e57f91 --- /dev/null +++ b/lang/tests/generics_test.rs @@ -0,0 +1,44 @@ +#![allow(dead_code)] + +use anchor_lang::prelude::borsh::maybestd::io::Write; +use anchor_lang::prelude::*; +use borsh::{BorshDeserialize, BorshSerialize}; + +#[derive(Accounts)] +pub struct GenericsTest<'info, T, U, const N: usize> +where + T: AccountSerialize + AccountDeserialize + Clone, + U: BorshSerialize + BorshDeserialize + Default + Clone, +{ + pub non_generic: AccountInfo<'info>, + pub generic: ProgramAccount<'info, T>, + pub const_generic: Loader<'info, Account>, + pub associated: CpiAccount<'info, Associated>, +} + +#[account(zero_copy)] +pub struct Account { + pub data: WrappedU8Array, +} + +#[associated] +#[derive(Default)] +pub struct Associated +where + T: BorshDeserialize + BorshSerialize + Default, +{ + pub data: T, +} + +#[derive(Copy, Clone)] +pub struct WrappedU8Array(u8); +impl BorshSerialize for WrappedU8Array { + fn serialize(&self, _writer: &mut W) -> borsh::maybestd::io::Result<()> { + todo!() + } +} +impl BorshDeserialize for WrappedU8Array { + fn deserialize(_buf: &mut &[u8]) -> borsh::maybestd::io::Result { + todo!() + } +}