diff --git a/clients/rust-legacy/tests/cpi_guard.rs b/clients/rust-legacy/tests/cpi_guard.rs index c2c3061fd..39757585a 100644 --- a/clients/rust-legacy/tests/cpi_guard.rs +++ b/clients/rust-legacy/tests/cpi_guard.rs @@ -8,8 +8,10 @@ use { }, solana_sdk::{ instruction::InstructionError, pubkey::Pubkey, rent::Rent, signature::Signer, - signer::keypair::Keypair, transaction::TransactionError, transport::TransportError, + signer::keypair::Keypair, transaction::Transaction, transaction::TransactionError, + transport::TransportError, }, + solana_system_interface::instruction as system_instruction, spl_instruction_padding_interface::instruction::wrap_instruction, spl_token_2022_interface::{ error::TokenError, @@ -685,6 +687,57 @@ async fn test_cpi_guard_unwrap_lamports() { assert_eq!(alice_state.base.amount, amount); } +#[tokio::test] +async fn test_cpi_guard_withdraw_excess_lamports() { + let context = make_context_with_new_mint().await; + let program_context = context.context.clone(); + let TokenContext { + token, alice, bob, .. + } = context.token_context.unwrap(); + + let withdraw_excess_lamports = [wrap_instruction( + spl_instruction_padding_interface::id(), + instruction::withdraw_excess_lamports( + &spl_token_2022_interface::id(), + &alice.pubkey(), + &bob.pubkey(), + &alice.pubkey(), + &[], + ) + .unwrap(), + vec![], + 0, + ) + .unwrap()]; + + token + .enable_cpi_guard(&alice.pubkey(), &alice.pubkey(), &[&alice]) + .await + .unwrap(); + + { + let context = program_context.lock().await; + let instructions = vec![system_instruction::transfer( + &context.payer.pubkey(), + &alice.pubkey(), + 1, + )]; + let tx = Transaction::new_signed_with_payer( + &instructions, + Some(&context.payer.pubkey()), + &[&context.payer], + context.last_blockhash, + ); + context.banks_client.process_transaction(tx).await.unwrap(); + } + + let error = token + .process_ixs(&withdraw_excess_lamports, &[&alice]) + .await + .expect_err("expected CPI withdraw_excess_lamports to be blocked by CPI Guard"); + assert_eq!(error, client_error(TokenError::CpiGuardTransferBlocked)); +} + async fn make_close_test_account( token: &Token, owner: &S, diff --git a/program/src/processor.rs b/program/src/processor.rs index 64d4abb23..694643502 100644 --- a/program/src/processor.rs +++ b/program/src/processor.rs @@ -1705,6 +1705,12 @@ impl Processor { authority_info.data_len(), account_info_iter.as_slice(), )?; + + if let Ok(cpi_guard) = account.get_extension::() { + if cpi_guard.lock_cpi.into() && in_cpi() { + return Err(TokenError::CpiGuardTransferBlocked.into()); + } + } } else if let Ok(mint) = PodStateWithExtensions::::unpack(&source_data) { match &mint.base.mint_authority { PodCOption {