diff --git a/examples/anti_fee_sniping.rs b/examples/anti_fee_sniping.rs index 43f7796..82174ca 100644 --- a/examples/anti_fee_sniping.rs +++ b/examples/anti_fee_sniping.rs @@ -88,13 +88,10 @@ fn main() -> anyhow::Result<()> { }, )?; - let fallback_locktime: LockTime = LockTime::from_consensus(tip_height.to_consensus_u32()); - let selection_inputs = selection.inputs.clone(); let psbt = selection.create_psbt(PsbtParams { - enable_anti_fee_sniping: true, - fallback_locktime, + anti_fee_sniping: Some(tip_height), ..Default::default() })?; diff --git a/src/lib.rs b/src/lib.rs index 46866f4..0587350 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ pub use rbf::*; pub use selection::*; pub use selector::*; pub use signer::*; -use utils::*; +pub use utils::*; #[cfg(feature = "std")] pub(crate) mod collections { diff --git a/src/selection.rs b/src/selection.rs index 22daae3..48cb851 100644 --- a/src/selection.rs +++ b/src/selection.rs @@ -3,14 +3,11 @@ use alloc::vec::Vec; use core::fmt::{Debug, Display}; use miniscript::bitcoin; -use miniscript::bitcoin::{ - absolute::{self, LockTime}, - transaction, Psbt, Sequence, -}; +use miniscript::bitcoin::{absolute, transaction, Psbt, Sequence}; use miniscript::psbt::PsbtExt; use rand_core::RngCore; -use crate::{apply_anti_fee_sniping, Finalizer, Input, Output}; +use crate::{apply_anti_fee_sniping, AntiFeeSnipingError, Finalizer, Input, Output}; /// Final selection of inputs and outputs. #[derive(Debug, Clone)] @@ -27,15 +24,17 @@ pub struct PsbtParams { /// Use a specific [`transaction::Version`]. pub version: transaction::Version, - /// Fallback tx locktime. + /// Minimum tx locktime — a floor on the resulting `tx.lock_time`. /// - /// The locktime to use if no input specifies a required absolute locktime. - /// - /// It is best practice to set this to the latest block height to avoid fee sniping. - pub fallback_locktime: absolute::LockTime, + /// The final `tx.lock_time` is the maximum of this value and any absolute locktime required by + /// an input's CLTV, provided the locktime units agree. If `min_locktime` uses a different unit + /// (block-height vs. time) than an input's CLTV, it is ignored — a height-based `min_locktime` + /// will not be combined with a time-based CLTV (and vice versa). + pub min_locktime: absolute::LockTime, - /// Whether to require the full tx, aka [`non_witness_utxo`] for segwit v0 inputs, - /// default is `true`. + /// Whether to require the full tx, aka [`non_witness_utxo`] for segwit v0 inputs. + /// + /// Default is `true`. /// /// [`non_witness_utxo`]: bitcoin::psbt::Input::non_witness_utxo pub mandate_full_tx_for_segwit_v0: bool, @@ -48,66 +47,45 @@ pub struct PsbtParams { /// cover all of the outputs). pub sighash_type: Option, - /// Whether to use BIP326 anti-fee-sniping protection. - /// - /// When enabled, the transaction's nLockTime or nSequence will be set to indicate - /// the transaction should only be valid at or after the current block height. - /// This discourages miners from reorganizing recent blocks to capture fees. + /// Apply BIP-326 anti-fee-sniping (AFS) protection, using the given block height. /// - /// # Assumptions - /// - The current height is determined by the transaction's locktime (must be a block height) - /// - Transaction version must be >= 2 to support relative locktimes + /// * `None` (default) — no AFS is applied. + /// * `Some(tip_height)` — AFS is applied with `tip_height` as the current chain tip. /// - /// # Effects on Transaction - /// When enabled, this will modify the transaction in one of two ways: - /// - **nLockTime approach**: Sets `tx.lock_time` to current height (possibly with random offset) - /// - **nSequence approach**: Sets sequence on a randomly selected Taproot input to current - /// confirmation depth (possibly with random offset) + /// AFS discourages miners from reorganizing recent blocks to capture fees by constraining the + /// transaction to only be valid at or after the chain tip. When enabled, + /// [`Selection::create_psbt`] sets either the transaction's `nLockTime` or the `nSequence` of + /// one Taproot input to a value derived from `tip_height`. /// - /// The choice between approaches is randomized based on BIP326 probabilities, with - /// certain conditions forcing nLockTime usage (unconfirmed inputs, non-Taproot inputs, - /// RBF disabled, etc.). + /// AFS only operates on a height-based `tx.lock_time`. If [`min_locktime`] or any input's + /// CLTV is time-based, enabling AFS produces [`AntiFeeSnipingError::UnsupportedLockTime`]. /// - /// # Error Cases - /// - Returns [`CreatePsbtError::InvalidLockTime`] if the locktime is not a block height - /// - Returns [`CreatePsbtError::UnsupportedVersion`] if transaction version is less than 2 + /// If `tx.lock_time` is already a block height greater than `tip_height` (e.g., because an + /// input's CLTV pins the tx to a future block), AFS leaves the transaction unchanged — the + /// existing CLTV already provides equivalent protection. /// - /// # Default - /// - Disabled by default (`false`). + /// # Errors /// - /// # Example - /// ``` - /// use miniscript::bitcoin::absolute::{LockTime, Height}; - /// use bdk_tx::{PsbtParams, Selection, Output}; - /// - /// fn main() -> Result<(), Box> { - /// let params = PsbtParams { - /// fallback_locktime: LockTime::from_height(800000).expect("valid height"), - /// enable_anti_fee_sniping: true, - /// ..PsbtParams::default() - /// }; - /// let selection = Selection { - /// inputs: vec![], /* Inputs */ - /// outputs: vec![], /* Outputs */ - /// }; - /// let psbt = selection.create_psbt(params)?; - /// // the resulting transaction will have anti-fee-sniping applied. - /// Ok(()) - /// } - /// ``` + /// When `Some(..)`, [`Selection::create_psbt`] returns [`CreatePsbtError::AntiFeeSniping`] if: + /// - the transaction version is less than 2 + /// ([`AntiFeeSnipingError::UnsupportedVersion`]) — v2 is required for relative locktimes; or + /// - a time-based (MTP) locktime is in effect + /// ([`AntiFeeSnipingError::UnsupportedLockTime`]) — AFS only supports height-based locktimes. /// /// See [BIP326](https://github.com/bitcoin/bips/blob/master/bip-0326.mediawiki) for more details. - pub enable_anti_fee_sniping: bool, + /// + /// [`min_locktime`]: Self::min_locktime + pub anti_fee_sniping: Option, } impl Default for PsbtParams { fn default() -> Self { Self { version: transaction::Version::TWO, - fallback_locktime: absolute::LockTime::ZERO, + min_locktime: absolute::LockTime::ZERO, mandate_full_tx_for_segwit_v0: true, sighash_type: None, - enable_anti_fee_sniping: false, + anti_fee_sniping: None, } } } @@ -125,10 +103,14 @@ pub enum CreatePsbtError { Psbt(bitcoin::psbt::Error), /// Update psbt output with descriptor error. OutputUpdate(miniscript::psbt::OutputUpdateError), - /// Invalid locktime - InvalidLockTime(absolute::LockTime), - /// Unsupported version for anti fee snipping - UnsupportedVersion(transaction::Version), + /// Occurs when applying anti-fee-sniping fails. + AntiFeeSniping(AntiFeeSnipingError), +} + +impl From for CreatePsbtError { + fn from(e: AntiFeeSnipingError) -> Self { + Self::AntiFeeSniping(e) + } } impl core::fmt::Display for CreatePsbtError { @@ -149,12 +131,7 @@ impl core::fmt::Display for CreatePsbtError { CreatePsbtError::OutputUpdate(output_update_error) => { Display::fmt(&output_update_error, f) } - CreatePsbtError::InvalidLockTime(locktime) => { - write!(f, "The locktime - {}, is invalid", locktime) - } - CreatePsbtError::UnsupportedVersion(version) => { - write!(f, "Unsupported version {}", version) - } + CreatePsbtError::AntiFeeSniping(e) => Display::fmt(e, f), } } } @@ -165,18 +142,18 @@ impl std::error::Error for CreatePsbtError {} impl Selection { /// Accumulates the maximum locktime from an iterator of input-required locktimes. /// - /// Returns the `fallback_locktime` if the locktimes iterator is empty, `Ok(lock_time)` with + /// Returns the `min_locktime` if the locktimes iterator is empty, `Ok(lock_time)` with /// the maximum locktime if all items share the same unit. Errors if there is a mismatch of /// lock type units among the required locktimes. fn accumulate_max_locktime( locktimes: impl IntoIterator, - fallback_locktime: absolute::LockTime, + min_locktime: absolute::LockTime, ) -> Result { // Accumulate locktimes required by inputs. An input-vs-input unit mismatch is an error. - // The fallback is only used when it is compatible with the input requirements. - // If the fallback is a different unit from the required locktime it is - // intentionally ignored so that a height-based fallback does not conflict with a - // time-based CLTV requirement. + // `min_locktime` is only used when it is compatible with the input requirements. + // If it is a different unit from the required locktime it is intentionally ignored + // so that a height-based `min_locktime` does not conflict with a time-based CLTV + // requirement. let mut acc = Option::::None; for locktime in locktimes { match &mut acc { @@ -192,17 +169,17 @@ impl Selection { }; } match acc { - // No required locktimes from inputs: use fallback directly. - None => Ok(fallback_locktime), - // Same unit as fallback: take the maximum of required and fallback. - Some(lock_time) if lock_time.is_same_unit(fallback_locktime) => { - if lock_time.is_implied_by(fallback_locktime) { - Ok(fallback_locktime) + // No required locktimes from inputs: use `min_locktime` directly. + None => Ok(min_locktime), + // Same unit as `min_locktime`: take the maximum of required and `min_locktime`. + Some(lock_time) if lock_time.is_same_unit(min_locktime) => { + if lock_time.is_implied_by(min_locktime) { + Ok(min_locktime) } else { Ok(lock_time) } } - // Fallback is a different unit: use required locktime and ignore fallback. + // `min_locktime` is a different unit: use required locktime and ignore it. Some(lock_time) => Ok(lock_time), } } @@ -225,7 +202,7 @@ impl Selection { self.inputs .iter() .filter_map(|input| input.absolute_timelock()), - params.fallback_locktime, + params.min_locktime, )?, input: self .inputs @@ -239,16 +216,8 @@ impl Selection { output: self.outputs.iter().map(|output| output.txout()).collect(), }; - if params.enable_anti_fee_sniping { - let rbf_enabled = tx.is_explicitly_rbf(); - let current_height = match tx.lock_time { - LockTime::Blocks(height) => height, - LockTime::Seconds(_) => { - return Err(CreatePsbtError::InvalidLockTime(tx.lock_time)); - } - }; - - apply_anti_fee_sniping(&mut tx, &self.inputs, current_height, rbf_enabled, rng)?; + if let Some(tip_height) = params.anti_fee_sniping { + apply_anti_fee_sniping(&mut tx, &self.inputs, tip_height, rng)?; }; let mut psbt = Psbt::from_unsigned_tx(tx).map_err(CreatePsbtError::Psbt)?; @@ -315,30 +284,30 @@ impl Selection { mod tests { use super::*; use bitcoin::{ - absolute::{self, Height, Time}, + absolute::{self, LockTime, Time}, + relative, secp256k1::Secp256k1, transaction::{self, Version}, - Amount, ScriptBuf, Transaction, TxIn, TxOut, + Amount, ScriptBuf, Sequence, Transaction, TxIn, TxOut, }; use miniscript::{plan::Assets, Descriptor, DescriptorPublicKey}; use rand_core::OsRng; const TEST_DESCRIPTOR: &str = "tr([83737d5e/86h/1h/0h]tpubDDR5GgtoxS8fJyjjvdahN4VzV5DV6jtbcyvVXhEKq2XtpxjxBXmxH3r8QrNbQqHg4bJM1EGkxi7Pjfkgnui9jQWqS7kxHvX6rhUeriLDKxz/0/*)"; const TEST_DESCRIPTOR_PK: &str = "[83737d5e/86h/1h/0h]tpubDDR5GgtoxS8fJyjjvdahN4VzV5DV6jtbcyvVXhEKq2XtpxjxBXmxH3r8QrNbQqHg4bJM1EGkxi7Pjfkgnui9jQWqS7kxHvX6rhUeriLDKxz/0/*"; + const TEST_HEX_PK: &str = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3"; - #[test] - fn test_fallback_locktime_height() -> anyhow::Result<()> { - let abs_locktime = absolute::LockTime::from_consensus(100_000); + fn setup_cltv_input( + cltv: absolute::LockTime, + ) -> anyhow::Result<(Input, Descriptor)> { let secp = Secp256k1::new(); - let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3"; - let desc_str = format!("wsh(and_v(v:pk({pk}),after({abs_locktime})))"); - let desc_pk: DescriptorPublicKey = pk.parse()?; + let desc_str = format!("wsh(and_v(v:pk({TEST_HEX_PK}),after({cltv})))"); + let desc_pk: DescriptorPublicKey = TEST_HEX_PK.parse()?; let (desc, _) = Descriptor::parse_descriptor(&secp, &desc_str)?; let plan = desc .at_derivation_index(0)? - .plan(&Assets::new().add(desc_pk).after(abs_locktime)) + .plan(&Assets::new().add(desc_pk).after(cltv)) .unwrap(); - let prev_tx = Transaction { version: transaction::Version::TWO, lock_time: absolute::LockTime::ZERO, @@ -349,6 +318,14 @@ mod tests { }], }; let input = Input::from_prev_tx(plan, prev_tx, 0, None)?; + Ok((input, desc)) + } + + #[test] + fn test_min_locktime_height() -> anyhow::Result<()> { + let abs_locktime = absolute::LockTime::from_consensus(100_000); + + let (input, desc) = setup_cltv_input(abs_locktime)?; let selection = Selection { inputs: vec![input], @@ -366,22 +343,22 @@ mod tests { let cases = vec![ TestCase { - name: "no fallback locktime, use plan locktime", + name: "no min_locktime, use plan locktime", psbt_params: PsbtParams::default(), exp_locktime: 100_000, }, TestCase { - name: "larger fallback locktime is used", + name: "larger min_locktime is used", psbt_params: PsbtParams { - fallback_locktime: absolute::LockTime::from_consensus(100_100), + min_locktime: absolute::LockTime::from_consensus(100_100), ..Default::default() }, exp_locktime: 100_100, }, TestCase { - name: "smaller fallback locktime is ignored", + name: "smaller min_locktime is ignored", psbt_params: PsbtParams { - fallback_locktime: absolute::LockTime::from_consensus(99_900), + min_locktime: absolute::LockTime::from_consensus(99_900), ..Default::default() }, exp_locktime: 100_000, @@ -401,32 +378,14 @@ mod tests { Ok(()) } - /// Tests that a height-based fallback locktime is ignored when the input + /// Tests that a height-based `min_locktime` is ignored when the input /// requires a time-based (UNIX timestamp) CLTV, and that an explicit time-based - /// fallback greater than the requirement is respected. + /// `min_locktime` greater than the requirement is respected. #[test] - fn test_fallback_locktime_respects_lock_type() -> anyhow::Result<()> { + fn test_min_locktime_respects_lock_type() -> anyhow::Result<()> { let time_locktime = absolute::LockTime::from_consensus(1_734_230_218); - let secp = Secp256k1::new(); - let pk = "032b0558078bec38694a84933d659303e2575dae7e91685911454115bfd64487e3"; - let desc_str = format!("wsh(and_v(v:pk({pk}),after({time_locktime})))"); - let desc_pk: DescriptorPublicKey = pk.parse()?; - let (desc, _) = Descriptor::parse_descriptor(&secp, &desc_str)?; - let plan = desc - .at_derivation_index(0)? - .plan(&Assets::new().add(desc_pk).after(time_locktime)) - .unwrap(); - let prev_tx = Transaction { - version: transaction::Version::TWO, - lock_time: absolute::LockTime::ZERO, - input: vec![TxIn::default()], - output: vec![TxOut { - script_pubkey: desc.at_derivation_index(0)?.script_pubkey(), - value: Amount::ONE_BTC, - }], - }; - let input = Input::from_prev_tx(plan, prev_tx, 0, None)?; + let (input, desc) = setup_cltv_input(time_locktime)?; let selection = Selection { inputs: vec![input], @@ -436,24 +395,24 @@ mod tests { )], }; - // Default fallback is height 0 (block-height unit). It is incompatible with the - // time-based CLTV requirement, so it must be ignored. + // Default `min_locktime` is height 0 (block-height unit). It is incompatible with + // the time-based CLTV requirement, so it must be ignored. let psbt = selection.create_psbt(PsbtParams::default())?; assert_eq!( psbt.unsigned_tx.lock_time, time_locktime, - "time-based CLTV requirement should be used; height-based fallback must be ignored", + "time-based CLTV requirement should be used; height-based `min_locktime` must be ignored", ); - // An explicit time-based fallback *greater* than the requirement should be respected. + // An explicit time-based `min_locktime` *greater* than the requirement should be respected. let larger_time = absolute::LockTime::from_consensus(1_772_167_108); assert!(larger_time > time_locktime); let psbt = selection.create_psbt(PsbtParams { - fallback_locktime: larger_time, + min_locktime: larger_time, ..Default::default() })?; assert_eq!( psbt.unsigned_tx.lock_time, larger_time, - "a larger time-based fallback should override the CLTV requirement", + "a larger time-based `min_locktime` should override the CLTV requirement", ); Ok(()) @@ -502,7 +461,7 @@ mod tests { // Disabled - default behavior is disable let psbt = selection.create_psbt(PsbtParams { - fallback_locktime: absolute::LockTime::from_consensus(current_height), + min_locktime: absolute::LockTime::from_consensus(current_height), ..Default::default() })?; let tx = psbt.unsigned_tx; @@ -512,33 +471,10 @@ mod tests { } #[test] - fn test_anti_fee_sniping_invalid_locktime_error() -> anyhow::Result<()> { - let input = setup_test_input(2_000).unwrap(); - let output = Output::with_script(ScriptBuf::new(), Amount::from_sat(9_000)); - let selection = Selection { - inputs: vec![input], - outputs: vec![output], - }; - - // Use time-based locktime instead of height-based - let result = selection.create_psbt(PsbtParams { - fallback_locktime: LockTime::from_consensus(500_000_000), // Time-based - enable_anti_fee_sniping: true, - ..Default::default() - }); - - assert!( - matches!(result, Err(CreatePsbtError::InvalidLockTime(_))), - "should return InvalidLockTime error for time-based locktime" - ); - - Ok(()) - } - - #[test] - fn test_anti_fee_sniping_protection() { + fn test_anti_fee_sniping_protection() -> anyhow::Result<()> { let current_height = 2_500; - let input = setup_test_input(2_000).unwrap(); + let tip = absolute::Height::from_consensus(current_height)?; + let input = setup_test_input(2_000)?; let mut used_locktime = false; let mut used_sequence = false; @@ -550,20 +486,19 @@ mod tests { inputs: vec![input.clone()], outputs: vec![output], }; - let psbt = selection - .create_psbt(PsbtParams { - fallback_locktime: absolute::LockTime::from_consensus(current_height), - enable_anti_fee_sniping: true, - ..Default::default() - }) - .unwrap(); + + let psbt = selection.create_psbt(PsbtParams { + anti_fee_sniping: Some(tip), + ..Default::default() + })?; + let tx = psbt.unsigned_tx; if tx.lock_time > absolute::LockTime::ZERO { used_locktime = true; let locktime_value = tx.lock_time.to_consensus_u32(); let min_height = current_height.saturating_sub(100); - assert!((min_height..=current_height).contains(&tx.lock_time.to_consensus_u32())); + assert!((min_height..=current_height).contains(&locktime_value)); assert!(locktime_value <= current_height); assert!(locktime_value >= current_height.saturating_sub(100)); } else { @@ -580,16 +515,15 @@ mod tests { } loops += 1; - assert!( - loops < 20, - "Failed to observe both behaviors within reasonable attempts" - ); + assert!(loops < 20, "Failed to observe both behaviors"); } + Ok(()) } #[test] fn test_anti_fee_sniping_multiple_taproot_inputs() { let current_height = 3_000; + let tip = absolute::Height::from_consensus(current_height).unwrap(); let input1 = setup_test_input(2_500).unwrap(); let input2 = setup_test_input(2_700).unwrap(); let input3 = setup_test_input(3_000).unwrap(); @@ -606,11 +540,11 @@ mod tests { }; let psbt = selection .create_psbt(PsbtParams { - fallback_locktime: absolute::LockTime::from_consensus(current_height), - enable_anti_fee_sniping: true, + anti_fee_sniping: Some(tip), ..Default::default() }) .unwrap(); + let tx = psbt.unsigned_tx; if tx.lock_time > absolute::LockTime::ZERO { @@ -619,7 +553,8 @@ mod tests { used_sequence = true; // One of the inputs should have modified sequence let has_modified_sequence = tx.input.iter().any(|txin| { - txin.sequence.to_consensus_u32() > 0 && txin.sequence.to_consensus_u32() < 65535 + let seq = txin.sequence.to_consensus_u32(); + seq > 0 && seq < 65_535 }); assert!(has_modified_sequence); } @@ -632,6 +567,158 @@ mod tests { } } + /// Regression: pre-fix, the AFS nLockTime path could overwrite `tx.lock_time` with a value + /// lower than an input's required CLTV. + #[test] + fn test_anti_fee_sniping_preserves_input_cltv() -> anyhow::Result<()> { + let cltv = absolute::LockTime::from_consensus(100_000); + let (input, desc) = setup_cltv_input(cltv)?; + // Tip is well below the input's CLTV requirement. + let tip = absolute::Height::from_consensus(50_000)?; + + let selection = Selection { + inputs: vec![input], + outputs: vec![Output::with_descriptor( + desc.at_derivation_index(1)?, + Amount::from_sat(1000), + )], + }; + + // The input is wsh (not Taproot), so AFS deterministically takes the locktime path; loop a + // few times anyway as cheap insurance against future control-flow changes. + for _ in 0..100 { + let psbt = selection.create_psbt(PsbtParams { + anti_fee_sniping: Some(tip), + ..Default::default() + })?; + assert_eq!( + psbt.unsigned_tx.lock_time, cltv, + "AFS must not overwrite an input's CLTV with a lower value", + ); + } + + Ok(()) + } + + /// Regression: pre-fix, the AFS nSequence path could pick a Taproot input that already carried + /// a CSV (relative-timelock) requirement and overwrite its sequence. The presence of a regular + /// Taproot input ensures the sequence path remains reachable — so the test also catches a + /// regression where AFS degrades to "never use the sequence path." + #[test] + fn test_anti_fee_sniping_skips_taproot_csv_input() -> anyhow::Result<()> { + let tip = absolute::Height::from_consensus(3_000)?; + let csv_blocks = 10; + + // Input A: regular Taproot, no CSV. + let regular_input = setup_test_input(2_500)?; + let regular_outpoint = regular_input.prev_outpoint(); + + // Input B: Taproot whose script-path requires CSV. The internal key is omitted from + // `assets`, forcing planning to use the script-path leaf (which sets + // `plan.relative_timelock`). + let secp = Secp256k1::new(); + let desc_str = + format!("tr({TEST_HEX_PK},and_v(v:pk({TEST_DESCRIPTOR_PK}),older({csv_blocks})))"); + let desc = Descriptor::parse_descriptor(&secp, &desc_str)? + .0 + .at_derivation_index(0)?; + let prev_tx = Transaction { + version: Version::TWO, + lock_time: LockTime::ZERO, + input: vec![TxIn::default()], + output: vec![TxOut { + script_pubkey: desc.script_pubkey(), + value: Amount::from_sat(10_000), + }], + }; + let assets = Assets::new() + .add(TEST_DESCRIPTOR_PK.parse::()?) + .older(relative::LockTime::from_height(csv_blocks)); + let plan = desc.plan(&assets).expect("script-path plan with CSV"); + let status = crate::ConfirmationStatus { + height: absolute::Height::from_consensus(2_500)?, + prev_mtp: Some(Time::from_consensus(500_000_000)?), + }; + let csv_input = Input::from_prev_tx(plan, prev_tx, 0, Some(status))?; + let csv_outpoint = csv_input.prev_outpoint(); + let csv_sequence = csv_input.sequence().expect("plan-derived sequence"); + + let output = Output::with_script(ScriptBuf::new(), Amount::from_sat(18_000)); + + // We will run AFS for 100 rounds. + // Track whether AFS's nSequence path actually fired for at least one of the rounds. + let mut observed_sequence_path = false; + + for _ in 0..100 { + let selection = Selection { + inputs: vec![regular_input.clone(), csv_input.clone()], + outputs: vec![output.clone()], + }; + let psbt = selection.create_psbt(PsbtParams { + anti_fee_sniping: Some(tip), + ..Default::default() + })?; + let tx = psbt.unsigned_tx; + + let csv_txin = tx + .input + .iter() + .find(|t| t.previous_output == csv_outpoint) + .expect("csv input must be present"); + assert_eq!( + csv_txin.sequence, csv_sequence, + "AFS must not overwrite the sequence of a CSV-bearing Taproot input", + ); + + let regular_txin = tx + .input + .iter() + .find(|t| t.previous_output == regular_outpoint) + .expect("regular input must be present"); + if regular_txin.sequence != Sequence::ENABLE_RBF_NO_LOCKTIME { + observed_sequence_path = true; + } + } + + assert!( + observed_sequence_path, + "AFS nSequence path must fire at least once across the 100 trials (otherwise the \ + CSV-preservation check above doesn't exercise the candidate-pool exclusion)", + ); + + Ok(()) + } + + /// A time-based CLTV propagates to `tx.lock_time`; AFS only supports height-based locktimes, so + /// it must surface `UnsupportedLockTime`. + #[test] + fn test_anti_fee_sniping_rejects_time_based_locktime() -> anyhow::Result<()> { + let time_locktime = absolute::LockTime::from_consensus(1_734_230_218); + let (input, desc) = setup_cltv_input(time_locktime)?; + let tip = absolute::Height::from_consensus(800_000)?; + + let selection = Selection { + inputs: vec![input], + outputs: vec![Output::with_descriptor( + desc.at_derivation_index(1)?, + Amount::from_sat(1000), + )], + }; + + let result = selection.create_psbt(PsbtParams { + anti_fee_sniping: Some(tip), + ..Default::default() + }); + + assert!(matches!( + result, + Err(CreatePsbtError::AntiFeeSniping(AntiFeeSnipingError::UnsupportedLockTime(lt))) + if lt == time_locktime + )); + + Ok(()) + } + #[test] fn test_anti_fee_sniping_unsupported_version_error() { let confirmation_height = 800_000; @@ -649,11 +736,10 @@ mod tests { output: vec![], }; - let current_height = Height::from_consensus(800_050).unwrap(); - let result = apply_anti_fee_sniping(&mut tx, &inputs, current_height, true, &mut OsRng); + let result = apply_anti_fee_sniping(&mut tx, &inputs, current_height, &mut OsRng); assert!( - matches!(result, Err(CreatePsbtError::UnsupportedVersion(_))), + matches!(result, Err(AntiFeeSnipingError::UnsupportedVersion(_))), "should return UnsupportedVersion error for version < 2" ); } diff --git a/src/utils.rs b/src/utils.rs index 9fcfe67..2fd5eaa 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,20 +1,49 @@ -use crate::{CreatePsbtError, Input}; +use crate::Input; use alloc::vec::Vec; use miniscript::bitcoin::{ absolute::{self, LockTime}, transaction::Version, Sequence, Transaction, }; -#[cfg(feature = "std")] -use rand::Rng; use rand_core::RngCore; +/// Error returned by `apply_anti_fee_sniping`. +#[derive(Debug, Clone, PartialEq)] +pub enum AntiFeeSnipingError { + /// Transaction `version` must be >= 2 for AFS to use relative locktimes. + UnsupportedVersion(Version), + /// AFS only supports height-based locktimes. The transaction's locktime is + /// time-based (MTP), which can originate from either `PsbtParams::min_locktime` + /// or an input's time-based CLTV requirement. + UnsupportedLockTime(absolute::LockTime), +} + +impl core::fmt::Display for AntiFeeSnipingError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::UnsupportedVersion(version) => write!( + f, + "anti-fee-sniping requires tx.version >= 2 (got {version})" + ), + Self::UnsupportedLockTime(locktime) => write!( + f, + "anti-fee-sniping requires a height-based tx locktime (got time-based {locktime}); \ + check `min_locktime` and any input CLTV requirements" + ), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for AntiFeeSnipingError {} + /// Applies BIP326 anti-fee-sniping protection to a transaction. /// /// Anti-fee-sniping makes transaction replay attacks less profitable by setting /// either nLockTime or nSequence to indicate the transaction should only be valid /// at or after the current block height. This discourages miners from attempting -/// to reorganize recent blocks to claim fees from transactions. +/// to reorganize recent blocks to claim fees from transactions. It must be called +/// **before** the PSBT is signed. /// /// # Strategy /// The function randomly chooses between two approaches: @@ -27,47 +56,30 @@ use rand_core::RngCore; /// # Parameters /// - `tx`: The transaction to modify /// - `inputs`: The inputs associated with the transaction -/// - `current_height`: The current blockchain height (used as the base for time locks) -/// - `rbf_enabled`: Whether Replace-By-Fee is enabled (affects strategy selection) +/// - `tip_height`: The current blockchain height (used as the base for time locks) /// - `rng`: Random number generator implementing `RngCore` /// -/// # Errors -/// Returns an error if: -/// - Transaction version is less than 2 [`CreatePsbtError::UnsupportedVersion`] +/// # Behavior with existing locktime constraints /// -/// # Example -/// ```ignore -/// # use bdk_tx::Input; -/// # use miniscript::bitcoin::{ -/// # absolute::{Height, LockTime}, transaction::Version, Transaction, TxIn, TxOut, ScriptBuf, Amount -/// # }; -/// # use rand_core::OsRng; +/// If `tx.lock_time` is already a block height greater than the AFS target +/// (e.g., because an input's CLTV pins the transaction to a future height), +/// this function leaves `tx.lock_time` untouched and returns `Ok(())`. The +/// existing CLTV already prevents inclusion before `tip_height + 1`, so AFS +/// is implicitly satisfied. /// -/// fn main() -> Result<(), Box> { -/// let inputs: Vec = vec![]; -/// let mut tx = Transaction { -/// version: Version::TWO, -/// lock_time: LockTime::from_height(800_000)?, -/// input: vec![/* corresponding TxIns */], -/// output: vec![/* your outputs */], -/// }; -/// let current_height = Height::from_consensus(800_000)?; -/// let mut rng = OsRng; -/// apply_anti_fee_sniping(&mut tx, &inputs, current_height, true, &mut rng)?; -/// // tx now has anti-fee-sniping protection applied -/// Ok(()) -/// } -/// ``` +/// # Errors +/// - [`AntiFeeSnipingError::UnsupportedVersion`] if `tx.version < 2`. +/// - [`AntiFeeSnipingError::UnsupportedLockTime`] if `tx.lock_time` is time-based +/// (either from `PsbtParams::min_locktime` or an input's time-based CLTV). /// /// # See Also /// [BIP326](https://github.com/bitcoin/bips/blob/master/bip-0326.mediawiki) -pub fn apply_anti_fee_sniping( +pub(crate) fn apply_anti_fee_sniping( tx: &mut Transaction, inputs: &[Input], - current_height: absolute::Height, - rbf_enabled: bool, + tip_height: absolute::Height, rng: &mut impl RngCore, -) -> Result<(), CreatePsbtError> { +) -> Result<(), AntiFeeSnipingError> { const MAX_RELATIVE_HEIGHT: u32 = 65_535; const FIFTY_PERCENT_PROBABILITY_RANGE: u32 = 2; const MIN_SEQUENCE_VALUE: u32 = 1; @@ -75,9 +87,15 @@ pub fn apply_anti_fee_sniping( const MAX_RANDOM_OFFSET: u32 = 100; if tx.version < Version::TWO { - return Err(CreatePsbtError::UnsupportedVersion(tx.version)); + return Err(AntiFeeSnipingError::UnsupportedVersion(tx.version)); } + if !tx.lock_time.is_block_height() { + return Err(AntiFeeSnipingError::UnsupportedLockTime(tx.lock_time)); + } + + let rbf_enabled = tx.is_explicitly_rbf(); + // vector of input_index and associated Input ref. let taproot_inputs: Vec<(usize, &Input)> = tx .input @@ -87,7 +105,7 @@ pub fn apply_anti_fee_sniping( let input = inputs .iter() .find(|input| input.prev_outpoint() == txin.previous_output)?; - if input.prev_txout().script_pubkey.is_p2tr() { + if input.prev_txout().script_pubkey.is_p2tr() && input.relative_timelock().is_none() { Some((vin, input)) } else { None @@ -95,14 +113,12 @@ pub fn apply_anti_fee_sniping( }) .collect(); - // Check always‐locktime conditions - let must_use_locktime = inputs.iter().any(|input| { - let confirmation = input.confirmations(current_height); - confirmation == 0 - || confirmation > MAX_RELATIVE_HEIGHT - || !input.prev_txout().script_pubkey.is_p2tr() - }); - + // Conditions that force nLockTime (vs nSequence). + let must_use_locktime = taproot_inputs.is_empty() + || inputs.iter().any(|input| { + let confirmation = input.confirmations(tip_height); + confirmation == 0 || confirmation > MAX_RELATIVE_HEIGHT + }); let use_locktime = !rbf_enabled || must_use_locktime || taproot_inputs.is_empty() @@ -110,22 +126,23 @@ pub fn apply_anti_fee_sniping( if use_locktime { // Use nLockTime - let mut locktime = current_height.to_consensus_u32(); + let mut afs_height = tip_height.to_consensus_u32(); if random_probability(rng, TEN_PERCENT_PROBABILITY_RANGE) { let random_offset = random_range(rng, MAX_RANDOM_OFFSET); - locktime = locktime.saturating_sub(random_offset); + afs_height = afs_height.saturating_sub(random_offset); } - let new_locktime = LockTime::from_height(locktime).expect("must be valid Height"); + let afs_locktime = LockTime::from_height(afs_height).expect("must be valid Height"); - tx.lock_time = new_locktime; + if tx.lock_time.is_implied_by(afs_locktime) { + tx.lock_time = afs_locktime; + } } else { // Use Sequence - tx.lock_time = LockTime::ZERO; let random_index = random_range(rng, taproot_inputs.len() as u32); let (input_index, input) = taproot_inputs[random_index as usize]; - let confirmation = input.confirmations(current_height); + let confirmation = input.confirmations(tip_height); let mut sequence_value = confirmation; if random_probability(rng, TEN_PERCENT_PROBABILITY_RANGE) { @@ -142,34 +159,11 @@ pub fn apply_anti_fee_sniping( } /// Returns true with probability 1/n. -#[cfg(feature = "std")] -fn random_probability(rng: &mut impl RngCore, n: u32) -> bool { - rng.gen_bool(1.0 / n as f64) -} - -/// Returns true with probability 1/n. -/// -/// This `no-std` implementation avoids depending on the full `rand` crate, -/// keeping the dependency tree minimal while supporting `no-std` environments -/// through `rand_core` alone. -#[cfg(not(feature = "std"))] fn random_probability(rng: &mut impl RngCore, n: u32) -> bool { random_range(rng, n) == 0 } -/// Returns a random value in the range [0, n). -#[cfg(feature = "std")] -fn random_range(rng: &mut impl RngCore, n: u32) -> u32 { - rng.gen_range(0..n) -} - -/// Returns a random value in the range [0, n) using unbiased sampling. -/// -/// This `no-std` implementation uses rejection sampling to ensure uniform -/// distribution and avoid modulo bias, without depending on the full `rand` crate. -/// This keeps the dependency tree minimal while supporting `no-std` environments -/// through `rand_core` alone. -#[cfg(not(feature = "std"))] +/// Returns a random value in the range [0, n) using unbiased rejection sampling. fn random_range(rng: &mut impl RngCore, n: u32) -> u32 { let threshold = n.wrapping_neg() % n;