Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion clients/rust-legacy/tests/cpi_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use {
},
solana_sdk::{
instruction::InstructionError, pubkey::Pubkey, rent::Rent, signature::Signer,
signer::keypair::Keypair, transaction::TransactionError, transport::TransportError,
signer::keypair::Keypair, transaction::Transaction, transaction::TransactionError,
transport::TransportError,
},
solana_system_interface::instruction as system_instruction,
spl_instruction_padding_interface::instruction::wrap_instruction,
spl_token_2022_interface::{
error::TokenError,
Expand Down Expand Up @@ -685,6 +687,57 @@ async fn test_cpi_guard_unwrap_lamports() {
assert_eq!(alice_state.base.amount, amount);
}

#[tokio::test]
async fn test_cpi_guard_withdraw_excess_lamports() {
let context = make_context_with_new_mint().await;
let program_context = context.context.clone();
let TokenContext {
token, alice, bob, ..
} = context.token_context.unwrap();

let withdraw_excess_lamports = [wrap_instruction(
spl_instruction_padding_interface::id(),
instruction::withdraw_excess_lamports(
&spl_token_2022_interface::id(),
&alice.pubkey(),
&bob.pubkey(),
&alice.pubkey(),
&[],
)
.unwrap(),
vec![],
0,
)
.unwrap()];

token
.enable_cpi_guard(&alice.pubkey(), &alice.pubkey(), &[&alice])
.await
.unwrap();

{
let context = program_context.lock().await;
let instructions = vec![system_instruction::transfer(
&context.payer.pubkey(),
&alice.pubkey(),
1,
)];
let tx = Transaction::new_signed_with_payer(
&instructions,
Some(&context.payer.pubkey()),
&[&context.payer],
context.last_blockhash,
);
context.banks_client.process_transaction(tx).await.unwrap();
}

let error = token
.process_ixs(&withdraw_excess_lamports, &[&alice])
.await
.expect_err("expected CPI withdraw_excess_lamports to be blocked by CPI Guard");
assert_eq!(error, client_error(TokenError::CpiGuardTransferBlocked));
}

async fn make_close_test_account<S: Signer>(
token: &Token<ProgramBanksClientProcessTransaction>,
owner: &S,
Expand Down
6 changes: 6 additions & 0 deletions program/src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,12 @@ impl Processor {
authority_info.data_len(),
account_info_iter.as_slice(),
)?;

if let Ok(cpi_guard) = account.get_extension::<CpiGuard>() {
if cpi_guard.lock_cpi.into() && in_cpi() {
return Err(TokenError::CpiGuardTransferBlocked.into());
}
}
} else if let Ok(mint) = PodStateWithExtensions::<PodMint>::unpack(&source_data) {
match &mint.base.mint_authority {
PodCOption {
Expand Down
Loading