Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/reusable-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ jobs:
path: tests/idl
- cmd: cd tests/lazy-account && anchor test
path: tests/lazy-account
- cmd: cd tests/test-instruction-validation && ./test.sh
path: tests/test-instruction-validation
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/setup/
Expand Down
7 changes: 7 additions & 0 deletions lang/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ pub mod __private {
pub use crate::lazy::Lazy;
#[cfg(feature = "lazy-account")]
pub use anchor_derive_serde::Lazy;

/// Trait for compile-time type equality checking.
/// Used to enforce that instruction argument types match the `#[instruction(...)]` attribute types.
#[doc(hidden)]
pub trait IsSameType<T> {}

impl<T> IsSameType<T> for T {}
}

/// Ensures a condition is true, otherwise returns with the given error.
Expand Down
103 changes: 103 additions & 0 deletions lang/syn/src/codegen/accounts/try_accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,110 @@ pub fn generate(accs: &AccountsStruct) -> proc_macro2::TokenStream {
}
};

// Generate type validation methods for instruction parameters
let type_validation_methods = match &accs.instruction_api {
None => {
// generate stub methods for up to 32 possible arguments
let stub_methods: Vec<proc_macro2::TokenStream> = (0..32)
.map(|idx| {
let method_name = syn::Ident::new(
&format!("__anchor_validate_ix_arg_type_{}", idx),
proc_macro2::Span::call_site(),
);
quote! {
#[doc(hidden)]
#[inline(always)]
#[allow(unused)]
pub fn #method_name<__T>(_arg: &__T) {
// no type validation when #[instruction(...)] is missing
}
}
})
.collect();

quote! {
#(#stub_methods)*
}
}
Some(ix_api) => {
let declared_count = ix_api.len();

// Generate strict validation methods for declared parameters
let type_check_methods: Vec<proc_macro2::TokenStream> = ix_api
.iter()
.enumerate()
.map(|(idx, expr)| {
if let Expr::Type(expr_type) = expr {
let ty = &expr_type.ty;
let method_name = syn::Ident::new(
&format!("__anchor_validate_ix_arg_type_{}", idx),
proc_macro2::Span::call_site(),
);
quote! {
#[doc(hidden)]
#[inline(always)]
pub fn #method_name<__T>(_arg: &__T)
where
__T: anchor_lang::__private::IsSameType<#ty>,
{}
}
} else {
panic!("Invalid instruction declaration");
}
})
.collect();

// stub methods for remaining argument positions (up to 32 total)
let stub_methods: Vec<proc_macro2::TokenStream> = (declared_count..32)
.map(|idx| {
let method_name = syn::Ident::new(
&format!("__anchor_validate_ix_arg_type_{}", idx),
proc_macro2::Span::call_site(),
);
quote! {
#[doc(hidden)]
#[inline(always)]
#[allow(unused)]
pub fn #method_name<__T>(_arg: &__T) {
}
}
})
.collect();

quote! {
#(#type_check_methods)*
#(#stub_methods)*
}
}
};

let param_count_const = match &accs.instruction_api {
None => quote! {
#[automatically_derived]
impl<#combined_generics> #name<#struct_generics> #where_clause {
#[doc(hidden)]
pub const __ANCHOR_IX_PARAM_COUNT: usize = 0;

#type_validation_methods
}
},
Some(ix_api) => {
let count = ix_api.len();

quote! {
#[automatically_derived]
impl<#combined_generics> #name<#struct_generics> #where_clause {
#[doc(hidden)]
pub const __ANCHOR_IX_PARAM_COUNT: usize = #count;

#type_validation_methods
}
}
}
};

quote! {
#param_count_const
#[automatically_derived]
impl<#combined_generics> anchor_lang::Accounts<#trait_generics, #bumps_struct_name> for #name<#struct_generics> #where_clause {
#[inline(never)]
Expand Down
52 changes: 52 additions & 0 deletions lang/syn/src/codegen/program/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,57 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
anchor_lang::solana_program::program::set_return_data(&return_data);
},
};

let actual_param_count = ix.args.len();
let ix_name_str = ix_method_name.to_string();
let accounts_type_str = anchor.to_string();

// Build clear error messages
let count_error_msg = format!(
"#[instruction(...)] on Account `{}<'_>` expects MORE args, the ix `{}(...)` has only {} args.",
accounts_type_str,
ix_name_str,
actual_param_count,
);

// Generate type validation calls for each argument
let type_validations: Vec<proc_macro2::TokenStream> = ix.args
.iter()
.enumerate()
.map(|(idx, arg)| {
let arg_ty = &arg.raw_arg.ty;
let method_name = syn::Ident::new(
&format!("__anchor_validate_ix_arg_type_{}", idx),
proc_macro2::Span::call_site(),
);
quote! {
// Type validation for argument #idx
if #anchor::__ANCHOR_IX_PARAM_COUNT > #idx {
if false {
// This code is never executed but is type-checked at compile time
let __type_check_arg: #arg_ty = panic!();
Comment thread
0x4ka5h marked this conversation as resolved.
#anchor::#method_name(&__type_check_arg);
}
}
}
})
.collect();

