Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ incremented for features.

* lang: Add fallback functions ([#457](https://github.com/project-serum/anchor/pull/457)).
* lang: Add feature flag for using the old state account discriminator. This is a temporary flag for those with programs built prior to v0.7.0 but want to use the latest Anchor version. Expect this to be removed in a future version ([#446](https://github.com/project-serum/anchor/pull/446)).
* lang: Add generic support to Accounts ([#496](https://github.com/project-serum/anchor/pull/496)).

### Breaking Changes

Expand Down
33 changes: 23 additions & 10 deletions lang/attribute/account/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub fn account(

let account_strct = parse_macro_input!(input as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();

let discriminator: proc_macro2::TokenStream = {
// Namespace the discriminator to prevent collisions.
Expand All @@ -103,20 +104,25 @@ pub fn account(
#[zero_copy]
#account_strct

unsafe impl anchor_lang::__private::bytemuck::Pod for #account_name {}
unsafe impl anchor_lang::__private::bytemuck::Zeroable for #account_name {}
#[automatically_derived]
unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
#[automatically_derived]
unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}

impl anchor_lang::ZeroCopy for #account_name {}
#[automatically_derived]
impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}

impl anchor_lang::Discriminator for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
fn discriminator() -> [u8; 8] {
#discriminator
}
}

// This trait is useful for clients deserializing accounts.
// It's expected on-chain programs deserialize via zero-copy.
impl anchor_lang::AccountDeserialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
if buf.len() < #discriminator.len() {
return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
Expand All @@ -142,7 +148,8 @@ pub fn account(
#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
#account_strct

impl anchor_lang::AccountSerialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> std::result::Result<(), ProgramError> {
writer.write_all(&#discriminator).map_err(|_| anchor_lang::__private::ErrorCode::AccountDidNotSerialize)?;
AnchorSerialize::serialize(
Expand All @@ -154,7 +161,8 @@ pub fn account(
}
}

impl anchor_lang::AccountDeserialize for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
fn try_deserialize(buf: &mut &[u8]) -> std::result::Result<Self, ProgramError> {
if buf.len() < #discriminator.len() {
return Err(anchor_lang::__private::ErrorCode::AccountDiscriminatorNotFound.into());
Expand All @@ -173,7 +181,8 @@ pub fn account(
}
}

impl anchor_lang::Discriminator for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
fn discriminator() -> [u8; 8] {
#discriminator
}
Expand Down Expand Up @@ -206,6 +215,7 @@ pub fn associated(
) -> proc_macro::TokenStream {
let mut account_strct = parse_macro_input!(input as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();

// Add a `__nonce: u8` field to the struct to hold the bump seed for
// the program dervied address.
Expand Down Expand Up @@ -245,7 +255,8 @@ pub fn associated(
#[anchor_lang::account(#args)]
#account_strct

impl anchor_lang::Bump for #account_name {
#[automatically_derived]
impl #impl_gen anchor_lang::Bump for #account_name #ty_gen #where_clause {
fn seed(&self) -> u8 {
self.__nonce
}
Expand All @@ -257,6 +268,7 @@ pub fn associated(
pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let account_strct = parse_macro_input!(item as syn::ItemStruct);
let account_name = &account_strct.ident;
let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();

let fields = match &account_strct.fields {
syn::Fields::Named(n) => n,
Expand Down Expand Up @@ -300,7 +312,8 @@ pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::T
})
.collect();
proc_macro::TokenStream::from(quote! {
impl #account_name {
#[automatically_derived]
impl #impl_gen #account_name #ty_gen #where_clause {
#(#methods)*
}
})
Expand Down
1 change: 1 addition & 0 deletions lang/syn/src/codegen/accounts/__client_accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
#(#account_struct_fields),*
}

#[automatically_derived]
impl anchor_lang::ToAccountMetas for #name {
fn to_account_metas(&self, is_signer: Option<bool>) -> Vec<anchor_lang::solana_program::instruction::AccountMeta> {
let mut account_metas = vec![];
Expand Down
10 changes: 5 additions & 5 deletions lang/syn/src/codegen/accounts/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,24 +355,24 @@ pub fn generate_constraint_associated_init(
)
}

fn parse_ty(f: &Field) -> (&syn::Ident, proc_macro2::TokenStream, bool) {
fn parse_ty(f: &Field) -> (&syn::TypePath, proc_macro2::TokenStream, bool) {
match &f.ty {
Ty::ProgramAccount(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::ProgramAccount
},
false,
),
Ty::Loader(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::Loader
},
true,
),
Ty::CpiAccount(ty) => (
&ty.account_ident,
&ty.account_type_path,
quote! {
anchor_lang::CpiAccount
},
Expand Down Expand Up @@ -617,7 +617,7 @@ pub fn generate_constraint_state(f: &Field, c: &ConstraintState) -> proc_macro2:
let program_target = c.program_target.clone();
let ident = &f.ident;
let account_ty = match &f.ty {
Ty::CpiState(ty) => &ty.account_ident,
Ty::CpiState(ty) => &ty.account_type_path,
_ => panic!("Invalid state constraint"),
};
quote! {
Expand Down
12 changes: 9 additions & 3 deletions lang/syn/src/codegen/accounts/exit.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::codegen::accounts::generics;
use crate::codegen::accounts::{generics, ParsedGenerics};
use crate::{AccountField, AccountsStruct};
use quote::quote;

// Generates the `Exit` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, trait_generics, strct_generics) = generics(accs);
let ParsedGenerics {
combined_generics,
trait_generics,
struct_generics,
where_clause,
} = generics(accs);

let on_save: Vec<proc_macro2::TokenStream> = accs
.fields
Expand Down Expand Up @@ -39,7 +44,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
})
.collect();
quote! {
impl#combined_generics anchor_lang::AccountsExit#trait_generics for #name#strct_generics {
#[automatically_derived]
impl<#combined_generics> anchor_lang::AccountsExit<#trait_generics> for #name<#struct_generics> #where_clause{
fn exit(&self, program_id: &anchor_lang::solana_program::pubkey::Pubkey) -> anchor_lang::solana_program::entrypoint::ProgramResult {
#(#on_save)*
Ok(())
Expand Down
82 changes: 69 additions & 13 deletions lang/syn/src/codegen/accounts/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use crate::AccountsStruct;
use quote::quote;
use std::iter;
use syn::punctuated::Punctuated;
use syn::{ConstParam, LifetimeDef, Token, TypeParam};
use syn::{GenericParam, PredicateLifetime, WhereClause, WherePredicate};

mod __client_accounts;
mod constraints;
Expand All @@ -26,18 +30,70 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
}
}

fn generics(
accs: &AccountsStruct,
) -> (
proc_macro2::TokenStream,
proc_macro2::TokenStream,
proc_macro2::TokenStream,
) {
match accs.generics.lt_token {
None => (quote! {<'info>}, quote! {<'info>}, quote! {}),
Some(_) => {
let g = &accs.generics;
(quote! {#g}, quote! {#g}, quote! {#g})
}
fn generics(accs: &AccountsStruct) -> ParsedGenerics {
let trait_lifetime = accs
.generics
.lifetimes()
.next()
.cloned()
.unwrap_or_else(|| syn::parse_str("'info").expect("Could not parse lifetime"));

let mut where_clause = accs.generics.where_clause.clone().unwrap_or(WhereClause {
where_token: Default::default(),
predicates: Default::default(),
});
for lifetime in accs.generics.lifetimes().map(|def| &def.lifetime) {
where_clause
.predicates
.push(WherePredicate::Lifetime(PredicateLifetime {
lifetime: lifetime.clone(),
colon_token: Default::default(),
bounds: iter::once(trait_lifetime.lifetime.clone()).collect(),
}))
}
let trait_lifetime = GenericParam::Lifetime(trait_lifetime);

ParsedGenerics {
combined_generics: if accs.generics.lifetimes().next().is_some() {
accs.generics.params.clone()
} else {
iter::once(trait_lifetime.clone())
.chain(accs.generics.params.clone())
.collect()
},
trait_generics: iter::once(trait_lifetime).collect(),
struct_generics: accs
.generics
.params
.clone()
.into_iter()
.map(|param: GenericParam| match param {
GenericParam::Const(ConstParam { ident, .. })
| GenericParam::Type(TypeParam { ident, .. }) => GenericParam::Type(TypeParam {
attrs: vec![],
ident,
colon_token: None,
bounds: Default::default(),
eq_token: None,
default: None,
}),
GenericParam::Lifetime(LifetimeDef { lifetime, .. }) => {
GenericParam::Lifetime(LifetimeDef {
attrs: vec![],
lifetime,
colon_token: None,
bounds: Default::default(),
})
}
})
.collect(),
where_clause,
}
}

struct ParsedGenerics {
pub combined_generics: Punctuated<GenericParam, Token![,]>,
pub trait_generics: Punctuated<GenericParam, Token![,]>,
pub struct_generics: Punctuated<GenericParam, Token![,]>,
pub where_clause: WhereClause,
}
12 changes: 9 additions & 3 deletions lang/syn/src/codegen/accounts/to_account_infos.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use crate::codegen::accounts::generics;
use crate::codegen::accounts::{generics, ParsedGenerics};
use crate::{AccountField, AccountsStruct};
use quote::quote;

// Generates the `ToAccountInfos` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, trait_generics, strct_generics) = generics(accs);
let ParsedGenerics {
combined_generics,
trait_generics,
struct_generics,
where_clause,
} = generics(accs);

let to_acc_infos: Vec<proc_macro2::TokenStream> = accs
.fields
Expand All @@ -21,7 +26,8 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
})
.collect();
quote! {
impl#combined_generics anchor_lang::ToAccountInfos#trait_generics for #name#strct_generics {
#[automatically_derived]
impl<#combined_generics> anchor_lang::ToAccountInfos<#trait_generics> for #name <#struct_generics> #where_clause{
fn to_account_infos(&self) -> Vec<anchor_lang::solana_program::account_info::AccountInfo<'info>> {
let mut account_infos = vec![];

Expand Down
8 changes: 5 additions & 3 deletions lang/syn/src/codegen/accounts/to_account_metas.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::codegen::accounts::generics;
use crate::{AccountField, AccountsStruct};
use quote::quote;

// Generates the `ToAccountMetas` trait implementation.
pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
let name = &accs.ident;
let (combined_generics, _trait_generics, strct_generics) = generics(accs);

let to_acc_metas: Vec<proc_macro2::TokenStream> = accs
.fields
Expand All @@ -26,8 +24,12 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
}
})
.collect();

let (impl_gen, ty_gen, where_clause) = accs.generics.split_for_impl();

quote! {
impl#combined_generics anchor_lang::ToAccountMetas for #name#strct_generics {
#[automatically_derived]
impl#impl_gen anchor_lang::ToAccountMetas for #name #ty_gen #where_clause{
fn to_account_metas(&self, is_signer: Option<bool>) -> Vec<anchor_lang::solana_program::instruction::AccountMeta> {
let mut account_metas = vec![];

Expand Down
Loading