Skip to content

Commit 24bfc08

Browse files
authored
lang: Allow CPI return values (otter-sec#1598)
1 parent 069330a commit 24bfc08

23 files changed

Lines changed: 471 additions & 8 deletions

File tree

.github/workflows/tests.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ jobs:
329329
path: tests/custom-coder
330330
- cmd: cd tests/validator-clone && yarn --frozen-lockfile && anchor test --skip-lint
331331
path: tests/validator-clone
332+
- cmd: cd tests/cpi-returns && anchor test --skip-lint
333+
path: tests/cpi-returns
332334
steps:
333335
- uses: actions/checkout@v2
334336
- uses: ./.github/actions/setup/

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The minor version will be incremented upon a breaking change and the patch versi
1212

1313
### Features
1414

15+
* lang: Add return values to CPI client. ([#1598](https://github.com/project-serum/anchor/pull/1598)).
1516
* avm: New `avm update` command to update the Anchor CLI to the latest version ([#1670](https://github.com/project-serum/anchor/pull/1670)).
1617

1718
### Fixes

lang/syn/src/codegen/program/cpi.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::codegen::program::common::{generate_ix_variant, sighash, SIGHASH_GLOB
22
use crate::Program;
33
use crate::StateIx;
44
use heck::SnakeCase;
5-
use quote::quote;
5+
use quote::{quote, ToTokens};
66

77
pub fn generate(program: &Program) -> proc_macro2::TokenStream {
88
// Generate cpi methods for the state struct.
@@ -70,11 +70,20 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
7070
let sighash_arr = sighash(SIGHASH_GLOBAL_NAMESPACE, name);
7171
let sighash_tts: proc_macro2::TokenStream =
7272
format!("{:?}", sighash_arr).parse().unwrap();
73+
let ret_type = &ix.returns.ty.to_token_stream();
74+
let (method_ret, maybe_return) = match ret_type.to_string().as_str() {
75+
"()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
76+
_ => (
77+
quote! { anchor_lang::Result<crate::cpi::Return::<#ret_type>> },
78+
quote! { Ok(crate::cpi::Return::<#ret_type> { phantom: crate::cpi::PhantomData }) }
79+
)
80+
};
81+
7382
quote! {
7483
pub fn #method_name<'a, 'b, 'c, 'info>(
7584
ctx: anchor_lang::context::CpiContext<'a, 'b, 'c, 'info, #accounts_ident<'info>>,
7685
#(#args),*
77-
) -> anchor_lang::Result<()> {
86+
) -> #method_ret {
7887
let ix = {
7988
let ix = instruction::#ix_variant;
8089
let mut ix_data = AnchorSerialize::try_to_vec(&ix)
@@ -93,7 +102,11 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
93102
&ix,
94103
&acc_infos,
95104
ctx.signer_seeds,
96-
).map_err(Into::into)
105+
).map_or_else(
106+
|e| Err(Into::into(e)),
107+
// Maybe handle Solana return data.
108+
|_| { #maybe_return }
109+
)
97110
}
98111
}
99112
};
@@ -108,13 +121,25 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
108121
#[cfg(feature = "cpi")]
109122
pub mod cpi {
110123
use super::*;
124+
use std::marker::PhantomData;
111125

112126
pub mod state {
113127
use super::*;
114128

115129
#(#state_cpi_methods)*
116130
}
117131

132+
pub struct Return<T> {
133+
phantom: std::marker::PhantomData<T>
134+
}
135+
136+
impl<T: AnchorDeserialize> Return<T> {
137+
pub fn get(&self) -> T {
138+
let (_key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
139+
T::try_from_slice(&data).unwrap()
140+
}
141+
}
142+
118143
#(#global_cpi_methods)*
119144

120145
#accounts

lang/syn/src/codegen/program/handlers.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::codegen::program::common::*;
22
use crate::{Program, State};
33
use heck::CamelCase;
4-
use quote::quote;
4+
use quote::{quote, ToTokens};
55

66
// Generate non-inlined wrappers for each instruction handler, since Solana's
77
// BPF max stack size can't handle reasonable sized dispatch trees without doing
@@ -694,6 +694,13 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
694694
let anchor = &ix.anchor_ident;
695695
let variant_arm = generate_ix_variant(ix.raw_method.sig.ident.to_string(), &ix.args);
696696
let ix_name_log = format!("Instruction: {}", ix_name);
697+
let ret_type = &ix.returns.ty.to_token_stream();
698+
let maybe_set_return_data = match ret_type.to_string().as_str() {
699+
"()" => quote! {},
700+
_ => quote! {
701+
anchor_lang::solana_program::program::set_return_data(&result.try_to_vec().unwrap());
702+
},
703+
};
697704
quote! {
698705
#[inline(never)]
699706
pub fn #ix_method_name(
@@ -722,7 +729,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
722729
)?;
723730

724731
// Invoke user defined handler.
725-
#program_name::#ix_method_name(
732+
let result = #program_name::#ix_method_name(
726733
anchor_lang::context::Context::new(
727734
program_id,
728735
&mut accounts,
@@ -732,6 +739,9 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
732739
#(#ix_arg_names),*
733740
)?;
734741

742+
// Maybe set Solana return data.
743+
#maybe_set_return_data
744+
735745
// Exit routine.
736746
accounts.exit(program_id)
737747
}

lang/syn/src/idl/file.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ pub fn parse(
6666
name,
6767
accounts,
6868
args,
69+
returns: None,
6970
}
7071
})
7172
.collect::<Vec<_>>()
@@ -105,6 +106,7 @@ pub fn parse(
105106
name,
106107
accounts,
107108
args,
109+
returns: None,
108110
}
109111
};
110112

@@ -164,10 +166,16 @@ pub fn parse(
164166
// todo: don't unwrap
165167
let accounts_strct = accs.get(&ix.anchor_ident.to_string()).unwrap();
166168
let accounts = idl_accounts(&ctx, accounts_strct, &accs, seeds_feature);
169+
let ret_type_str = ix.returns.ty.to_token_stream().to_string();
170+
let returns = match ret_type_str.as_str() {
171+
"()" => None,
172+
_ => Some(ret_type_str.parse().unwrap()),
173+
};
167174
IdlInstruction {
168175
name: ix.ident.to_string().to_mixed_case(),
169176
accounts,
170177
args,
178+
returns,
171179
}
172180
})
173181
.collect::<Vec<_>>();

lang/syn/src/idl/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub struct IdlInstruction {
4545
pub name: String,
4646
pub accounts: Vec<IdlAccountItem>,
4747
pub args: Vec<IdlField>,
48+
pub returns: Option<IdlType>,
4849
}
4950

5051
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]

lang/syn/src/lib.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use syn::spanned::Spanned;
1414
use syn::token::Comma;
1515
use syn::{
1616
Expr, Generics, Ident, ImplItemMethod, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStruct, LitInt,
17-
LitStr, PatType, Token, TypePath,
17+
LitStr, PatType, Token, Type, TypePath,
1818
};
1919

2020
pub mod codegen;
@@ -85,6 +85,7 @@ pub struct Ix {
8585
pub raw_method: ItemFn,
8686
pub ident: Ident,
8787
pub args: Vec<IxArg>,
88+
pub returns: IxReturn,
8889
// The ident for the struct deriving Accounts.
8990
pub anchor_ident: Ident,
9091
}
@@ -95,6 +96,11 @@ pub struct IxArg {
9596
pub raw_arg: PatType,
9697
}
9798

99+
#[derive(Debug)]
100+
pub struct IxReturn {
101+
pub ty: Type,
102+
}
103+
98104
#[derive(Debug)]
99105
pub struct FallbackFn {
100106
raw_method: ItemFn,

lang/syn/src/parser/program/instructions.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::parser::program::ctx_accounts_ident;
2-
use crate::{FallbackFn, Ix, IxArg};
2+
use crate::{FallbackFn, Ix, IxArg, IxReturn};
33
use syn::parse::{Error as ParseError, Result as ParseResult};
44
use syn::spanned::Spanned;
55

@@ -23,12 +23,14 @@ pub fn parse(program_mod: &syn::ItemMod) -> ParseResult<(Vec<Ix>, Option<Fallbac
2323
})
2424
.map(|method: &syn::ItemFn| {
2525
let (ctx, args) = parse_args(method)?;
26+
let returns = parse_return(method)?;
2627
let anchor_ident = ctx_accounts_ident(&ctx.raw_arg)?;
2728
Ok(Ix {
2829
raw_method: method.clone(),
2930
ident: method.sig.ident.clone(),
3031
args,
3132
anchor_ident,
33+
returns,
3234
})
3335
})
3436
.collect::<ParseResult<Vec<Ix>>>()?;
@@ -91,3 +93,34 @@ pub fn parse_args(method: &syn::ItemFn) -> ParseResult<(IxArg, Vec<IxArg>)> {
9193

9294
Ok((ctx, args))
9395
}
96+
97+
pub fn parse_return(method: &syn::ItemFn) -> ParseResult<IxReturn> {
98+
match method.sig.output {
99+
syn::ReturnType::Type(_, ref ty) => {
100+
let ty = match ty.as_ref() {
101+
syn::Type::Path(ty) => ty,
102+
_ => return Err(ParseError::new(ty.span(), "expected a return type")),
103+
};
104+
// Assume unit return by default
105+
let default_generic_arg = syn::GenericArgument::Type(syn::parse_str("()").unwrap());
106+
let generic_args = match &ty.path.segments.last().unwrap().arguments {
107+
syn::PathArguments::AngleBracketed(params) => params.args.iter().last().unwrap(),
108+
_ => &default_generic_arg,
109+
};
110+
let ty = match generic_args {
111+
syn::GenericArgument::Type(ty) => ty.clone(),
112+
_ => {
113+
return Err(ParseError::new(
114+
ty.span(),
115+
"expected generic return type to be a type",
116+
))
117+
}
118+
};
119+
Ok(IxReturn { ty })
120+
}
121+
_ => Err(ParseError::new(
122+
method.sig.output.span(),
123+
"expected a return type",
124+
)),
125+
}
126+
}

tests/cpi-returns/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
.anchor
3+
.DS_Store
4+
target
5+
**/*.rs.bk
6+
node_modules

tests/cpi-returns/Anchor.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[features]
2+
seeds = false
3+
4+
[programs.localnet]
5+
callee = "Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS"
6+
caller = "HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L"
7+
8+
[registry]
9+
url = "https://anchor.projectserum.com"
10+
11+
[provider]
12+
cluster = "localnet"
13+
wallet = "~/.config/solana/id.json"
14+
15+
[scripts]
16+
test = "yarn run ts-mocha -p ./tsconfig.json -t 1000000 tests/**/*.ts"

0 commit comments

Comments
 (0)