Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions smite/src/bolt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod shutdown;
mod tlv;
mod tx_abort;
mod tx_ack_rbf;
mod tx_add_output;
mod tx_complete;
mod tx_init_rbf;
mod tx_remove_input;
Expand Down Expand Up @@ -48,6 +49,7 @@ pub use shutdown::Shutdown;
pub use tlv::{TlvRecord, TlvStream};
pub use tx_abort::TxAbort;
pub use tx_ack_rbf::{TxAckRbf, TxAckRbfTlvs};
pub use tx_add_output::TxAddOutput;
pub use tx_complete::TxComplete;
pub use tx_init_rbf::{TxInitRbf, TxInitRbfTlvs};
pub use tx_remove_input::TxRemoveInput;
Expand Down Expand Up @@ -127,6 +129,8 @@ pub mod msg_type {
pub const OPEN_CHANNEL2: u16 = 64;
/// `accept_channel2` message (BOLT 2).
pub const ACCEPT_CHANNEL2: u16 = 65;
/// `tx_add_output` message (BOLT 2).
pub const TX_ADD_OUTPUT: u16 = 67;
/// `tx_remove_input` message (BOLT 2).
pub const TX_REMOVE_INPUT: u16 = 68;
/// `tx_remove_output` message (BOLT 2).
Expand Down Expand Up @@ -179,6 +183,8 @@ pub enum Message {
OpenChannel2(OpenChannel2),
/// `accept_channel2` message (type 65).
AcceptChannel2(AcceptChannel2),
/// `tx_add_output` message (type 67).
TxAddOutput(TxAddOutput),
/// `tx_remove_input` message (type 68).
TxRemoveInput(TxRemoveInput),
/// `tx_remove_output` message (type 69).
Expand Down Expand Up @@ -229,6 +235,7 @@ impl Message {
Self::Shutdown(_) => msg_type::SHUTDOWN,
Self::OpenChannel2(_) => msg_type::OPEN_CHANNEL2,
Self::AcceptChannel2(_) => msg_type::ACCEPT_CHANNEL2,
Self::TxAddOutput(_) => msg_type::TX_ADD_OUTPUT,
Self::TxRemoveInput(_) => msg_type::TX_REMOVE_INPUT,
Self::TxRemoveOutput(_) => msg_type::TX_REMOVE_OUTPUT,
Self::TxComplete(_) => msg_type::TX_COMPLETE,
Expand Down Expand Up @@ -262,6 +269,7 @@ impl Message {
Self::Shutdown(m) => out.extend(m.encode()),
Self::OpenChannel2(m) => out.extend(m.encode()),
Self::AcceptChannel2(m) => out.extend(m.encode()),
Self::TxAddOutput(m) => out.extend(m.encode()),
Self::TxRemoveInput(m) => out.extend(m.encode()),
Self::TxRemoveOutput(m) => out.extend(m.encode()),
Self::TxComplete(m) => out.extend(m.encode()),
Expand Down Expand Up @@ -302,6 +310,7 @@ impl Message {
msg_type::SHUTDOWN => Ok(Self::Shutdown(Shutdown::decode(cursor)?)),
msg_type::OPEN_CHANNEL2 => Ok(Self::OpenChannel2(OpenChannel2::decode(cursor)?)),
msg_type::ACCEPT_CHANNEL2 => Ok(Self::AcceptChannel2(AcceptChannel2::decode(cursor)?)),
msg_type::TX_ADD_OUTPUT => Ok(Self::TxAddOutput(TxAddOutput::decode(cursor)?)),
msg_type::TX_REMOVE_INPUT => Ok(Self::TxRemoveInput(TxRemoveInput::decode(cursor)?)),
msg_type::TX_REMOVE_OUTPUT => Ok(Self::TxRemoveOutput(TxRemoveOutput::decode(cursor)?)),
msg_type::TX_COMPLETE => Ok(Self::TxComplete(TxComplete::decode(cursor)?)),
Expand Down Expand Up @@ -630,6 +639,25 @@ mod tests {
assert_eq!(decoded, Message::AcceptChannel2(accept2));
}

/// Valid `TxAddOutput` message for testing.
fn sample_tx_add_output() -> TxAddOutput {
TxAddOutput {
channel_id: ChannelId::new([0xab; CHANNEL_ID_SIZE]),
serial_id: 42,
sats: 100_000,
script: vec![0x76, 0xa9, 0x14, 0xab, 0xcd],
}
}

#[test]
fn message_tx_add_output_roundtrip() {
let tx_add_output = sample_tx_add_output();
let msg = Message::TxAddOutput(tx_add_output.clone());
let encoded = msg.encode();
let decoded = Message::decode(&encoded).unwrap();
assert_eq!(decoded, Message::TxAddOutput(tx_add_output));
}

#[test]
fn message_tx_remove_input_roundtrip() {
let tx_remove_input = TxRemoveInput {
Expand Down Expand Up @@ -806,6 +834,10 @@ mod tests {
Message::AcceptChannel2(sample_accept_channel2(None)).msg_type(),
msg_type::ACCEPT_CHANNEL2
);
assert_eq!(
Message::TxAddOutput(sample_tx_add_output()).msg_type(),
msg_type::TX_ADD_OUTPUT
);
assert_eq!(
Message::TxRemoveInput(TxRemoveInput {
channel_id: ChannelId::new([0; CHANNEL_ID_SIZE]),
Expand Down
173 changes: 173 additions & 0 deletions smite/src/bolt/tx_add_output.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
//! BOLT 2 `tx_add_output` message.

use super::BoltError;
use super::types::ChannelId;
use super::wire::WireFormat;

/// BOLT 2 `tx_add_output` message (type 67).
///
/// Sent during interactive transaction construction to propose adding an
/// output to the shared transaction.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TxAddOutput {
/// The channel this message pertains to
pub channel_id: ChannelId,
/// Serial ID for this output, must be even if sent by the initiator,
/// odd if sent by the non-initiator (BOLT 2 parity rule)
pub serial_id: u64,
/// The value of this output in satoshis
pub sats: u64,
/// The scriptPubKey for this output
pub script: Vec<u8>,
}

impl TxAddOutput {
/// Encodes to wire format (without message type prefix).
#[must_use]
pub fn encode(&self) -> Vec<u8> {
Comment thread
morehouse marked this conversation as resolved.
let mut out = Vec::new();
self.channel_id.write(&mut out);
self.serial_id.write(&mut out);
self.sats.write(&mut out);
self.script.write(&mut out);
out
}

/// Decodes from wire format (without message type prefix).
///
/// # Errors
///
/// Returns `Truncated` if the payload is too short.
pub fn decode(payload: &[u8]) -> Result<Self, BoltError> {
let mut cursor = payload;
let channel_id = WireFormat::read(&mut cursor)?;
let serial_id = WireFormat::read(&mut cursor)?;
let sats = WireFormat::read(&mut cursor)?;
let script: Vec<u8> = WireFormat::read(&mut cursor)?;
Ok(Self {
channel_id,
serial_id,
sats,
script,
})
}
}

#[cfg(test)]
mod tests {
use super::super::CHANNEL_ID_SIZE;
use super::*;

fn sample_msg() -> TxAddOutput {
TxAddOutput {
channel_id: ChannelId::new([0xab; CHANNEL_ID_SIZE]),
serial_id: 42,
sats: 100_000,
script: vec![0x76, 0xa9, 0x14, 0xab, 0xcd],
}
}

#[test]
fn roundtrip() {
let original = sample_msg();
let encoded = original.encode();
let decoded = TxAddOutput::decode(&encoded).unwrap();
assert_eq!(original, decoded);
}

#[test]
fn decode_ignores_trailing_bytes() {
let original = sample_msg();
let mut encoded = original.encode();
encoded.extend_from_slice(&[0xaa, 0xbb, 0xcc]);
let decoded = TxAddOutput::decode(&encoded).unwrap();
assert_eq!(decoded, original);
}

#[test]
fn decode_truncated_channel_id() {
assert_eq!(
TxAddOutput::decode(&[0x00; 20]),
Err(BoltError::Truncated {
expected: CHANNEL_ID_SIZE,
actual: 20
})
);
}

#[test]
fn decode_truncated_serial_id() {
// channel_id (32 bytes) + 4 bytes of serial_id
assert_eq!(
TxAddOutput::decode(&[0x00; CHANNEL_ID_SIZE + 4]),
Err(BoltError::Truncated {
expected: 8,
actual: 4
})
);
}

#[test]
fn decode_truncated_sats() {
// channel_id (32) + serial_id (8) + 4 bytes of sats
assert_eq!(
TxAddOutput::decode(&[0x00; CHANNEL_ID_SIZE + 8 + 4]),
Err(BoltError::Truncated {
expected: 8,
actual: 4
})
);
}

Comment thread
morehouse marked this conversation as resolved.
#[test]
fn decode_truncated_script_len() {
Comment thread
morehouse marked this conversation as resolved.
// channel_id (32) + serial_id (8) + sats (8) + only 1 byte of the 2-byte script length
let mut payload = vec![0x00u8; CHANNEL_ID_SIZE + 8 + 8];
payload.push(0x00); // only 1 byte of the 2-byte script length field
assert_eq!(
TxAddOutput::decode(&payload),
Err(BoltError::Truncated {
expected: 2,
actual: 1
})
);
}

#[test]
fn decode_truncated_script_data() {
// channel_id (32) + serial_id (8) + sats (8) + script_len=10 (2 bytes) + only 3 bytes of data
let mut payload = vec![0x00u8; CHANNEL_ID_SIZE + 8 + 8];
payload.push(0x00); // script_len high byte
payload.push(0x0a); // script_len low byte = 10
payload.extend_from_slice(&[0xde, 0xad, 0xbe]); // only 3 bytes instead of 10
assert_eq!(
TxAddOutput::decode(&payload),
Err(BoltError::Truncated {
expected: 10,
actual: 3
})
);
}

#[test]
fn decode_empty() {
assert_eq!(
TxAddOutput::decode(&[]),
Err(BoltError::Truncated {
expected: CHANNEL_ID_SIZE,
actual: 0
})
);
}

#[test]
fn roundtrip_empty_script() {
let msg = TxAddOutput {
script: vec![],
..sample_msg()
};
let encoded = msg.encode();
let decoded = TxAddOutput::decode(&encoded).unwrap();
assert_eq!(decoded, msg);
}
}
Loading