diff --git a/.buildkite/rust-vmm-ci-tests.json b/.buildkite/rust-vmm-ci-tests.json index 630fc618..8700098e 100644 --- a/.buildkite/rust-vmm-ci-tests.json +++ b/.buildkite/rust-vmm-ci-tests.json @@ -22,7 +22,7 @@ }, { "test_name": "unittests-gnu-all-with-xen", - "command": "cargo test --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,xen", + "command": "cargo test --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,xen", "platform": [ "x86_64", "aarch64" @@ -30,7 +30,7 @@ }, { "test_name": "unittests-gnu-all-without-xen", - "command": "cargo test --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,postcopy", + "command": "cargo test --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,postcopy", "platform": [ "x86_64", "aarch64" @@ -38,7 +38,7 @@ }, { "test_name": "unittests-musl-all-with-xen", - "command": "cargo test --workspace --target {target_platform}-unknown-linux-musl --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,xen", + "command": "cargo test --workspace --target {target_platform}-unknown-linux-musl --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,xen", "platform": [ "x86_64", "aarch64" @@ -46,7 +46,7 @@ }, { "test_name": "unittests-musl-all-without-xen", - "command": "cargo test --workspace --target {target_platform}-unknown-linux-musl --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,postcopy", + "command": "cargo test --workspace --target {target_platform}-unknown-linux-musl --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,postcopy", "platform": [ "x86_64", "aarch64" @@ -54,7 +54,7 @@ }, { "test_name": "clippy-all-with-xen", - "command": "cargo clippy --workspace --bins --examples --benches --all-targets --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,xen -- -D warnings -D clippy::undocumented_unsafe_blocks", + "command": "cargo clippy --workspace --bins --examples --benches --all-targets --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,xen -- -D warnings -D clippy::undocumented_unsafe_blocks", "platform": [ "x86_64", "aarch64" @@ -62,7 +62,7 @@ }, { "test_name": "clippy-all-without-xen", - "command": "cargo clippy --workspace --bins --examples --benches --all-targets --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,postcopy -- -D warnings -D clippy::undocumented_unsafe_blocks", + "command": "cargo clippy --workspace --bins --examples --benches --all-targets --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,postcopy -- -D warnings -D clippy::undocumented_unsafe_blocks", "platform": [ "x86_64", "aarch64" @@ -70,7 +70,7 @@ }, { "test_name": "check-warnings-all-with-xen", - "command": "RUSTFLAGS=\"-D warnings\" cargo check --all-targets --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,xen", + "command": "RUSTFLAGS=\"-D warnings\" cargo check --all-targets --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,xen", "platform": [ "x86_64", "aarch64" @@ -78,7 +78,7 @@ }, { "test_name": "check-warnings-all-without-xen", - "command": "RUSTFLAGS=\"-D warnings\" cargo check --all-targets --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,vhost-user-backend,postcopy", + "command": "RUSTFLAGS=\"-D warnings\" cargo check --all-targets --workspace --no-default-features --features test-utils,vhost-vsock,vhost-kern,vhost-vdpa,vhost-net,vhost-user,vhost-user-frontend,gpu-socket,vhost-user-backend,postcopy", "platform": [ "x86_64", "aarch64" diff --git a/vhost-user-backend/CHANGELOG.md b/vhost-user-backend/CHANGELOG.md index 08ac715c..14bc3615 100644 --- a/vhost-user-backend/CHANGELOG.md +++ b/vhost-user-backend/CHANGELOG.md @@ -2,7 +2,7 @@ ## [Unreleased] ### Added - +- [[#239]](https://github.com/rust-vmm/vhost/pull/239) Add support for `VHOST_USER_GPU_SET_SOCKET` ### Changed ### Fixed diff --git a/vhost-user-backend/Cargo.toml b/vhost-user-backend/Cargo.toml index df232538..7093466a 100644 --- a/vhost-user-backend/Cargo.toml +++ b/vhost-user-backend/Cargo.toml @@ -11,6 +11,7 @@ license = "Apache-2.0" [features] xen = ["vm-memory/xen", "vhost/xen"] postcopy = ["vhost/postcopy", "userfaultfd"] +gpu-socket = ["vhost/gpu-socket"] [dependencies] libc = "0.2.39" diff --git a/vhost-user-backend/src/backend.rs b/vhost-user-backend/src/backend.rs index e78dac51..cfe05a2d 100644 --- a/vhost-user-backend/src/backend.rs +++ b/vhost-user-backend/src/backend.rs @@ -31,6 +31,9 @@ use vm_memory::bitmap::Bitmap; use vmm_sys_util::epoll::EventSet; use vmm_sys_util::eventfd::EventFd; +#[cfg(feature = "gpu-socket")] +use vhost::vhost_user::GpuBackend; + use super::vring::VringT; use super::GM; @@ -84,6 +87,9 @@ pub trait VhostUserBackend: Send + Sync { /// function. fn set_backend_req_fd(&self, _backend: Backend) {} + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, _gpu_backend: GpuBackend) {} + /// Get the map to map queue index to worker thread index. /// /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0, @@ -194,6 +200,9 @@ pub trait VhostUserBackendMut: Send + Sync { /// function. fn set_backend_req_fd(&mut self, _backend: Backend) {} + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) {} + /// Get the map to map queue index to worker thread index. /// /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0, @@ -299,6 +308,11 @@ impl VhostUserBackend for Arc { self.deref().set_backend_req_fd(backend) } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, gpu_backend: GpuBackend) { + self.deref().set_gpu_socket(gpu_backend) + } + fn queues_per_thread(&self) -> Vec { self.deref().queues_per_thread() } @@ -376,6 +390,11 @@ impl VhostUserBackend for Mutex { self.lock().unwrap().set_backend_req_fd(backend) } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, gpu_backend: GpuBackend) { + self.lock().unwrap().set_gpu_socket(gpu_backend) + } + fn queues_per_thread(&self) -> Vec { self.lock().unwrap().queues_per_thread() } @@ -456,6 +475,11 @@ impl VhostUserBackend for RwLock { self.write().unwrap().set_backend_req_fd(backend) } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, gpu_backend: GpuBackend) { + self.write().unwrap().set_gpu_socket(gpu_backend) + } + fn queues_per_thread(&self) -> Vec { self.read().unwrap().queues_per_thread() } @@ -576,6 +600,9 @@ pub mod tests { fn set_backend_req_fd(&mut self, _backend: Backend) {} + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) {} + fn queues_per_thread(&self) -> Vec { vec![1, 1] } diff --git a/vhost-user-backend/src/handler.rs b/vhost-user-backend/src/handler.rs index 62c4a66d..7e471e95 100644 --- a/vhost-user-backend/src/handler.rs +++ b/vhost-user-backend/src/handler.rs @@ -21,9 +21,12 @@ use vhost::vhost_user::message::{ VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, VhostUserVringState, }; +#[cfg(feature = "gpu-socket")] +use vhost::vhost_user::GpuBackend; use vhost::vhost_user::{ Backend, Error as VhostUserError, Result as VhostUserResult, VhostUserBackendReqHandlerMut, }; + use virtio_bindings::bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX; use virtio_queue::{Error as VirtQueError, QueueT}; use vm_memory::mmap::NewBitmap; @@ -549,6 +552,11 @@ where self.backend.set_backend_req_fd(backend); } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&mut self, gpu_backend: GpuBackend) { + self.backend.set_gpu_socket(gpu_backend); + } + fn get_inflight_fd( &mut self, _inflight: &vhost::vhost_user::message::VhostUserInflight, diff --git a/vhost/CHANGELOG.md b/vhost/CHANGELOG.md index 7ce77bef..a5103311 100644 --- a/vhost/CHANGELOG.md +++ b/vhost/CHANGELOG.md @@ -2,7 +2,7 @@ ## [Unreleased] ### Added - +- [[#239]](https://github.com/rust-vmm/vhost/pull/239) Add support for `VHOST_USER_GPU_SET_SOCKET` ### Changed ### Fixed diff --git a/vhost/Cargo.toml b/vhost/Cargo.toml index 44388bc0..f7fcef36 100644 --- a/vhost/Cargo.toml +++ b/vhost/Cargo.toml @@ -23,6 +23,7 @@ vhost-net = ["vhost-kern"] vhost-user = [] vhost-user-frontend = ["vhost-user"] vhost-user-backend = ["vhost-user"] +gpu-socket = ["vhost-user", "vhost-user-backend"] xen = ["vm-memory/xen"] postcopy = [] diff --git a/vhost/src/vhost_user/backend.rs b/vhost/src/vhost_user/backend.rs index 48e44571..8463e1ad 100644 --- a/vhost/src/vhost_user/backend.rs +++ b/vhost/src/vhost_user/backend.rs @@ -32,7 +32,7 @@ impl BackendListener { pub fn accept(&mut self) -> Result>> { if let Some(fd) = self.listener.accept()? { return Ok(Some(BackendReqHandler::new( - Endpoint::::from_stream(fd), + Endpoint::>::from_stream(fd), self.backend.take().unwrap(), ))); } diff --git a/vhost/src/vhost_user/backend_req.rs b/vhost/src/vhost_user/backend_req.rs index b43982f0..d89cdb33 100644 --- a/vhost/src/vhost_user/backend_req.rs +++ b/vhost/src/vhost_user/backend_req.rs @@ -14,7 +14,7 @@ use super::{Error, HandlerResult, Result, VhostUserFrontendReqHandler}; use vm_memory::ByteValued; struct BackendInternal { - sock: Endpoint, + sock: Endpoint>, // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. reply_ack_negotiated: bool, @@ -83,7 +83,7 @@ pub struct Backend { } impl Backend { - fn new(ep: Endpoint) -> Self { + fn new(ep: Endpoint>) -> Self { Backend { node: Arc::new(Mutex::new(BackendInternal { sock: ep, @@ -110,7 +110,9 @@ impl Backend { /// Create a new instance from a `UnixStream` object. pub fn from_stream(sock: UnixStream) -> Self { - Self::new(Endpoint::::from_stream(sock)) + Self::new(Endpoint::>::from_stream( + sock, + )) } /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. @@ -175,7 +177,7 @@ mod tests { fn test_backend_req_recv_negative() { let (p1, p2) = UnixStream::pair().unwrap(); let backend = Backend::from_stream(p1); - let mut frontend = Endpoint::::from_stream(p2); + let mut frontend = Endpoint::>::from_stream(p2); let len = mem::size_of::(); let mut hdr = VhostUserMsgHeader::new( diff --git a/vhost/src/vhost_user/backend_req_handler.rs b/vhost/src/vhost_user/backend_req_handler.rs index 25ffd9c5..7af7df5c 100644 --- a/vhost/src/vhost_user/backend_req_handler.rs +++ b/vhost/src/vhost_user/backend_req_handler.rs @@ -12,6 +12,8 @@ use vm_memory::ByteValued; use super::backend_req::Backend; use super::connection::Endpoint; +#[cfg(feature = "gpu-socket")] +use super::gpu_backend_req::GpuBackend; use super::message::*; use super::{take_single_file, Error, Result}; @@ -65,6 +67,8 @@ pub trait VhostUserBackendReqHandler { fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result>; fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; fn set_backend_req_fd(&self, _backend: Backend) {} + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, _gpu_backend: GpuBackend) {} fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&self) -> Result; @@ -124,6 +128,8 @@ pub trait VhostUserBackendReqHandlerMut { ) -> Result>; fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; fn set_backend_req_fd(&mut self, _backend: Backend) {} + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) {} fn get_inflight_fd( &mut self, inflight: &VhostUserInflight, @@ -235,6 +241,11 @@ impl VhostUserBackendReqHandler for Mutex { self.lock().unwrap().set_backend_req_fd(backend) } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&self, gpu_backend: GpuBackend) { + self.lock().unwrap().set_gpu_socket(gpu_backend); + } + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { self.lock().unwrap().get_inflight_fd(inflight) } @@ -302,7 +313,7 @@ impl VhostUserBackendReqHandler for Mutex { /// [BackendReqHandler]: struct.BackendReqHandler.html pub struct BackendReqHandler { // underlying Unix domain socket for communication - main_sock: Endpoint, + main_sock: Endpoint>, // the vhost-user backend device object backend: Arc, @@ -319,7 +330,10 @@ pub struct BackendReqHandler { impl BackendReqHandler { /// Create a vhost-user backend endpoint. - pub(super) fn new(main_sock: Endpoint, backend: Arc) -> Self { + pub(super) fn new( + main_sock: Endpoint>, + backend: Arc, + ) -> Self { BackendReqHandler { main_sock, backend, @@ -359,7 +373,10 @@ impl BackendReqHandler { /// * - `path` - path of Unix domain socket listener to connect to /// * - `backend` - handler for requests from the frontend to the backend pub fn connect(path: &str, backend: Arc) -> Result { - Ok(Self::new(Endpoint::::connect(path)?, backend)) + Ok(Self::new( + Endpoint::>::connect(path)?, + backend, + )) } /// Mark endpoint as failed with specified error code. @@ -554,6 +571,11 @@ impl BackendReqHandler { let res = self.backend.set_inflight_fd(&msg, file); self.send_ack_message(&hdr, res)?; } + #[cfg(feature = "gpu-socket")] + Ok(FrontendReq::GPU_SET_SOCKET) => { + let res = self.set_gpu_socket(files); + self.send_ack_message(&hdr, res)?; + } Ok(FrontendReq::GET_MAX_MEM_SLOTS) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; self.check_request_size(&hdr, size, 0)?; @@ -790,6 +812,18 @@ impl BackendReqHandler { Ok(()) } + #[cfg(feature = "gpu-socket")] + fn set_gpu_socket(&mut self, files: Option>) -> Result<()> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + // SAFETY: Safe because we have ownership of the files that were + // checked when received. We have to trust that they are Unix sockets + // since we have no way to check this. If not, it will fail later. + let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) }; + let gpu_backend = GpuBackend::from_stream(sock); + self.backend.set_gpu_socket(gpu_backend); + Ok(()) + } + fn handle_vring_fd_request( &mut self, buf: &[u8], @@ -859,7 +893,8 @@ impl BackendReqHandler { | FrontendReq::SET_BACKEND_REQ_FD | FrontendReq::SET_INFLIGHT_FD | FrontendReq::ADD_MEM_REG - | FrontendReq::SET_DEVICE_STATE_FD, + | FrontendReq::SET_DEVICE_STATE_FD + | FrontendReq::GPU_SET_SOCKET, ) => Ok(()), _ if files.is_some() => Err(Error::InvalidMessage), _ => Ok(()), @@ -965,7 +1000,7 @@ mod tests { #[test] fn test_backend_req_handler_new() { let (p1, _p2) = UnixStream::pair().unwrap(); - let endpoint = Endpoint::::from_stream(p1); + let endpoint = Endpoint::>::from_stream(p1); let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new())); let mut handler = BackendReqHandler::new(endpoint, backend); diff --git a/vhost/src/vhost_user/connection.rs b/vhost/src/vhost_user/connection.rs index 9f614271..92995cf4 100644 --- a/vhost/src/vhost_user/connection.rs +++ b/vhost/src/vhost_user/connection.rs @@ -102,12 +102,12 @@ impl Drop for Listener { } /// Unix domain socket endpoint for vhost-user connection. -pub(super) struct Endpoint { +pub(super) struct Endpoint { sock: UnixStream, - _r: PhantomData, + _h: PhantomData, } -impl Endpoint { +impl Endpoint { /// Create a new stream by connecting to server at `str`. /// /// # Return: @@ -122,7 +122,7 @@ impl Endpoint { pub fn from_stream(sock: UnixStream) -> Self { Endpoint { sock, - _r: PhantomData, + _h: PhantomData, } } @@ -196,20 +196,16 @@ impl Endpoint { /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. - pub fn send_header( - &mut self, - hdr: &VhostUserMsgHeader, - fds: Option<&[RawFd]>, - ) -> Result<()> { + pub fn send_header(&mut self, hdr: &H, fds: Option<&[RawFd]>) -> Result<()> { // SAFETY: Safe because there can't be other mutable referance to hdr. let iovs = unsafe { [slice::from_raw_parts( - hdr as *const VhostUserMsgHeader as *const u8, - mem::size_of::>(), + hdr as *const H as *const u8, + mem::size_of::(), )] }; let bytes = self.send_iovec_all(&iovs[..], fds)?; - if bytes != mem::size_of::>() { + if bytes != mem::size_of::() { return Err(Error::PartialMessage); } Ok(()) @@ -226,15 +222,15 @@ impl Endpoint { /// * - PartialMessage: received a partial message. pub fn send_message( &mut self, - hdr: &VhostUserMsgHeader, + hdr: &H, body: &T, fds: Option<&[RawFd]>, ) -> Result<()> { - if mem::size_of::() > MAX_MSG_SIZE { + if mem::size_of::() > H::MAX_MSG_SIZE { return Err(Error::OversizedMsg); } let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?; - if bytes != mem::size_of::>() + mem::size_of::() { + if bytes != mem::size_of::() + mem::size_of::() { return Err(Error::PartialMessage); } Ok(()) @@ -253,16 +249,16 @@ impl Endpoint { /// * - IncorrectFds: wrong number of attached fds. pub fn send_message_with_payload( &mut self, - hdr: &VhostUserMsgHeader, + hdr: &H, body: &T, payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()> { let len = payload.len(); - if mem::size_of::() > MAX_MSG_SIZE { + if mem::size_of::() > H::MAX_MSG_SIZE { return Err(Error::OversizedMsg); } - if len > MAX_MSG_SIZE - mem::size_of::() { + if len > H::MAX_MSG_SIZE - mem::size_of::() { return Err(Error::OversizedMsg); } if let Some(fd_arr) = fds { @@ -271,7 +267,7 @@ impl Endpoint { } } - let total = mem::size_of::>() + mem::size_of::() + len; + let total = mem::size_of::() + mem::size_of::() + len; let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?; if len != total { return Err(Error::PartialMessage); @@ -445,18 +441,18 @@ impl Endpoint { /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. /// * - InvalidMessage: received a invalid message. - pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader, Option>)> { - let mut hdr = VhostUserMsgHeader::default(); + pub fn recv_header(&mut self) -> Result<(H, Option>)> { + let mut hdr = H::default(); let mut iovs = [iovec { - iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, - iov_len: mem::size_of::>(), + iov_base: (&mut hdr as *mut H) as *mut c_void, + iov_len: mem::size_of::(), }]; // SAFETY: Safe because we own hdr and it's ByteValued. let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; if bytes == 0 { return Err(Error::Disconnected); - } else if bytes != mem::size_of::>() { + } else if bytes != mem::size_of::() { return Err(Error::PartialMessage); } else if !hdr.is_valid() { return Err(Error::InvalidMessage); @@ -478,13 +474,13 @@ impl Endpoint { /// * - InvalidMessage: received a invalid message. pub fn recv_body( &mut self, - ) -> Result<(VhostUserMsgHeader, T, Option>)> { - let mut hdr = VhostUserMsgHeader::default(); + ) -> Result<(H, T, Option>)> { + let mut hdr = H::default(); let mut body: T = Default::default(); let mut iovs = [ iovec { - iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, - iov_len: mem::size_of::>(), + iov_base: (&mut hdr as *mut H) as *mut c_void, + iov_len: mem::size_of::(), }, iovec { iov_base: (&mut body as *mut T) as *mut c_void, @@ -494,7 +490,7 @@ impl Endpoint { // SAFETY: Safe because we own hdr and body and they're ByteValued. let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; - let total = mem::size_of::>() + mem::size_of::(); + let total = mem::size_of::() + mem::size_of::(); if bytes != total { return Err(Error::PartialMessage); } else if !hdr.is_valid() || !body.is_valid() { @@ -518,15 +514,12 @@ impl Endpoint { /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. /// * - InvalidMessage: received a invalid message. - pub fn recv_body_into_buf( - &mut self, - buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader, usize, Option>)> { - let mut hdr = VhostUserMsgHeader::default(); + pub fn recv_body_into_buf(&mut self, buf: &mut [u8]) -> Result<(H, usize, Option>)> { + let mut hdr = H::default(); let mut iovs = [ iovec { - iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, - iov_len: mem::size_of::>(), + iov_base: (&mut hdr as *mut H) as *mut c_void, + iov_len: mem::size_of::(), }, iovec { iov_base: buf.as_mut_ptr() as *mut c_void, @@ -537,13 +530,13 @@ impl Endpoint { // and it's safe to fill a byte slice with arbitrary data. let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; - if bytes < mem::size_of::>() { + if bytes < mem::size_of::() { return Err(Error::PartialMessage); } else if !hdr.is_valid() { return Err(Error::InvalidMessage); } - Ok((hdr, bytes - mem::size_of::>(), files)) + Ok((hdr, bytes - mem::size_of::(), files)) } /// Receive a message with optional payload and attached file descriptors. @@ -561,13 +554,13 @@ impl Endpoint { pub fn recv_payload_into_buf( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader, T, usize, Option>)> { - let mut hdr = VhostUserMsgHeader::default(); + ) -> Result<(H, T, usize, Option>)> { + let mut hdr = H::default(); let mut body: T = Default::default(); let mut iovs = [ iovec { - iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, - iov_len: mem::size_of::>(), + iov_base: (&mut hdr as *mut H) as *mut c_void, + iov_len: mem::size_of::(), }, iovec { iov_base: (&mut body as *mut T) as *mut c_void, @@ -583,7 +576,7 @@ impl Endpoint { // arbitrary data. let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; - let total = mem::size_of::>() + mem::size_of::(); + let total = mem::size_of::() + mem::size_of::(); if bytes < total { return Err(Error::PartialMessage); } else if !hdr.is_valid() || !body.is_valid() { @@ -594,7 +587,7 @@ impl Endpoint { } } -impl AsRawFd for Endpoint { +impl AsRawFd for Endpoint { fn as_raw_fd(&self) -> RawFd { self.sock.as_raw_fd() } @@ -669,9 +662,9 @@ mod tests { let path = temp_path(); let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut frontend = Endpoint::::connect(&path).unwrap(); + let mut frontend = Endpoint::>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); - let mut backend = Endpoint::::from_stream(sock); + let mut backend = Endpoint::>::from_stream(sock); let buf1 = [0x1, 0x2, 0x3, 0x4]; let mut len = frontend.send_slice(&buf1[..], None).unwrap(); @@ -695,9 +688,9 @@ mod tests { let path = temp_path(); let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut frontend = Endpoint::::connect(&path).unwrap(); + let mut frontend = Endpoint::>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); - let mut backend = Endpoint::::from_stream(sock); + let mut backend = Endpoint::>::from_stream(sock); let mut fd = TempFile::new().unwrap().into_file(); write!(fd, "test").unwrap(); @@ -846,9 +839,9 @@ mod tests { let path = temp_path(); let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut frontend = Endpoint::::connect(&path).unwrap(); + let mut frontend = Endpoint::>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); - let mut backend = Endpoint::::from_stream(sock); + let mut backend = Endpoint::>::from_stream(sock); let mut hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, mem::size_of::() as u32); @@ -883,7 +876,7 @@ mod tests { let listener = Listener::new(&path, true).unwrap(); let mut frontend = UnixStream::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); - let mut backend = Endpoint::::from_stream(sock); + let mut backend = Endpoint::>::from_stream(sock); write!(frontend, "a").unwrap(); drop(frontend); @@ -896,7 +889,7 @@ mod tests { let listener = Listener::new(&path, true).unwrap(); let _ = UnixStream::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); - let mut backend = Endpoint::::from_stream(sock); + let mut backend = Endpoint::>::from_stream(sock); assert!(matches!(backend.recv_header(), Err(Error::Disconnected))); } diff --git a/vhost/src/vhost_user/frontend.rs b/vhost/src/vhost_user/frontend.rs index 8f625ead..5234b12a 100644 --- a/vhost/src/vhost_user/frontend.rs +++ b/vhost/src/vhost_user/frontend.rs @@ -106,7 +106,7 @@ pub struct Frontend { impl Frontend { /// Create a new instance. - fn new(ep: Endpoint, max_queue_num: u64) -> Self { + fn new(ep: Endpoint>, max_queue_num: u64) -> Self { Frontend { node: Arc::new(Mutex::new(FrontendInternal { main_sock: ep, @@ -128,7 +128,10 @@ impl Frontend { /// Create a new instance from a Unix stream socket. pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { - Self::new(Endpoint::::from_stream(sock), max_queue_num) + Self::new( + Endpoint::>::from_stream(sock), + max_queue_num, + ) } /// Create a new vhost-user frontend endpoint. @@ -140,7 +143,7 @@ impl Frontend { pub fn connect>(path: P, max_queue_num: u64) -> Result { let mut retry_count = 5; let endpoint = loop { - match Endpoint::::connect(&path) { + match Endpoint::>::connect(&path) { Ok(endpoint) => break Ok(endpoint), Err(e) => match &e { VhostUserError::SocketConnect(why) => { @@ -600,7 +603,7 @@ impl VhostUserMemoryContext { struct FrontendInternal { // Used to send requests to the backend. - main_sock: Endpoint, + main_sock: Endpoint>, // Cached virtio features from the backend. virtio_features: u64, // Cached acked virtio features from the driver. @@ -818,7 +821,9 @@ mod tests { )) } - fn create_pair>(path: P) -> (Frontend, Endpoint) { + fn create_pair>( + path: P, + ) -> (Frontend, Endpoint>) { let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); let frontend = Frontend::connect(path, 2).unwrap(); @@ -833,7 +838,9 @@ mod tests { listener.set_nonblocking(true).unwrap(); let frontend = Frontend::connect(&path, 1).unwrap(); - let mut backend = Endpoint::::from_stream(listener.accept().unwrap().unwrap()); + let mut backend = Endpoint::>::from_stream( + listener.accept().unwrap().unwrap(), + ); assert!(frontend.as_raw_fd() > 0); // Send two messages continuously @@ -1001,7 +1008,7 @@ mod tests { .unwrap_err(); } - fn create_pair2() -> (Frontend, Endpoint) { + fn create_pair2() -> (Frontend, Endpoint>) { let path = temp_path(); let (frontend, peer) = create_pair(path); diff --git a/vhost/src/vhost_user/frontend_req_handler.rs b/vhost/src/vhost_user/frontend_req_handler.rs index fb2dc16f..cf8a9e48 100644 --- a/vhost/src/vhost_user/frontend_req_handler.rs +++ b/vhost/src/vhost_user/frontend_req_handler.rs @@ -130,7 +130,7 @@ impl VhostUserFrontendReqHandler for Mutex /// [VhostUserFrontendReqHandler]: trait.VhostUserFrontendReqHandler.html pub struct FrontendReqHandler { // underlying Unix domain socket for communication - sub_sock: Endpoint, + sub_sock: Endpoint>, tx_sock: UnixStream, // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. reply_ack_negotiated: bool, @@ -153,7 +153,7 @@ impl FrontendReqHandler { let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; Ok(FrontendReqHandler { - sub_sock: Endpoint::::from_stream(rx), + sub_sock: Endpoint::>::from_stream(rx), tx_sock: tx, reply_ack_negotiated: false, backend, diff --git a/vhost/src/vhost_user/gpu_backend_req.rs b/vhost/src/vhost_user/gpu_backend_req.rs new file mode 100644 index 00000000..265b7fcb --- /dev/null +++ b/vhost/src/vhost_user/gpu_backend_req.rs @@ -0,0 +1,473 @@ +// Copyright (C) 2024 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::os::fd::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex, MutexGuard}; +use std::{io, mem, slice}; + +use vm_memory::ByteValued; + +use crate::vhost_user; +use crate::vhost_user::connection::Endpoint; +use crate::vhost_user::gpu_message::*; +use crate::vhost_user::message::{VhostUserMsgValidator, VhostUserU64}; +use crate::vhost_user::Error; + +struct BackendInternal { + sock: Endpoint>, + // whether the endpoint has encountered any failure + error: Option, +} + +impl BackendInternal { + fn check_state(&self) -> vhost_user::Result { + match self.error { + Some(e) => Err(Error::SocketBroken(io::Error::from_raw_os_error(e))), + None => Ok(0), + } + } + + fn send_header( + &mut self, + request: GpuBackendReq, + fds: Option<&[RawFd]>, + ) -> vhost_user::Result { + self.check_state()?; + + let hdr = VhostUserGpuMsgHeader::new(request, 0, 0); + self.sock.send_header(&hdr, fds)?; + + self.wait_for_ack(&hdr) + } + + fn send_message_no_reply( + &mut self, + request: GpuBackendReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> vhost_user::Result<()> { + self.check_state()?; + + let len = mem::size_of::(); + let hdr = VhostUserGpuMsgHeader::new(request, 0, len as u32); + self.sock.send_message(&hdr, body, fds)?; + Ok(()) + } + + fn send_message_with_payload_no_reply( + &mut self, + request: GpuBackendReq, + body: &T, + data: &[u8], + fds: Option<&[RawFd]>, + ) -> vhost_user::Result<()> { + self.check_state()?; + + let len = mem::size_of::() + data.len(); + let hdr = VhostUserGpuMsgHeader::new(request, 0, len as u32); + self.sock.send_message_with_payload(&hdr, body, data, fds)?; + Ok(()) + } + + fn send_message( + &mut self, + request: GpuBackendReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> vhost_user::Result { + self.check_state()?; + + let len = mem::size_of::(); + let hdr = VhostUserGpuMsgHeader::new(request, 0, len as u32); + self.sock.send_message(&hdr, body, fds)?; + + self.wait_for_ack(&hdr) + } + + fn wait_for_ack( + &mut self, + hdr: &VhostUserGpuMsgHeader, + ) -> vhost_user::Result { + self.check_state()?; + let (reply, body, rfds) = self.sock.recv_body::()?; + if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(body) + } +} + +/// Proxy for sending messages from the backend to the fronted +/// over the socket obtained from VHOST_USER_GPU_SET_SOCKET. +/// The protocol is documented here: https://www.qemu.org/docs/master/interop/vhost-user-gpu.html +#[derive(Clone)] +pub struct GpuBackend { + // underlying Unix domain socket for communication + node: Arc>, +} + +impl GpuBackend { + fn new(ep: Endpoint>) -> Self { + Self { + node: Arc::new(Mutex::new(BackendInternal { + sock: ep, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard { + self.node.lock().unwrap() + } + + fn send_header( + &self, + request: GpuBackendReq, + fds: Option<&[RawFd]>, + ) -> io::Result { + self.node() + .send_header(request, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + fn send_message( + &self, + request: GpuBackendReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> io::Result { + self.node() + .send_message(request, body, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + fn send_message_no_reply( + &self, + request: GpuBackendReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> io::Result<()> { + self.node() + .send_message_no_reply(request, body, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + fn send_message_with_payload_no_reply( + &self, + request: GpuBackendReq, + body: &T, + data: &[u8], + fds: Option<&[RawFd]>, + ) -> io::Result<()> { + self.node() + .send_message_with_payload_no_reply(request, body, data, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + /// Send the VHOST_USER_GPU_GET_PROTOCOL_FEATURES message to the frontend and wait for a reply. + /// Get the supported protocol features bitmask. + pub fn get_protocol_features(&self) -> io::Result { + self.send_header(GpuBackendReq::GET_PROTOCOL_FEATURES, None) + } + + /// Send the VHOST_USER_GPU_SET_PROTOCOL_FEATURES message to the frontend. Doesn't wait for + /// a reply. + /// Enable protocol features using a bitmask. + pub fn set_protocol_features(&self, msg: &VhostUserU64) -> io::Result<()> { + self.send_message_no_reply(GpuBackendReq::SET_PROTOCOL_FEATURES, msg, None) + } + + /// Send the VHOST_USER_GPU_GET_DISPLAY_INFO message to the frontend and wait for a reply. + /// Get the preferred display configuration. + pub fn get_display_info(&self) -> io::Result { + self.send_header(GpuBackendReq::GET_DISPLAY_INFO, None) + } + + /// Send the VHOST_USER_GPU_GET_EDID message to the frontend and wait for a reply. + /// Retrieve the EDID data for a given scanout. + /// This message requires the VHOST_USER_GPU_PROTOCOL_F_EDID protocol feature to be supported. + pub fn get_edid(&self, get_edid: &VhostUserGpuEdidRequest) -> io::Result { + self.send_message(GpuBackendReq::GET_EDID, get_edid, None) + } + + /// Send the VHOST_USER_GPU_SCANOUT message to the frontend. Doesn't wait for a reply. + /// Set the scanout resolution. To disable a scanout, the dimensions width/height are set to 0. + pub fn set_scanout(&self, scanout: &VhostUserGpuScanout) -> io::Result<()> { + self.send_message_no_reply(GpuBackendReq::SCANOUT, scanout, None) + } + + /// Sends the VHOST_USER_GPU_UPDATE message to the frontend. Doesn't wait for a reply. + /// Updates the scanout content. The data payload contains the graphical bits. + /// The display should be flushed and presented. + pub fn update_scanout(&self, update: &VhostUserGpuUpdate, data: &[u8]) -> io::Result<()> { + self.send_message_with_payload_no_reply(GpuBackendReq::UPDATE, update, data, None) + } + + /// Send the VHOST_USER_GPU_DMABUF_SCANOUT message to the frontend. Doesn't wait for a reply. + /// Set the scanout resolution/configuration, and share a DMABUF file descriptor for the scanout + /// content, which is passed as ancillary data. To disable a scanout, the dimensions + /// width/height are set to 0, there is no file descriptor passed. + pub fn set_dmabuf_scanout( + &self, + scanout: &VhostUserGpuDMABUFScanout, + fd: Option<&impl AsRawFd>, + ) -> io::Result<()> { + let fd = fd.map(AsRawFd::as_raw_fd); + let fd = fd.as_ref().map(slice::from_ref); + self.send_message_no_reply(GpuBackendReq::DMABUF_SCANOUT, scanout, fd) + } + + /// Send the VHOST_USER_GPU_DMABUF_SCANOUT2 message to the frontend. Doesn't wait for a reply. + /// Same as `set_dmabuf_scanout` (VHOST_USER_GPU_DMABUF_SCANOUT), but also sends the dmabuf + /// modifiers appended to the message, which were not provided in the other message. This + /// message requires the VhostUserGpuProtocolFeatures::DMABUF2 + /// (VHOST_USER_GPU_PROTOCOL_F_DMABUF2) protocol feature to be supported. + pub fn set_dmabuf_scanout2( + &self, + scanout: &VhostUserGpuDMABUFScanout2, + fd: Option<&impl AsRawFd>, + ) -> io::Result<()> { + let fd = fd.map(AsRawFd::as_raw_fd); + let fd = fd.as_ref().map(slice::from_ref); + self.send_message_no_reply(GpuBackendReq::DMABUF_SCANOUT2, scanout, fd) + } + + /// Send the VHOST_USER_GPU_DMABUF_UPDATE message to the frontend. Doesn't wait for a reply. + /// The display should be flushed and presented according to updated region + /// from VhostUserGpuUpdate. + pub fn update_dmabuf_scanout(&self, update: &VhostUserGpuUpdate) -> io::Result<()> { + self.send_message_no_reply(GpuBackendReq::DMABUF_UPDATE, update, None) + } + + /// Send the VHOST_USER_GPU_CURSOR_POS message to the frontend. Doesn't wait for a reply. + /// Set/show the cursor position. + pub fn cursor_pos(&self, cursor_pos: &VhostUserGpuCursorPos) -> io::Result<()> { + self.send_message_no_reply(GpuBackendReq::CURSOR_POS, cursor_pos, None) + } + + /// Send the VHOST_USER_GPU_CURSOR_POS_HIDE message to the frontend. Doesn't wait for a reply. + /// Set/hide the cursor. + pub fn cursor_pos_hide(&self, cursor_pos: &VhostUserGpuCursorPos) -> io::Result<()> { + self.send_message_no_reply(GpuBackendReq::CURSOR_POS_HIDE, cursor_pos, None) + } + + /// Send the VHOST_USER_GPU_CURSOR_UPDATE message to the frontend. Doesn't wait for a reply. + /// Update the cursor shape and location. + /// `data` represents a 64*64 cursor image (PIXMAN_x8r8g8b8 format). + pub fn cursor_update( + &self, + cursor_update: &VhostUserGpuCursorUpdate, + data: &[u8; 4 * 64 * 64], + ) -> io::Result<()> { + self.send_message_with_payload_no_reply( + GpuBackendReq::CURSOR_UPDATE, + cursor_update, + data, + None, + ) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::>::from_stream(sock)) + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&self, error: i32) { + self.node().error = Some(error); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vhost_user::message::VhostUserEmpty; + use libc::STDOUT_FILENO; + use std::mem::{size_of, size_of_val}; + use std::os::fd::RawFd; + use std::thread; + + #[derive(Default, Copy, Clone, Debug)] + struct MockUnreachableResponse; + + // SAFETY: Safe because type is zero size. + unsafe impl ByteValued for MockUnreachableResponse {} + + impl VhostUserMsgValidator for MockUnreachableResponse { + fn is_valid(&self) -> bool { + panic!("UnreachableResponse should not have been constructed and validated!"); + } + } + + #[derive(Copy, Default, Clone, Debug, PartialEq)] + struct MockData(u32); + + // SAFETY: Safe because type is POD. + unsafe impl ByteValued for MockData {} + + impl VhostUserMsgValidator for MockData {} + + const REQUEST: GpuBackendReq = GpuBackendReq::GET_PROTOCOL_FEATURES; + const VALID_REQUEST: MockData = MockData(123); + const VALID_RESPONSE: MockData = MockData(456); + + fn reply_with_valid_response( + frontend: &mut Endpoint>, + hdr: &VhostUserGpuMsgHeader, + ) { + let response_hdr = VhostUserGpuMsgHeader::new( + hdr.get_code().unwrap(), + VhostUserGpuHeaderFlag::REPLY.bits(), + size_of::() as u32, + ); + + frontend + .send_message(&response_hdr, &VALID_RESPONSE, None) + .unwrap(); + } + + fn frontend_backend_pair() -> (Endpoint>, GpuBackend) { + let (backend, frontend) = UnixStream::pair().unwrap(); + let backend = GpuBackend::from_stream(backend); + let frontend = Endpoint::from_stream(frontend); + + (frontend, backend) + } + + #[test] + fn test_gpu_backend_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let backend = GpuBackend::from_stream(p1); + assert!(backend.node().error.is_none()); + backend.set_failed(libc::EAGAIN); + assert_eq!(backend.node().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_gpu_backend_cannot_send_when_failed() { + let (_frontend, backend) = frontend_backend_pair(); + backend.set_failed(libc::EAGAIN); + + backend + .send_header::(GpuBackendReq::GET_PROTOCOL_FEATURES, None) + .unwrap_err(); + + backend + .send_message::( + GpuBackendReq::GET_PROTOCOL_FEATURES, + &VhostUserEmpty, + None, + ) + .unwrap_err(); + + backend + .send_message_no_reply::( + GpuBackendReq::GET_PROTOCOL_FEATURES, + &VhostUserEmpty, + None, + ) + .unwrap_err(); + + backend + .send_message_with_payload_no_reply::( + GpuBackendReq::GET_PROTOCOL_FEATURES, + &VhostUserEmpty, + &[], + None, + ) + .unwrap_err(); + } + + #[test] + fn test_gpu_backend_send_header() { + let (mut frontend, backend) = frontend_backend_pair(); + + let sender_thread = thread::spawn(move || { + let response: MockData = backend.send_header(REQUEST, None).unwrap(); + assert_eq!(response, VALID_RESPONSE); + }); + + let (hdr, fds) = frontend.recv_header().unwrap(); + assert!(fds.is_none()); + + assert_eq!(hdr, VhostUserGpuMsgHeader::new(REQUEST, 0, 0)); + reply_with_valid_response(&mut frontend, &hdr); + sender_thread.join().expect("Sender failed"); + } + + #[test] + fn test_gpu_backend_send_message() { + let (mut frontend, backend) = frontend_backend_pair(); + + let sender_thread = thread::spawn(move || { + let response: MockData = backend.send_message(REQUEST, &VALID_REQUEST, None).unwrap(); + assert_eq!(response, VALID_RESPONSE); + }); + + let (hdr, body, fds) = frontend.recv_body::().unwrap(); + let expected_hdr = + VhostUserGpuMsgHeader::new(REQUEST, 0, size_of_val(&VALID_REQUEST) as u32); + assert_eq!(hdr, expected_hdr); + assert_eq!(body, VALID_REQUEST); + assert!(fds.is_none()); + + reply_with_valid_response(&mut frontend, &hdr); + sender_thread.join().expect("Sender failed"); + } + + #[test] + fn test_gpu_backend_send_message_no_reply_with_fd() { + let (mut frontend, backend) = frontend_backend_pair(); + let expected_hdr = + VhostUserGpuMsgHeader::new(REQUEST, 0, size_of_val(&VALID_REQUEST) as u32); + + let requested_fds: &[RawFd] = &[STDOUT_FILENO]; + + let sender_thread = thread::spawn(move || { + backend + .send_message_no_reply(REQUEST, &VALID_REQUEST, Some(requested_fds)) + .unwrap(); + }); + + let (hdr, body, fds) = frontend.recv_body::().unwrap(); + assert_eq!(hdr, expected_hdr); + assert_eq!(body, VALID_REQUEST); + let fds = fds.unwrap(); + assert_eq!(fds.len(), 1); + + sender_thread.join().expect("Sender failed"); + } + + #[test] + fn test_gpu_backend_send_message_with_big_payload() { + let (mut frontend, backend) = frontend_backend_pair(); + + let expected_payload = vec![1; 8192]; + let expected_hdr = VhostUserGpuMsgHeader::new( + REQUEST, + 0, + (size_of_val(&VALID_REQUEST) + expected_payload.len()) as u32, + ); + + let sending_data = expected_payload.clone(); + let sender_thread = thread::spawn(move || { + backend + .send_message_with_payload_no_reply(REQUEST, &VALID_REQUEST, &sending_data, None) + .unwrap(); + }); + + let mut recv_payload = vec![0; 8192]; + let (hdr, body, num_bytes, fds) = frontend + .recv_payload_into_buf::(&mut recv_payload) + .unwrap(); + + assert_eq!(hdr, expected_hdr); + assert_eq!(body, VALID_REQUEST); + assert_eq!(num_bytes, expected_payload.len()); + assert!(fds.is_none()); + assert_eq!(&recv_payload[..], &expected_payload[..]); + + sender_thread.join().expect("Sender failed"); + } +} diff --git a/vhost/src/vhost_user/gpu_message.rs b/vhost/src/vhost_user/gpu_message.rs new file mode 100644 index 00000000..f90b68d7 --- /dev/null +++ b/vhost/src/vhost_user/gpu_message.rs @@ -0,0 +1,427 @@ +// Copyright (C) 2024 Red Hat, Inc. +// SPDX-License-Identifier: Apache-2.0 + +//! Implementation parts of the protocol on the socket from VHOST_USER_SET_GPU_SOCKET +//! see: https://www.qemu.org/docs/master/interop/vhost-user-gpu.html + +use crate::vhost_user::message::{enum_value, MsgHeader, Req, VhostUserMsgValidator}; +use crate::vhost_user::Error; +use std::fmt::Debug; +use std::marker::PhantomData; +use vm_memory::ByteValued; + +enum_value! { + /// Type of requests sending from gpu backends to gpu frontends. + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] + #[allow(non_camel_case_types, clippy::upper_case_acronyms)] + pub enum GpuBackendReq: u32 { + /// Get the supported protocol features bitmask. + GET_PROTOCOL_FEATURES = 1, + /// Enable protocol features using a bitmask. + SET_PROTOCOL_FEATURES = 2, + /// Get the preferred display configuration. + GET_DISPLAY_INFO = 3, + /// Set/show the cursor position. + CURSOR_POS = 4, + /// Set/hide the cursor. + CURSOR_POS_HIDE = 5, + /// Update the cursor shape and location. + CURSOR_UPDATE = 6, + /// Set the scanout resolution. + /// To disable a scanout, the dimensions width/height are set to 0. + SCANOUT = 7, + /// Update the scanout content. The data payload contains the graphical bits. + /// The display should be flushed and presented. + UPDATE = 8, + /// Set the scanout resolution/configuration, and share a DMABUF file descriptor for the + /// scanout content, which is passed as ancillary data. + /// To disable a scanout, the dimensions width/height are set to 0, there is no file + /// descriptor passed. + DMABUF_SCANOUT = 9, + /// The display should be flushed and presented according to updated region from + /// VhostUserGpuUpdate. + /// Note: there is no data payload, since the scanout is shared thanks to DMABUF, + /// that must have been set previously with VHOST_USER_GPU_DMABUF_SCANOUT. + DMABUF_UPDATE = 10, + /// Retrieve the EDID data for a given scanout. + /// This message requires the VHOST_USER_GPU_PROTOCOL_F_EDID protocol feature to be + /// supported. + GET_EDID = 11, + /// Same as DMABUF_SCANOUT, but also sends the dmabuf modifiers appended to the message, + /// which were not provided in the other message. + /// This message requires the VHOST_USER_GPU_PROTOCOL_F_DMABUF2 protocol feature to be + /// supported. + DMABUF_SCANOUT2 = 12, + } +} + +impl Req for GpuBackendReq {} + +// Bit mask for common message flags. +bitflags! { + /// Common message flags for vhost-user requests and replies. + pub struct VhostUserGpuHeaderFlag: u32 { + /// Mark message as reply. + const REPLY = 0x4; + } +} + +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[repr(C, packed)] +#[derive(Copy)] +pub(super) struct VhostUserGpuMsgHeader { + request: u32, + flags: u32, + size: u32, + _r: PhantomData, +} + +impl Debug for VhostUserGpuMsgHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VhostUserMsgHeader") + .field("request", &{ self.request }) + .field("flags", &{ self.flags }) + .field("size", &{ self.size }) + .finish() + } +} + +impl Clone for VhostUserGpuMsgHeader { + fn clone(&self) -> VhostUserGpuMsgHeader { + *self + } +} + +impl PartialEq for VhostUserGpuMsgHeader { + fn eq(&self, other: &Self) -> bool { + self.request == other.request && self.flags == other.flags && self.size == other.size + } +} + +#[allow(dead_code)] +impl VhostUserGpuMsgHeader { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + VhostUserGpuMsgHeader { + request: request.into(), + flags, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> crate::vhost_user::Result { + R::try_from(self.request).map_err(|_| Error::InvalidMessage) + } + + /// Check whether it's a reply message. + pub fn is_reply(&self) -> bool { + (self.flags & VhostUserGpuHeaderFlag::REPLY.bits()) != 0 + } + + /// Mark message as reply. + pub fn set_reply(&mut self, is_reply: bool) { + if is_reply { + self.flags |= VhostUserGpuHeaderFlag::REPLY.bits(); + } else { + self.flags &= !VhostUserGpuHeaderFlag::REPLY.bits(); + } + } + + /// Check whether it's the reply message for the request `req`. + pub fn is_reply_for(&self, req: &VhostUserGpuMsgHeader) -> bool { + if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) { + self.is_reply() && !req.is_reply() && code1 == code2 + } else { + false + } + } + + /// Get message size. + pub fn get_size(&self) -> u32 { + self.size + } + + /// Set message size. + pub fn set_size(&mut self, size: u32) { + self.size = size; + } +} + +impl Default for VhostUserGpuMsgHeader { + fn default() -> Self { + VhostUserGpuMsgHeader { + request: 0, + flags: 0, + size: 0, + _r: PhantomData, + } + } +} + +// SAFETY: Safe because all fields of VhostUserMsgHeader are POD. +unsafe impl ByteValued for VhostUserGpuMsgHeader {} + +impl VhostUserMsgValidator for VhostUserGpuMsgHeader { + fn is_valid(&self) -> bool { + self.get_code().is_ok() && VhostUserGpuHeaderFlag::from_bits(self.flags).is_some() + } +} + +impl MsgHeader for VhostUserGpuMsgHeader { + type Request = R; + const MAX_MSG_SIZE: usize = u32::MAX as usize; +} + +// Bit mask for vhost-user-gpu protocol feature flags. +bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + /// Vhost-user-gpu protocol feature flags from the QEMU vhost-user-gpu specification. + pub struct VhostUserGpuProtocolFeatures: u64 { + /// Frontend support for EDID + const EDID = 0; + /// Frontend support for DMABUF_SCANOUT2 + const DMABUF2 = 1; + } +} + +/// The virtio_gpu_ctrl_hdr from virtio specification +/// Defined here because some GpuBackend commands return virtio structs, which contain this header. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VirtioGpuCtrlHdr { + /// Specifies the type of the driver request (VIRTIO_GPU_CMD_*) + /// or device response (VIRTIO_GPU_RESP_*). + pub type_: u32, + /// Request / response flags. + pub flags: u32, + /// Set VIRTIO_GPU_FLAG_FENCE bit in the response + pub fence_id: u64, + /// Rendering context (used in 3D mode only). + pub ctx_id: u32, + /// ring_idx indicates the value of a context-specific ring index. + /// The minimum value is 0 and maximum value is 63 (inclusive). + pub ring_idx: u8, + /// padding of the structure + pub padding: [u8; 3], +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VirtioGpuCtrlHdr {} + +/// The virtio_gpu_rect struct from virtio specification. +/// Part of the reply for GpuBackend::get_display_info +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VirtioGpuRect { + /// The position field x describes how the displays are arranged + pub x: u32, + /// The position field y describes how the displays are arranged + pub y: u32, + /// Display resolution width + pub width: u32, + /// Display resolution height + pub height: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VirtioGpuRect {} + +/// The virtio_gpu_display_one struct from virtio specification. +/// Part of the reply for GpuBackend::get_display_info +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VirtioGpuDisplayOne { + /// Preferred display resolutions and display positions relative to each other + pub r: VirtioGpuRect, + /// The enabled field is set when the user enabled the display. + pub enabled: u32, + /// The display flags + pub flags: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VirtioGpuDisplayOne {} + +/// Constant for maximum number of scanouts, defined in the virtio specification. +pub const VIRTIO_GPU_MAX_SCANOUTS: usize = 16; + +/// The virtio_gpu_resp_display_info from the virtio specification. +/// This it the reply from GpuBackend::get_display_info +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VirtioGpuRespDisplayInfo { + /// The fixed header struct + pub hdr: VirtioGpuCtrlHdr, + /// pmodes contains whether the scanout is enabled and what + /// its preferred position and size is + pub pmodes: [VirtioGpuDisplayOne; VIRTIO_GPU_MAX_SCANOUTS], +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VirtioGpuRespDisplayInfo {} + +impl VhostUserMsgValidator for VirtioGpuRespDisplayInfo {} + +/// The VhostUserGpuEdidRequest from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VhostUserGpuEdidRequest { + /// The id of the scanout to retrieve EDID data for + pub scanout_id: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuEdidRequest {} + +impl VhostUserMsgValidator for VhostUserGpuEdidRequest {} + +/// The VhostUserGpuUpdate from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VhostUserGpuUpdate { + /// The id of the scanout that is being updated + pub scanout_id: u32, + /// The x coordinate of the region to update + pub x: u32, + /// The y coordinate of the region to update + pub y: u32, + /// The width of the region to update + pub width: u32, + /// The height of the region to update + pub height: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuUpdate {} + +impl VhostUserMsgValidator for VhostUserGpuUpdate {} + +/// The VhostUserGpuDMABUFScanout from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VhostUserGpuDMABUFScanout { + /// The id of the scanout to update + pub scanout_id: u32, + /// The position field x of the scanout within the DMABUF + pub x: u32, + /// The position field y of the scanout within the DMABUF + pub y: u32, + /// Scanout width size + pub width: u32, + /// Scanout height size + pub height: u32, + /// The DMABUF width + pub fd_width: u32, + /// The DMABUF height + pub fd_height: u32, + /// The DMABUF stride + pub fd_stride: u32, + /// The DMABUF flags + pub fd_flags: u32, + /// The DMABUF fourcc + pub fd_drm_fourcc: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuDMABUFScanout {} + +impl VhostUserMsgValidator for VhostUserGpuDMABUFScanout {} + +/// The VhostUserGpuDMABUFScanout2 from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C, packed)] +pub struct VhostUserGpuDMABUFScanout2 { + /// The dmabuf scanout parameters + pub dmabuf_scanout: VhostUserGpuDMABUFScanout, + /// The DMABUF modifiers + pub modifier: u64, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuDMABUFScanout2 {} + +impl VhostUserMsgValidator for VhostUserGpuDMABUFScanout2 {} + +/// The VhostUserGpuCursorPos from the QEMU vhost-user-gpu specification. +#[derive(Default, Copy, Clone, Debug)] +#[repr(C)] +pub struct VhostUserGpuCursorPos { + /// The scanout where the cursor is located + pub scanout_id: u32, + /// The cursor position field x + pub x: u32, + /// The cursor position field y + pub y: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuCursorPos {} + +impl VhostUserMsgValidator for VhostUserGpuCursorPos {} + +/// The VhostUserGpuCursorUpdate from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Default, Debug)] +#[repr(C)] +pub struct VhostUserGpuCursorUpdate { + /// The cursor location + pub pos: VhostUserGpuCursorPos, + /// The cursor hot location x + pub hot_x: u32, + /// The cursor hot location y + pub hot_y: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuCursorUpdate {} + +impl VhostUserMsgValidator for VhostUserGpuCursorUpdate {} + +/// The virtio_gpu_resp_edid struct from the virtio specification. +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct VirtioGpuRespGetEdid { + /// The fixed header struct + pub hdr: VirtioGpuCtrlHdr, + /// The actual size of the `edid` field. + pub size: u32, + /// Padding of the structure + pub padding: u32, + /// The EDID display data blob (as specified by VESA) for the scanout. + pub edid: [u8; 1024], +} + +impl Default for VirtioGpuRespGetEdid { + fn default() -> Self { + VirtioGpuRespGetEdid { + hdr: VirtioGpuCtrlHdr::default(), + size: u32::default(), + padding: u32::default(), + edid: [0; 1024], // Default value for the edid array (filled with zeros) + } + } +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VirtioGpuRespGetEdid {} + +impl VhostUserMsgValidator for VirtioGpuRespGetEdid {} + +/// The VhostUserGpuScanout from the QEMU vhost-user-gpu specification. +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct VhostUserGpuScanout { + /// The id of the scanout + pub scanout_id: u32, + /// The scanout width + pub width: u32, + /// The scanout height + pub height: u32, +} + +// SAFETY: Safe because all fields are POD. +unsafe impl ByteValued for VhostUserGpuScanout {} + +impl VhostUserMsgValidator for VhostUserGpuScanout {} diff --git a/vhost/src/vhost_user/message.rs b/vhost/src/vhost_user/message.rs index f24331c2..31d8997e 100644 --- a/vhost/src/vhost_user/message.rs +++ b/vhost/src/vhost_user/message.rs @@ -23,6 +23,15 @@ use vm_memory::{GuestAddress, MmapRange, MmapXenFlags}; use super::{Error, Result}; use crate::VringConfigData; +/* +TODO: Consider deprecating this. We don't actually have any preallocated buffers except in tests, +so we should be able to support u32::MAX normally. +Also this doesn't need to be public api, since Endpoint is private anyway, this doesn't seem +useful for consumers of this crate. + +There are GPU specific messages (GpuBackendReq::UPDATE and CURSOR_UPDATE) that are larger than 4K. +We can use MsgHeader::MAX_MSG_SIZE, if we want to support larger messages only for GPU headers. +*/ /// The vhost-user specification uses a field of u32 to store message length. /// On the other hand, preallocated buffers are needed to receive messages from the Unix domain /// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead. @@ -54,6 +63,13 @@ pub(super) trait Req: { } +pub(super) trait MsgHeader: ByteValued + Copy + Default + VhostUserMsgValidator { + type Request: Req; + + /// The maximum size of a msg that can be encapsulated by this MsgHeader + const MAX_MSG_SIZE: usize; +} + macro_rules! enum_value { ( $(#[$meta:meta])* @@ -88,6 +104,7 @@ macro_rules! enum_value { } } } +pub(crate) use enum_value; enum_value! { #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -255,6 +272,11 @@ pub(super) struct VhostUserMsgHeader { _r: PhantomData, } +impl MsgHeader for VhostUserMsgHeader { + type Request = R; + const MAX_MSG_SIZE: usize = MAX_MSG_SIZE; +} + impl Debug for VhostUserMsgHeader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("VhostUserMsgHeader") diff --git a/vhost/src/vhost_user/mod.rs b/vhost/src/vhost_user/mod.rs index 66f8c109..dc9e3d6d 100644 --- a/vhost/src/vhost_user/mod.rs +++ b/vhost/src/vhost_user/mod.rs @@ -52,6 +52,12 @@ pub use self::backend_req_handler::{ mod backend_req; #[cfg(feature = "vhost-user-backend")] pub use self::backend_req::Backend; +#[cfg(feature = "gpu-socket")] +mod gpu_backend_req; +#[cfg(feature = "gpu-socket")] +pub mod gpu_message; +#[cfg(feature = "gpu-socket")] +pub use self::gpu_backend_req::GpuBackend; /// Errors for vhost-user operations #[derive(Debug)]