Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The minor version will be incremented upon a breaking change and the patch versi
- cli: Add `--stdout` flag to the `expand` command ([#4400](https://github.com/solana-foundation/anchor/pull/4400)).
- client: Add versioned tx support ([#4207](https://github.com/solana-foundation/anchor/pull/4207)).
- cli: Add `edition` and `rust-version` to template ([#4048](https://github.com/solana-foundation/anchor/pull/4048))
- lang: Add `program_id` verification to CPI return values ([#4411](https://github.com/solana-foundation/anchor/pull/4411)).

### Fixes

Expand Down
17 changes: 15 additions & 2 deletions lang/attribute/program/src/declare_program/mods/cpi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn gen_cpi_instructions(idl: &Idl) -> proc_macro2::TokenStream {
let ty = convert_idl_type_to_syn_type(ty);
(
quote! { anchor_lang::Result<Return::<#ty>> },
quote! { Ok(Return::<#ty> { phantom: std::marker::PhantomData }) },
quote! { Ok(Return::<#ty> { phantom: std::marker::PhantomData, program_id: ctx.program_id }) },
)
},
None => (
Expand Down Expand Up @@ -105,11 +105,24 @@ fn gen_cpi_instructions(idl: &Idl) -> proc_macro2::TokenStream {
fn gen_cpi_return_type() -> proc_macro2::TokenStream {
quote! {
pub struct Return<T> {
phantom: std::marker::PhantomData<T>
phantom: std::marker::PhantomData<T>,
program_id: anchor_lang::solana_program::pubkey::Pubkey,
}

impl<T: AnchorDeserialize> Return<T> {
pub fn get(&self) -> T {
let (key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
if key != self.program_id {
anchor_lang::solana_program::log::sol_log("CPI return data program_id mismatch");
panic!();
}
T::try_from_slice(&data).unwrap()
}

/// Read return data without validating the program_id.
/// Use this only when you intentionally need to read return data
/// from a different program than the one that was CPI'd into.
pub fn get_unchecked(&self) -> T {
let (_key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
T::try_from_slice(&data).unwrap()
}
Expand Down
17 changes: 15 additions & 2 deletions lang/syn/src/codegen/program/cpi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {
"()" => (quote! {anchor_lang::Result<()> }, quote! { Ok(()) }),
_ => (
quote! { anchor_lang::Result<crate::cpi::Return::<#ret_type>> },
quote! { Ok(crate::cpi::Return::<#ret_type> { phantom: crate::cpi::PhantomData }) }
quote! { Ok(crate::cpi::Return::<#ret_type> { phantom: crate::cpi::PhantomData, program_id: ctx.program_id }) }
)
};

Expand Down Expand Up @@ -91,11 +91,24 @@ pub fn generate(program: &Program) -> proc_macro2::TokenStream {


pub struct Return<T> {
phantom: std::marker::PhantomData<T>
phantom: std::marker::PhantomData<T>,
program_id: anchor_lang::solana_program::pubkey::Pubkey,
}

impl<T: AnchorDeserialize> Return<T> {
pub fn get(&self) -> T {
let (key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
if key != self.program_id {
anchor_lang::solana_program::log::sol_log("CPI return data program_id mismatch");
panic!();
}
T::try_from_slice(&data).unwrap()
}

/// Read return data without validating the program_id.
/// Use this only when you intentionally need to read return data
/// from a different program than the one that was CPI'd into.
pub fn get_unchecked(&self) -> T {
let (_key, data) = anchor_lang::solana_program::program::get_return_data().unwrap();
T::try_from_slice(&data).unwrap()
}
Expand Down
1 change: 1 addition & 0 deletions tests/cpi-returns/Anchor.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ seeds = false
[programs.localnet]
callee = "Fg6PaFpoGXkYsidMpWTK6W2BeZ7FEfcYkg476zPFsLnS"
caller = "HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L"
malicious = "6nWiFMhouBBrXir1h6BoZHoUzYJQTHwjUPPTGuKY9gXB"

[provider]
cluster = "localnet"
Expand Down
1 change: 1 addition & 0 deletions tests/cpi-returns/programs/caller/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ idl-build = ["anchor-lang/idl-build"]
[dependencies]
anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] }
callee = { path = "../callee", features = ["cpi"] }
malicious = { path = "../malicious", features = ["cpi"] }
80 changes: 77 additions & 3 deletions tests/cpi-returns/programs/caller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use anchor_lang::prelude::*;
use callee::cpi::accounts::CpiReturn;
use callee::program::Callee;
use callee::{self, CpiReturnAccount};
use malicious::cpi::accounts::SpoofReturn;
use malicious::program::Malicious;

declare_id!("HmbTLCmaGvZhKnn1Zfa1JVnp7vkMV4DYVxPLWBVoN65L");

Expand Down Expand Up @@ -51,17 +53,79 @@ pub mod caller {
Ok(())
}

pub fn return_u64(ctx: Context<ReturnContext>) -> Result<u64> {
pub fn return_u64(_ctx: Context<ReturnContext>) -> Result<u64> {
Ok(99)
}

pub fn return_struct(ctx: Context<ReturnContext>) -> Result<Struct> {
pub fn return_struct(_ctx: Context<ReturnContext>) -> Result<Struct> {
Ok(Struct { a: 1, b: 2 })
}

pub fn return_vec(ctx: Context<ReturnContext>) -> Result<Vec<u64>> {
pub fn return_vec(_ctx: Context<ReturnContext>) -> Result<Vec<u64>> {
Ok(vec![1, 2, 3])
}

/// PoC: Demonstrates that get_unchecked() reads spoofed return data.
/// This replicates the OLD (vulnerable) behavior of get().
///
/// 1. CPI to callee::return_u64 -> callee sets return data = 10
/// 2. CPI to malicious::spoof_return_data -> overwrites return data with 999
/// 3. get_unchecked() reads 999 instead of 10 (SPOOFED!)
pub fn cpi_call_return_u64_spoofed(ctx: Context<SpoofedReturnContext>) -> Result<()> {
// Step 1: CPI to callee, which returns u64 = 10
let cpi_program_id = ctx.accounts.cpi_return_program.key();
let cpi_accounts = CpiReturn {
account: ctx.accounts.cpi_return.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program_id, cpi_accounts);
let result = callee::cpi::return_u64(cpi_ctx)?;

// Step 2: CPI to malicious program, which calls set_return_data(999)
let malicious_program_id = ctx.accounts.malicious_program.key();
let spoof_accounts = SpoofReturn {
authority: ctx.accounts.authority.to_account_info(),
};
let spoof_ctx = CpiContext::new(malicious_program_id, spoof_accounts);
malicious::cpi::spoof_return_data(spoof_ctx)?;

// Step 3: Use get_unchecked() (old vulnerable behavior) to read the
// spoofed return data without program_id validation.
let spoofed_value = result.get_unchecked();

// Log the spoofed value so the test can verify it
anchor_lang::solana_program::log::sol_log_data(&[&borsh::to_vec(&spoofed_value).unwrap()]);

Ok(())
}

/// PoC: Demonstrates that get() (with fix) REJECTS spoofed return data.
///
/// Same flow as above, but uses get() instead of get_unchecked().
/// This will panic because the program_id from get_return_data() doesn't
/// match the expected callee program_id.
pub fn cpi_call_return_u64_spoofed_rejected(ctx: Context<SpoofedReturnContext>) -> Result<()> {
// Step 1: CPI to callee, which returns u64 = 10
let cpi_program_id = ctx.accounts.cpi_return_program.key();
let cpi_accounts = CpiReturn {
account: ctx.accounts.cpi_return.to_account_info(),
};
let cpi_ctx = CpiContext::new(cpi_program_id, cpi_accounts);
let result = callee::cpi::return_u64(cpi_ctx)?;

// Step 2: CPI to malicious program, which calls set_return_data(999)
let malicious_program_id = ctx.accounts.malicious_program.key();
let spoof_accounts = SpoofReturn {
authority: ctx.accounts.authority.to_account_info(),
};
let spoof_ctx = CpiContext::new(malicious_program_id, spoof_accounts);
malicious::cpi::spoof_return_data(spoof_ctx)?;

// Step 3: Use get() (FIXED) — this validates program_id and will PANIC
// because return data was set by malicious, not callee.
let _value = result.get();

Ok(())
}
}

#[derive(Accounts)]
Expand All @@ -71,5 +135,15 @@ pub struct CpiReturnContext<'info> {
pub cpi_return_program: Program<'info, Callee>,
}

#[derive(Accounts)]
pub struct SpoofedReturnContext<'info> {
#[account(mut)]
pub authority: Signer<'info>,
#[account(mut)]
pub cpi_return: Account<'info, CpiReturnAccount>,
pub cpi_return_program: Program<'info, Callee>,
pub malicious_program: Program<'info, Malicious>,
}

#[derive(Accounts)]
pub struct ReturnContext {}
20 changes: 20 additions & 0 deletions tests/cpi-returns/programs/malicious/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "malicious"
version = "0.1.0"
description = "PoC: Malicious program that spoofs CPI return data"
edition = "2021"

[lib]
crate-type = ["cdylib", "lib"]
name = "malicious"

[features]
no-entrypoint = []
no-idl = []
no-log-ix-name = []
cpi = ["no-entrypoint"]
default = []
idl-build = ["anchor-lang/idl-build"]

[dependencies]
anchor-lang = { path = "../../../../lang", features = ["init-if-needed"] }
2 changes: 2 additions & 0 deletions tests/cpi-returns/programs/malicious/Xargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[target.bpfel-unknown-unknown.dependencies.std]
features = []
26 changes: 26 additions & 0 deletions tests/cpi-returns/programs/malicious/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use anchor_lang::prelude::*;

declare_id!("6nWiFMhouBBrXir1h6BoZHoUzYJQTHwjUPPTGuKY9gXB");

#[program]
pub mod malicious {
use super::*;

/// This instruction manually calls set_return_data with a spoofed u64 value.
/// When a caller reads return data via Return<T>::get() after this CPI,
/// it will receive this spoofed value instead of the legitimate callee's value.
pub fn spoof_return_data(_ctx: Context<SpoofReturn>) -> Result<()> {
// Spoof a u64 value of 999 (0x03E7 in little-endian)
let spoofed_value: u64 = 999;
let data = spoofed_value.to_le_bytes();
anchor_lang::solana_program::program::set_return_data(&data);
Ok(())
}
}

#[derive(Accounts)]
pub struct SpoofReturn<'info> {
/// Dummy signer to satisfy CPI account requirements.
/// CHECK: No constraints needed for the PoC.
pub authority: Signer<'info>,
}
88 changes: 88 additions & 0 deletions tests/cpi-returns/tests/cpi-return.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import * as borsh from "borsh";
import { Program } from "@anchor-lang/core";
import { Callee } from "../target/types/callee";
import { Caller } from "../target/types/caller";
import { Malicious } from "../target/types/malicious";
import { ConfirmOptions } from "@solana/web3.js";

const { SystemProgram } = anchor.web3;
Expand All @@ -14,6 +15,7 @@ describe("CPI return", () => {

const callerProgram = anchor.workspace.Caller as Program<Caller>;
const calleeProgram = anchor.workspace.Callee as Program<Callee>;
const maliciousProgram = anchor.workspace.Malicious as Program<Malicious>;

const getReturnLog = (confirmedTransaction) => {
const prefix = "Program return: ";
Expand Down Expand Up @@ -233,4 +235,90 @@ describe("CPI return", () => {
assert(e.message.includes("Method does not support views"));
}
});

// === VULNERABILITY PoC: Return data spoofing ===

it("VULNERABILITY: get_unchecked() reads spoofed return data (old behavior)", async () => {
// This demonstrates what happened BEFORE the fix.
// get_unchecked() preserves the old behavior for backward compatibility,
// showing that a malicious program can spoof return data.
const tx = await callerProgram.methods
.cpiCallReturnU64Spoofed()
.accounts({
authority: provider.wallet.publicKey,
cpiReturn: cpiReturn.publicKey,
cpiReturnProgram: calleeProgram.programId,
maliciousProgram: maliciousProgram.programId,
})
.rpc(confirmOptions);

let t = await provider.connection.getTransaction(tx, {
commitment: "confirmed",
maxSupportedTransactionVersion: 0,
});

// Find the "Program data:" log emitted by the caller
const dataPrefix = "Program data: ";
const dataLogs = t.meta.logMessages.filter((log) =>
log.startsWith(dataPrefix)
);
const lastDataLog = dataLogs[dataLogs.length - 1];
const b64Data = lastDataLog.slice(dataPrefix.length);
const buffer = Buffer.from(b64Data, "base64");

const reader = new borsh.BinaryReader(buffer);
const spoofedValue = reader.readU64().toNumber();

// Callee returned 10, but malicious program overwrote with 999.
// get_unchecked() (old behavior) happily returns the spoofed value.
assert.notEqual(
spoofedValue,
10,
"Expected spoofed value, not the real callee value"
);
assert.equal(
spoofedValue,
999,
"Malicious program successfully spoofed return data"
);

console.log(`\n VULNERABILITY CONFIRMED (get_unchecked / old behavior):`);
console.log(` Callee returned: 10`);
console.log(` Malicious spoofed: 999`);
console.log(` Caller received: ${spoofedValue} (SPOOFED!)\n`);
});

it("FIX: get() rejects spoofed return data with program_id validation", async () => {
// After the fix, get() validates the program_id from get_return_data()
// against the expected program. This should FAIL because the return data
// was set by the malicious program, not the callee.
try {
await callerProgram.methods
.cpiCallReturnU64SpoofedRejected()
.accounts({
authority: provider.wallet.publicKey,
cpiReturn: cpiReturn.publicKey,
cpiReturnProgram: calleeProgram.programId,
maliciousProgram: maliciousProgram.programId,
})
.rpc(confirmOptions);

// If we get here, the fix didn't work
assert.fail("Expected transaction to fail due to program_id mismatch");
} catch (e) {
// Verify the error is specifically from the program_id validation,
// not some unrelated failure.
const errStr = JSON.stringify(e);
assert(
errStr.includes("program_id mismatch") ||
errStr.includes("ProgramFailedToComplete"),
`Expected program_id mismatch error, got: ${e.message?.substring(
0,
200
)}`
);
console.log(`\n FIX CONFIRMED: get() rejected spoofed return data`);
console.log(` Error: ${e.message?.substring(0, 100)}...\n`);
}
});
});
Loading