let param_validation = quote! {
const _: () = {
const EXPECTED_COUNT: usize = #anchor::__ANCHOR_IX_PARAM_COUNT;
const HANDLER_PARAM_COUNT: usize = #actual_param_count;

// Count validation
if EXPECTED_COUNT > HANDLER_PARAM_COUNT {
panic!(#count_error_msg);
}
};

// Type validations
#(#type_validations)*
};

quote! {
#(#cfgs)*
#[inline(never)]
Expand All @@ -124,6 +175,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
#[cfg(not(feature = "no-log-ix-name"))]
anchor_lang::prelude::msg!(#ix_name_log);

#param_validation
// Deserialize data.
let ix = instruction::#ix_name::deserialize(&mut &__ix_data[..])
.map_err(|_| anchor_lang::error::ErrorCode::InstructionDidNotDeserialize)?;
Expand Down
15 changes: 15 additions & 0 deletions tests/test-instruction-validation/fail-args-count/Anchor.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[toolchain]

[features]
resolution = true
skip-lint = false

[programs.localnet]
test_instruction_validation = "Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS"

[registry]
url = "https://api.apr.dev"

[provider]
cluster = "Localnet"
wallet = "~/.config/solana/id.json"
16 changes: 16 additions & 0 deletions tests/test-instruction-validation/fail-args-count/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[workspace]
members = [
"programs/*"
]
resolver = "2"

[profile.release]
overflow-checks = true
lto = "fat"
codegen-units = 1

[profile.release.build-override]
opt-level = 3
incremental = false
codegen-units = 1

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "test-instruction-validation"
version = "0.1.0"
description = "Test for instruction parameter validation"
edition = "2021"

[lib]
crate-type = ["cdylib", "lib"]
name = "test_instruction_validation"

[features]
default = []
cpi = ["no-entrypoint"]
no-entrypoint = []
no-idl = []
no-log-ix-name = []
idl-build = ["anchor-lang/idl-build"]
anchor-debug = ["anchor-lang/anchor-debug"]

[dependencies]
anchor-lang = { path = "../../../../../lang" }

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#![allow(unexpected_cfgs)]

use anchor_lang::prelude::*;

declare_id!("Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS");

type MyType = u64;
#[program]
pub mod test_instruction_validation {
use super::*;

// Test 1: Missing parameter - handler only has 1 arg but #[instruction] expects 2
pub fn missing_instruction_attr(
_ctx: Context<MissingInstructionAttr>,
data: u64, // Handler has only 1 parameter
) -> Result<()> {
msg!("Data: {}", data);
Ok(())
}

pub fn no_params(_ctx: Context<NoParams>) -> Result<()> {
msg!("No params needed");
Ok(())
}

// Test 2: Type mismatch - handler has u64 but #[instruction(...)] has u8
pub fn type_mismatch(
_ctx: Context<TypeMismatch>,
data: u64, // Handler parameter is u64
) -> Result<()> {
msg!("Data: {}", data);
Ok(())
}
}

#[derive(Accounts)]
#[instruction(data: u64, ehe: u64)] // Expects 2 params but handler only has 1
pub struct MissingInstructionAttr<'info> {
#[account(mut)]
pub user: Signer<'info>,
}

#[derive(Accounts)]
// No #[instruction(...)] - correct for no params
pub struct NoParams<'info> {
pub user: Signer<'info>,
}

#[derive(Accounts)]
#[instruction(data: MyType)] // Attribute specifies u8 but handler has u64
pub struct TypeMismatch<'info> {
pub user: Signer<'info>,
}
15 changes: 15 additions & 0 deletions tests/test-instruction-validation/fail-type/Anchor.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[toolchain]

[features]
resolution = true
skip-lint = false

[programs.localnet]
test_instruction_validation = "Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS"

[registry]
url = "https://api.apr.dev"

[provider]
cluster = "Localnet"
wallet = "~/.config/solana/id.json"
16 changes: 16 additions & 0 deletions tests/test-instruction-validation/fail-type/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[workspace]
members = [
"programs/*"
]
resolver = "2"

[profile.release]
overflow-checks = true
lto = "fat"
codegen-units = 1

[profile.release.build-override]
opt-level = 3
incremental = false
codegen-units = 1

Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "test-instruction-validation"
version = "0.1.0"
description = "Test for instruction parameter validation"
edition = "2021"

[lib]
crate-type = ["cdylib", "lib"]
name = "test_instruction_validation"

[features]
default = []
cpi = ["no-entrypoint"]
no-entrypoint = []
no-idl = []
no-log-ix-name = []
idl-build = ["anchor-lang/idl-build"]
anchor-debug = ["anchor-lang/anchor-debug"]

[dependencies]
anchor-lang = { path = "../../../../../lang" }

Loading
Loading