Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/rust_vm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ fn main() -> Result<()> {
.args(["Hello from libkrun VM!"])
.env("HOME", "/root")
})
.on_exit(|| {
eprintln!("[on_exit] VM is shutting down — cleanup complete");
.on_exit(|exit_code| {
eprintln!("[on_exit] VM exiting with code {exit_code}");
})
.build()?
.enter()?;
Expand Down
2 changes: 1 addition & 1 deletion src/devices/src/virtio/console/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ impl VirtioDevice for Console {
}

impl VmmExitObserver for Console {
fn on_vmm_exit(&mut self) {
fn on_vmm_exit(&mut self, _exit_code: i32) {
self.reset();
log::trace!("Console on_vmm_exit finished");
}
Expand Down
112 changes: 112 additions & 0 deletions src/devices/src/virtio/console/port_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,118 @@ impl PortOutput for PortOutputFd {
}
}

//--------------------------------------------------------------------------------------------------
// Types: Custom Console Port Backend Adapters
//--------------------------------------------------------------------------------------------------

use std::sync::Arc;

/// Trait for custom virtio-console port backends.
///
/// Implementors provide bidirectional byte I/O between the host and guest
/// without going through file descriptors on the data path. The console
/// device's RX thread calls [`read`](Self::read) and the TX thread calls
/// [`write`](Self::write) — both via `&self`, so the implementation must
/// use interior mutability (e.g. lock-free ring buffers).
///
/// [`read_wake_fd`](Self::read_wake_fd) returns a file descriptor that
/// becomes readable when [`read`](Self::read) would return data. The
/// console RX thread uses this fd in a `poll()` call to block efficiently.
pub trait ConsolePortBackend: Send + Sync {
/// Read bytes destined for the guest (host → guest direction).
///
/// Copies up to `buf.len()` bytes into `buf`. Returns the number of
/// bytes written, or `WouldBlock` if no data is available.
fn read(&self, buf: &mut [u8]) -> io::Result<usize>;

/// Write bytes from the guest (guest → host direction).
///
/// Accepts up to `buf.len()` bytes. Returns the number of bytes
/// consumed, or `WouldBlock` if the backend cannot accept data.
fn write(&self, buf: &[u8]) -> io::Result<usize>;

/// File descriptor that becomes readable when [`read`](Self::read) has data.
///
/// Used by the console RX thread for `poll()`-based blocking. Typically
/// the read end of a wake pipe.
fn read_wake_fd(&self) -> RawFd;
}

/// Adapter that wraps a [`ConsolePortBackend`] as a [`PortInput`].
///
/// Used by the console RX thread to read data from the backend into guest
/// memory via VolatileSlice.
pub struct ConsolePortBackendInputAdapter {
backend: Arc<dyn ConsolePortBackend>,
}

impl ConsolePortBackendInputAdapter {
pub fn new(backend: Arc<dyn ConsolePortBackend>) -> Self {
Self { backend }
}
}

impl PortInput for ConsolePortBackendInputAdapter {
fn read_volatile(&mut self, buf: &mut VolatileSlice) -> io::Result<usize> {
let guard = buf.ptr_guard_mut();

// SAFETY: The VolatileSlice invariant guarantees the memory region is
// valid for writes of `buf.len()` bytes for the lifetime of the guard.
let dst = unsafe { std::slice::from_raw_parts_mut(guard.as_ptr(), buf.len()) };

let n = self.backend.read(dst)?;

if n > 0 {
buf.bitmap().mark_dirty(0, n);
}

Ok(n)
}

fn wait_until_readable(&self, stopfd: Option<&EventFd>) {
let wake_fd = self.backend.read_wake_fd();
let mut poll_fds = Vec::with_capacity(2);
let wake_bfd = unsafe { BorrowedFd::borrow_raw(wake_fd) };
poll_fds.push(PollFd::new(wake_bfd, PollFlags::POLLIN));
if let Some(stopfd) = stopfd {
let stop_bfd = unsafe { BorrowedFd::borrow_raw(stopfd.as_raw_fd()) };
poll_fds.push(PollFd::new(stop_bfd, PollFlags::POLLIN));
}
poll(&mut poll_fds, PollTimeout::NONE).expect("Failed to poll");
}
}

/// Adapter that wraps a [`ConsolePortBackend`] as a [`PortOutput`].
///
/// Used by the console TX thread to write guest data to the backend.
pub struct ConsolePortBackendOutputAdapter {
backend: Arc<dyn ConsolePortBackend>,
}

impl ConsolePortBackendOutputAdapter {
pub fn new(backend: Arc<dyn ConsolePortBackend>) -> Self {
Self { backend }
}
}

impl PortOutput for ConsolePortBackendOutputAdapter {
fn write_volatile(&mut self, buf: &VolatileSlice) -> io::Result<usize> {
let guard = buf.ptr_guard();

// SAFETY: The VolatileSlice invariant guarantees the memory region is
// valid for reads of `buf.len()` bytes for the lifetime of the guard.
let src = unsafe { std::slice::from_raw_parts(guard.as_ptr(), buf.len()) };

self.backend.write(src)
}

fn wait_until_writable(&self) {
// Ring buffer push is lock-free and practically instant. If the ring
// is full, write() returns WouldBlock and the console TX thread's
// existing retry loop handles backpressure.
}
}

fn dup_raw_fd_into_owned(raw_fd: RawFd) -> Result<OwnedFd, nix::Error> {
// SAFETY: if raw_fd is invalid the `dup` call below will fail
let borrowed_fd = unsafe { BorrowedFd::borrow_raw(raw_fd) };
Expand Down
13 changes: 8 additions & 5 deletions src/devices/src/virtio/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,16 @@ pub trait VirtioDevice: AsAny + Send {
}

pub trait VmmExitObserver: Send {
/// Callback to finish processing or cleanup the device resources
fn on_vmm_exit(&mut self) {}
/// Callback to finish processing or cleanup the device resources.
///
/// `exit_code` is the final exit code chosen by the VMM (from guest
/// vCPU or the shared `exit_code` Arc).
fn on_vmm_exit(&mut self, _exit_code: i32) {}
}

impl<F: Fn() + Send> VmmExitObserver for F {
fn on_vmm_exit(&mut self) {
self()
impl<F: Fn(i32) + Send> VmmExitObserver for F {
fn on_vmm_exit(&mut self, exit_code: i32) {
self(exit_code)
}
}

Expand Down
19 changes: 15 additions & 4 deletions src/krun/src/api/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! VM Builder for creating and configuring microVMs using nested builders.

use std::sync::atomic::AtomicI32;
#[cfg(not(any(feature = "tee", feature = "aws-nitro")))]
use std::sync::Arc;
#[cfg(any(feature = "tee", feature = "aws-nitro"))]
use std::sync::Arc;

use utils::eventfd::{EventFd, EFD_NONBLOCK};
use vmm::resources::{VirtioConsoleConfigMode, VmResources};
use vmm::vmm_config::machine_config::VmConfig;
use vmm::vmm_config::machine_config::VmConfigError;
Expand All @@ -20,7 +24,7 @@ use super::builders::{ConsoleBuilder, ExecBuilder, FsBuilder, KernelBuilder, Mac
#[cfg(feature = "net")]
use super::builders::{NetBuilder, NetConfig};

use super::error::{ConfigError, Error, Result};
use super::error::{BuildError, ConfigError, Error, Result};
use super::vm::Vm;

#[cfg(feature = "blk")]
Expand Down Expand Up @@ -68,7 +72,7 @@ pub struct VmBuilder {
net: NetBuilder,
#[cfg(feature = "blk")]
disk: DiskBuilder,
exit_observers: Vec<Box<dyn Fn() + Send + 'static>>,
exit_observers: Vec<Box<dyn Fn(i32) + Send + 'static>>,
}

//--------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -258,11 +262,12 @@ impl VmBuilder {
/// ```rust,no_run
/// # use msb_krun::VmBuilder;
/// VmBuilder::new()
/// .on_exit(|| {
/// .on_exit(|exit_code| {
/// // flush logs, write final status, etc.
/// eprintln!("VM exited with code {exit_code}");
/// });
/// ```
pub fn on_exit(mut self, f: impl Fn() + Send + 'static) -> Self {
pub fn on_exit(mut self, f: impl Fn(i32) + Send + 'static) -> Self {
self.exit_observers.push(Box::new(f));
self
}
Expand Down Expand Up @@ -488,6 +493,10 @@ impl VmBuilder {
)
};

let exit_evt = EventFd::new(EFD_NONBLOCK)
.map_err(|e| Error::Build(BuildError::Start(format!("exit EventFd: {e:?}"))))?;
let exit_code = Arc::new(AtomicI32::new(i32::MAX));

Ok(Vm::new(
vmr,
self.kernel.cmdline,
Expand All @@ -499,6 +508,8 @@ impl VmBuilder {
self.kernel.krunfw_path,
self.kernel.init_path,
self.exit_observers,
exit_evt,
exit_code,
))
}
}
Expand Down
31 changes: 30 additions & 1 deletion src/krun/src/api/builders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

use std::os::fd::RawFd;
use std::path::{Path, PathBuf};
use std::sync::Arc;

use devices::virtio::console::port_io::{
ConsolePortBackend, ConsolePortBackendInputAdapter, ConsolePortBackendOutputAdapter,
};
use vmm::resources::PortConfig;

#[cfg(not(any(feature = "tee", feature = "aws-nitro")))]
Expand Down Expand Up @@ -185,7 +189,7 @@ pub enum NetConfig {
/// .gpu_shm_size(1 << 28)
/// });
/// ```
#[derive(Debug, Clone, Default)]
#[derive(Default)]
pub struct ConsoleBuilder {
pub(crate) output: Option<PathBuf>,
pub(crate) ports: Vec<PortConfig>,
Expand Down Expand Up @@ -593,6 +597,31 @@ impl ConsoleBuilder {
self.disable_implicit = true;
self
}

/// Add a custom console port backend.
///
/// The backend provides bidirectional byte I/O without file descriptors
/// on the data path, matching the pattern of
/// [`NetBuilder::custom()`](NetBuilder::custom) and
/// [`FsBuilder::custom()`](FsBuilder::custom).
///
/// # Example
///
/// ```rust,ignore
/// VmBuilder::new()
/// .console(|c| c.custom("agent", Box::new(my_backend)))
/// ```
pub fn custom(mut self, name: &str, backend: Box<dyn ConsolePortBackend>) -> Self {
let backend: Arc<dyn ConsolePortBackend> = Arc::from(backend);
let input = Box::new(ConsolePortBackendInputAdapter::new(Arc::clone(&backend)));
let output = Box::new(ConsolePortBackendOutputAdapter::new(backend));
self.ports.push(PortConfig::Custom {
name: name.to_string(),
input,
output,
});
self
}
}

//--------------------------------------------------------------------------------------------------
Expand Down
86 changes: 86 additions & 0 deletions src/krun/src/api/exit_handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//! Handle for triggering VM exit from any thread.

use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::{io, result};

use utils::eventfd::EventFd;

//--------------------------------------------------------------------------------------------------
// Types
//--------------------------------------------------------------------------------------------------

/// A thread-safe, cloneable handle that triggers VM exit when fired.
///
/// Obtained via [`Vm::exit_handle()`](super::vm::Vm::exit_handle) before
/// calling [`Vm::enter()`](super::vm::Vm::enter). Background tasks (idle
/// timeout, max-duration timer, relay drain) use this to shut down the VMM
/// cleanly — the exit event fd fires, exit observers run, and the process
/// terminates.
///
/// Multiple triggers are idempotent: the VMM reads the event once and calls
/// `_exit()`.
pub struct ExitHandle {
write_fd: OwnedFd,
}

//--------------------------------------------------------------------------------------------------
// Methods
//--------------------------------------------------------------------------------------------------

impl ExitHandle {
/// Create an `ExitHandle` from an [`EventFd`].
///
/// Dups the write end so the handle is independent of the original fd
/// lifetime.
pub(crate) fn from_event_fd(evt: &EventFd) -> result::Result<Self, io::Error> {
// On Linux, EventFd is a single fd (read/write on the same fd).
// On macOS, EventFd is a pipe pair — we need the write end.
#[cfg(target_os = "linux")]
let raw_fd = evt.as_raw_fd();
#[cfg(target_os = "macos")]
let raw_fd = evt.get_write_fd();

let fd = unsafe { libc::dup(raw_fd) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let write_fd = unsafe { OwnedFd::from_raw_fd(fd) };
Ok(Self { write_fd })
}

/// Trigger VM exit.
///
/// Writes to the exit event fd, causing the VMM event loop to invoke
/// exit observers and call `_exit()`. Safe to call from any thread.
/// Async-signal-safe.
pub fn trigger(&self) {
let val: u64 = 1;
// SAFETY: write_fd is a valid, owned file descriptor. Writing 8 bytes
// (a u64) matches the EventFd/pipe protocol used by the VMM.
let _ = unsafe {
libc::write(
self.write_fd.as_raw_fd(),
&val as *const u64 as *const libc::c_void,
std::mem::size_of::<u64>(),
)
};
}
}

//--------------------------------------------------------------------------------------------------
// Trait Implementations
//--------------------------------------------------------------------------------------------------

impl Clone for ExitHandle {
fn clone(&self) -> Self {
let fd = unsafe { libc::dup(self.write_fd.as_raw_fd()) };
assert!(fd >= 0, "Failed to dup ExitHandle fd");
let write_fd = unsafe { OwnedFd::from_raw_fd(fd) };
Self { write_fd }
}
}

// SAFETY: OwnedFd is Send. The write operation is atomic for 8-byte writes
// on both eventfd (Linux) and pipes (macOS, when ≤ PIPE_BUF).
unsafe impl Send for ExitHandle {}
unsafe impl Sync for ExitHandle {}
2 changes: 2 additions & 0 deletions src/krun/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
pub mod builder;
pub mod builders;
pub mod error;
pub mod exit_handle;
pub mod vm;

//--------------------------------------------------------------------------------------------------
Expand All @@ -45,4 +46,5 @@ pub use builders::DiskImageFormat;
pub use builders::NetBuilder;
pub use builders::{ConsoleBuilder, ExecBuilder, FsBuilder, KernelBuilder, MachineBuilder};
pub use error::{BuildError, ConfigError, Error, Result, RuntimeError};
pub use exit_handle::ExitHandle;
pub use vm::Vm;
Loading
Loading