Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions msg-sim/src/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tokio::sync::oneshot;

use crate::dynch::{DynCh, DynRequestSender};
use crate::namespace::helpers::current_netns;
use crate::network::RuntimeFactory;

/// Base directory for named network namespaces.
///
Expand Down Expand Up @@ -230,6 +231,7 @@ impl NetworkNamespaceInner {
/// `setns(2)`, which is thread-local.
pub fn spawn<Ctx: 'static>(
self,
runtime_factory: RuntimeFactory,
make_ctx: impl FnOnce() -> Ctx + Send + 'static,
) -> (std::thread::JoinHandle<Result<()>>, DynRequestSender<Ctx>) {
let (tx, mut rx) = DynCh::<Ctx>::channel(8);
Expand All @@ -245,7 +247,7 @@ impl NetworkNamespaceInner {
// Create mount namespace and remount /proc for namespace-specific sysctl access
helpers::setup_mount_namespace()?;

let rt = tokio::runtime::Builder::new_current_thread().enable_all().build()?;
let rt = runtime_factory();

tracing::debug!("started runtime");
drop(_span);
Expand Down Expand Up @@ -286,6 +288,7 @@ pub struct NetworkNamespace<Ctx = ()> {
impl NetworkNamespace {
pub async fn new<Ctx: 'static>(
name: impl Into<String>,
runtime_factory: RuntimeFactory,
make_ctx: impl FnOnce() -> Ctx + Send + 'static,
) -> Result<NetworkNamespace<Ctx>> {
let name = name.into();
Expand All @@ -296,7 +299,7 @@ impl NetworkNamespace {
let file = tokio::fs::File::open(path).await?.into_std().await;

let inner = NetworkNamespaceInner { name, file };
let (_receiver_task, task_sender) = inner.try_clone()?.spawn(make_ctx);
let (_receiver_task, task_sender) = inner.try_clone()?.spawn(runtime_factory, make_ctx);

Ok(NetworkNamespace::<Ctx> { inner, task_sender, _receiver_task })
}
Expand Down Expand Up @@ -349,11 +352,19 @@ mod tests {

const TCP_SLOW_START_AFTER_IDLE: &str = "/proc/sys/net/ipv4/tcp_slow_start_after_idle";

fn default_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread().enable_all().build().expect("to create runtime")
}

#[tokio::test(flavor = "multi_thread")]
async fn mount_namespace_isolates_proc() {
// Create two namespaces
let ns1 = NetworkNamespace::new("test-ns-mount-1", || ()).await.unwrap();
let ns2 = NetworkNamespace::new("test-ns-mount-2", || ()).await.unwrap();
let ns1 = NetworkNamespace::new("test-ns-mount-1", Box::new(default_runtime), || ())
.await
.unwrap();
let ns2 = NetworkNamespace::new("test-ns-mount-2", Box::new(default_runtime), || ())
.await
.unwrap();

// Verify /proc is mounted in ns1 by checking /proc/self/ns/net exists
let proc_mounted_ns1: bool = ns1
Expand Down Expand Up @@ -385,8 +396,12 @@ mod tests {
#[tokio::test(flavor = "multi_thread")]
async fn sysctl_values_are_namespace_specific() {
// Create two namespaces
let ns1 = NetworkNamespace::new("test-ns-sysctl-1", || ()).await.unwrap();
let ns2 = NetworkNamespace::new("test-ns-sysctl-2", || ()).await.unwrap();
let ns1 = NetworkNamespace::new("test-ns-sysctl-1", Box::new(default_runtime), || ())
.await
.unwrap();
let ns2 = NetworkNamespace::new("test-ns-sysctl-2", Box::new(default_runtime), || ())
.await
.unwrap();

// Set different values in each namespace
let write_result_ns1: std::io::Result<()> = ns1
Expand Down Expand Up @@ -446,24 +461,21 @@ mod tests {

assert_eq!(value_ns1, "0", "ns1 should have tcp_slow_start_after_idle=0");
assert_eq!(value_ns2, "1", "ns2 should have tcp_slow_start_after_idle=1");
assert_ne!(
value_ns1, value_ns2,
"sysctls should be isolated between namespaces"
);
assert_ne!(value_ns1, value_ns2, "sysctls should be isolated between namespaces");
}

#[tokio::test(flavor = "multi_thread")]
async fn namespace_has_isolated_network_identity() {
// Create a namespace
let ns = NetworkNamespace::new("test-ns-identity", || ()).await.unwrap();
let ns = NetworkNamespace::new("test-ns-identity", Box::new(default_runtime), || ())
.await
.unwrap();

// Get the network namespace inode from inside the namespace
let ns_inode_inside: u64 = ns
.task_sender
.submit(|_: &mut ()| -> DynFuture<'_, u64> {
Box::pin(async {
helpers::current_netns().map(|id| id.inode).unwrap_or(0)
})
Box::pin(async { helpers::current_netns().map(|id| id.inode).unwrap_or(0) })
})
.await
.unwrap()
Expand All @@ -476,9 +488,6 @@ mod tests {

assert_ne!(ns_inode_inside, 0, "should get valid inode inside namespace");
assert_ne!(host_inode, 0, "should get valid host inode");
assert_ne!(
ns_inode_inside, host_inode,
"namespace inode should differ from host"
);
assert_ne!(ns_inode_inside, host_inode, "namespace inode should differ from host");
}
}
105 changes: 81 additions & 24 deletions msg-sim/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
//! classes. See the [`crate::tc`] module for details on the qdisc hierarchy.

use std::{
any::Any,
collections::{HashMap, HashSet},
fmt::{Debug, Display},
io,
Expand Down Expand Up @@ -137,8 +136,8 @@ pub struct Link(pub PeerId, pub PeerId);
impl Link {
/// Create a new directed link from source to destination.
#[inline]
pub fn new(source: PeerId, destination: PeerId) -> Self {
Link(source, destination)
pub fn new(source: impl Into<PeerId>, destination: impl Into<PeerId>) -> Self {
Link(source.into(), destination.into())
}

/// Get the source peer (traffic originates here).
Expand Down Expand Up @@ -194,7 +193,7 @@ impl PeerTcState {
}

/// Map from peer ID to peer instance.
pub type PeerMap = HashMap<PeerId, Peer<Context>>;
pub type PeerMap = HashMap<PeerId, Peer<PeerContext>>;

/// Map from peer ID to traffic control state.
type TcStateMap = HashMap<PeerId, PeerTcState>;
Expand All @@ -217,21 +216,66 @@ impl Peer {
}
}

pub(crate) type RuntimeFactory = Box<dyn FnOnce() -> tokio::runtime::Runtime + Send>;

pub fn default_runtime_factory() -> RuntimeFactory {
Box::new(|| {
tokio::runtime::Builder::new_multi_thread().enable_all().build().expect("to create runtime")
})
}

/// Common context provided to all namespaces.
///
/// This context gives access to rtnetlink for network configuration.
#[derive(Debug)]
pub struct CommonContext {
/// Handle for sending rtnetlink messages within this namespace.
handle: rtnetlink::Handle,
/// Background task processing rtnetlink responses.
_connection_task: tokio::task::JoinHandle<()>,
}

/// Context provided to tasks running within a peer's namespace.
///
/// This context gives access to rtnetlink for network configuration
/// and metadata about the peer's position in the network.
#[derive(Debug)]
pub struct Context {
pub struct PeerContext {
/// Handle for sending rtnetlink messages within this namespace.
handle: rtnetlink::Handle,
pub handle: rtnetlink::Handle,
/// Background task processing rtnetlink responses.
_connection_task: tokio::task::JoinHandle<()>,

/// The subnet this network uses.
subnet: Subnet,
pub subnet: Subnet,
/// This peer's ID.
peer_id: usize,
pub peer_id: PeerId,
}

/// Options for configuring a peer.
pub struct PeerOptions {
runtime_factory: RuntimeFactory,
}

impl Default for PeerOptions {
fn default() -> Self {
Self {
runtime_factory: Box::new(|| {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("to create runtime")
}),
}
}
}

impl PeerOptions {
/// Create new peer options with a custom runtime factory.
pub fn with_runtime(
runtime_factory: impl FnOnce() -> tokio::runtime::Runtime + Send + 'static,
) -> Self {
Self { runtime_factory: Box::new(runtime_factory) }
}
}

// -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -289,7 +333,7 @@ pub type Result<T> = std::result::Result<T, Error>;
/// # Example
///
/// ```no_run
/// use msg_sim::network::{Network, Link};
/// use msg_sim::network::{Network, Link, PeerOptions};
/// use msg_sim::tc::impairment::LinkImpairment;
/// use msg_sim::ip::Subnet;
/// use std::net::Ipv4Addr;
Expand Down Expand Up @@ -344,7 +388,7 @@ pub struct Network {
subnet: Subnet,

/// The hub namespace containing the bridge device.
network_hub_namespace: NetworkNamespace<Context>,
network_hub_namespace: NetworkNamespace<CommonContext>,

/// Rtnetlink handle bound to the host namespace.
///
Expand Down Expand Up @@ -379,11 +423,13 @@ impl Network {
.map(|(connection, handle, _)| (handle, tokio::task::spawn(connection)))
.unwrap();

Context { handle, subnet, peer_id: 0, _connection_task }
CommonContext { handle, _connection_task }
};

// Create the hub namespace that will host the bridge.
let namespace_hub = NetworkNamespace::new(Self::hub_namespace_name(), make_ctx).await?;
let namespace_hub =
NetworkNamespace::new(Self::hub_namespace_name(), default_runtime_factory(), make_ctx)
.await?;
let fd = namespace_hub.fd();

let network = Self {
Expand Down Expand Up @@ -418,7 +464,7 @@ impl Network {
/// 1. A new network namespace for the peer
/// 2. A veth pair connecting the peer to the hub bridge
/// 3. IP address assignment based on the subnet and peer ID
pub async fn add_peer(&mut self) -> Result<PeerId> {
pub async fn add_peer_with_options(&mut self, options: PeerOptions) -> Result<PeerId> {
let peer_id = PEER_ID_NEXT.load(Ordering::Relaxed);
let namespace_name = peer_id.namespace_name();
let veth_name = Arc::new(peer_id.veth_name());
Expand All @@ -435,10 +481,12 @@ impl Network {
.map(|(connection, handle, _)| (handle, tokio::task::spawn(connection)))
.expect("to create rtnetlink socket");

Context { handle, peer_id, subnet, _connection_task }
PeerContext { handle, _connection_task, subnet, peer_id }
};

let network_namespace = NetworkNamespace::new(namespace_name.clone(), make_ctx).await?;
let network_namespace =
NetworkNamespace::new(namespace_name.clone(), options.runtime_factory, make_ctx)
.await?;

// Step 1: Create the veth pair in the host namespace.
// One end (veth_name) will go to the peer, the other (veth_br_name) to the bridge.
Expand Down Expand Up @@ -476,7 +524,7 @@ impl Network {

network_namespace
.task_sender
.submit(|ctx| {
.submit(|ctx: &mut PeerContext| {
Box::pin(async move {
let address = ctx.peer_id.veth_address(ctx.subnet);
let mask = ctx.subnet.netmask;
Expand Down Expand Up @@ -541,6 +589,11 @@ impl Network {
Ok(peer_id)
}

/// See [`Self::add_peer_with_options`].
pub async fn add_peer(&mut self) -> Result<PeerId> {
self.add_peer_with_options(PeerOptions::default()).await
}

/// Run a task in a peer's network namespace.
///
/// The provided closure receives a mutable reference to the namespace's context,
Expand All @@ -560,7 +613,7 @@ impl Network {
///
/// ```no_run
/// use msg_sim::ip::Subnet;
/// use msg_sim::network::Network;
/// use msg_sim::network::{Network, PeerOptions};
/// use std::net::Ipv4Addr;
/// use tokio::net::TcpListener;
///
Expand Down Expand Up @@ -589,8 +642,8 @@ impl Network {
fut: F,
) -> Result<impl Future<Output = std::result::Result<T, oneshot::error::RecvError>>>
where
T: Any + Send + 'static,
F: for<'a> FnOnce(&'a mut Context) -> DynFuture<'a, T> + Send + 'static,
T: Send + 'static,
F: for<'a> FnOnce(&'a mut PeerContext) -> DynFuture<'a, T> + Send + 'static,
{
let Some(peer) = self.peers.get(&peer_id) else {
return Err(Error::PeerNotFound(peer_id));
Expand Down Expand Up @@ -630,7 +683,7 @@ impl Network {
/// ```no_run
/// use msg_sim::{
/// ip::Subnet,
/// network::{Link, Network},
/// network::{Link, Network, PeerOptions},
/// tc::impairment::LinkImpairment
/// };
///
Expand Down Expand Up @@ -681,7 +734,7 @@ impl Network {
src_peer
.namespace
.task_sender
.submit(move |ctx| {
.submit(move |ctx: &mut PeerContext| {
let span = tracing::debug_span!(
"apply_impairment",
link = %link,
Expand Down Expand Up @@ -996,13 +1049,17 @@ mod msg_sim_network {
req_socket_2.connect_sync(SocketAddr::new(address_2, port));
req_socket_3.connect_sync(SocketAddr::new(address_3, port));

// Measure RTT to peer 2 (should be ~200ms round trip for 100ms one-way)
// Wait for both TCP connections to be established before starting
// measurements. Peer 3 has 500ms one-way latency, so TCP handshake takes ~1s.
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;

// Measure RTT to peer 2 (should be ~100ms for one-way latency)
let start = Instant::now();
let resp = req_socket_2.request("ping".into()).await.unwrap();
let rtt_2 = start.elapsed();
assert_eq!(resp.as_ref(), b"peer2");

// Measure RTT to peer 3 (should be ~1000ms round trip for 500ms one-way)
// Measure RTT to peer 3 (should be ~500ms for one-way latency)
let start = Instant::now();
let resp = req_socket_3.request("ping".into()).await.unwrap();
let rtt_3 = start.elapsed();
Expand Down
Loading