Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coverage_config_x86_64.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"coverage_score": 85.41,
"coverage_score": 86.22,
"exclude_path": "vhost/src/vhost_kern/",
"crate_features": "vhost/vhost-user-frontend,vhost/vhost-user-backend,vhost-user-backend/postcopy"
}
2 changes: 2 additions & 0 deletions vhost/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## [Unreleased]

### Added
- [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support

### Changed
### Deprecated
### Fixed
Expand Down
73 changes: 66 additions & 7 deletions vhost/src/vhost_user/backend_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ impl VhostUserFrontendReqHandler for Backend {
Some(&[fd.as_raw_fd()]),
)
}

/// Forward vhost-user memory map file request to the frontend.
fn shmem_map(&self, req: &VhostUserMMap, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.send_message(BackendReq::SHMEM_MAP, req, Some(&[fd.as_raw_fd()]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In shared_object_*() functions above we check if the related feature is negotiated, should we do something similar also here?

}

/// Forward vhost-user memory unmap file request to the frontend.
fn shmem_unmap(&self, req: &VhostUserMMap) -> HandlerResult<u64> {
self.send_message(BackendReq::SHMEM_UNMAP, req, None)
}
}

#[cfg(test)]
Expand All @@ -182,10 +192,16 @@ mod tests {

use super::*;

fn frontend_backend_pair() -> (Endpoint<VhostUserMsgHeader<BackendReq>>, Backend) {
let (p1, p2) = UnixStream::pair().unwrap();
let backend = Backend::from_stream(p1);
let frontend = Endpoint::<VhostUserMsgHeader<BackendReq>>::from_stream(p2);
(frontend, backend)
}

#[test]
fn test_backend_req_set_failed() {
let (p1, _p2) = UnixStream::pair().unwrap();
let backend = Backend::from_stream(p1);
let (_, backend) = frontend_backend_pair();

assert!(backend.node().error.is_none());
backend.set_failed(libc::EAGAIN);
Expand All @@ -194,8 +210,7 @@ mod tests {

#[test]
fn test_backend_req_send_failure() {
let (p1, _) = UnixStream::pair().unwrap();
let backend = Backend::from_stream(p1);
let (_, backend) = frontend_backend_pair();

backend.set_failed(libc::ECONNRESET);
backend
Expand All @@ -209,9 +224,7 @@ mod tests {

#[test]
fn test_backend_req_recv_negative() {
let (p1, p2) = UnixStream::pair().unwrap();
let backend = Backend::from_stream(p1);
let mut frontend = Endpoint::<VhostUserMsgHeader<BackendReq>>::from_stream(p2);
let (mut frontend, backend) = frontend_backend_pair();

let len = mem::size_of::<VhostUserSharedMsg>();
let mut hdr = VhostUserMsgHeader::new(
Expand Down Expand Up @@ -257,4 +270,50 @@ mod tests {
.shared_object_add(&VhostUserSharedMsg::default())
.unwrap();
}

#[test]
fn test_shmem_map() {
let (mut frontend, backend) = frontend_backend_pair();

let (_, some_fd_to_send) = UnixStream::pair().unwrap();
let map_request = VhostUserMMap {
shmid: 0,
padding: Default::default(),
fd_offset: 0,
shm_offset: 1028,
len: 4096,
flags: VhostUserMMapFlags::WRITABLE.bits(),
};

backend.shmem_map(&map_request, &some_fd_to_send).unwrap();

let (hdr, request, fd) = frontend.recv_body::<VhostUserMMap>().unwrap();
assert_eq!(hdr.get_code().unwrap(), BackendReq::SHMEM_MAP);
assert!(fd.is_some());
assert_eq!({ request.shm_offset }, { map_request.shm_offset });
assert_eq!({ request.len }, { map_request.len },);
assert_eq!({ request.flags }, { map_request.flags });
}

#[test]
fn test_shmem_unmap() {
let (mut frontend, backend) = frontend_backend_pair();

let unmap_request = VhostUserMMap {
shmid: 0,
padding: Default::default(),
fd_offset: 0,
shm_offset: 1028,
len: 4096,
flags: 0,
};

backend.shmem_unmap(&unmap_request).unwrap();

let (hdr, request, fd) = frontend.recv_body::<VhostUserMMap>().unwrap();
assert_eq!(hdr.get_code().unwrap(), BackendReq::SHMEM_UNMAP);
assert!(fd.is_none());
assert_eq!({ request.shm_offset }, { unmap_request.shm_offset });
assert_eq!({ request.len }, { unmap_request.len });
}
}
177 changes: 176 additions & 1 deletion vhost/src/vhost_user/frontend_req_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ pub trait VhostUserFrontendReqHandler {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

/// Handle shared memory region mapping requests.
fn shmem_map(&self, _req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

/// Handle shared memory region unmapping requests.
fn shmem_unmap(&self, _req: &VhostUserMMap) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
// fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd);
}
Expand Down Expand Up @@ -84,6 +94,16 @@ pub trait VhostUserFrontendReqHandlerMut {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

/// Handle shared memory region mapping requests.
fn shmem_map(&mut self, _req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

/// Handle shared memory region unmapping requests.
fn shmem_unmap(&mut self, _req: &VhostUserMMap) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}

// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
// fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
}
Expand Down Expand Up @@ -111,6 +131,14 @@ impl<S: VhostUserFrontendReqHandlerMut> VhostUserFrontendReqHandler for Mutex<S>
) -> HandlerResult<u64> {
self.lock().unwrap().shared_object_lookup(uuid, fd)
}

fn shmem_map(&self, req: &VhostUserMMap, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().shmem_map(req, fd)
}

fn shmem_unmap(&self, req: &VhostUserMMap) -> HandlerResult<u64> {
self.lock().unwrap().shmem_unmap(req)
}
}

/// Server to handle service requests from backends from the backend communication channel.
Expand Down Expand Up @@ -241,6 +269,18 @@ impl<S: VhostUserFrontendReqHandler> FrontendReqHandler<S> {
.shared_object_lookup(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHMEM_MAP) => {
let msg = self.extract_msg_body::<VhostUserMMap>(&hdr, size, &buf)?;
self.backend
.shmem_map(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
Ok(BackendReq::SHMEM_UNMAP) => {
let msg = self.extract_msg_body::<VhostUserMMap>(&hdr, size, &buf)?;
self.backend
.shmem_unmap(&msg)
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
};

Expand Down Expand Up @@ -278,7 +318,7 @@ impl<S: VhostUserFrontendReqHandler> FrontendReqHandler<S> {
files: &Option<Vec<File>>,
) -> Result<()> {
match hdr.get_code() {
Ok(BackendReq::SHARED_OBJECT_LOOKUP) => {
Ok(BackendReq::SHARED_OBJECT_LOOKUP | BackendReq::SHMEM_MAP) => {
// Expect a single file is passed.
match files {
Some(files) if files.len() == 1 => Ok(()),
Expand Down Expand Up @@ -356,6 +396,7 @@ mod tests {
use super::*;

use std::collections::HashSet;
use std::io::ErrorKind;

use uuid::Uuid;

Expand All @@ -366,12 +407,14 @@ mod tests {

struct MockFrontendReqHandler {
shared_objects: HashSet<Uuid>,
shmem_mappings: HashSet<(u64, u64)>,
}

impl MockFrontendReqHandler {
fn new() -> Self {
Self {
shared_objects: HashSet::new(),
shmem_mappings: HashSet::new(),
}
}
}
Expand All @@ -395,6 +438,88 @@ mod tests {
}
Ok(1)
}

fn shmem_map(&mut self, req: &VhostUserMMap, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
assert_eq!({ req.shmid }, 0);
if self.shmem_mappings.insert((req.shm_offset, req.len)) {
return Ok(0);
};
Ok(1)
}

fn shmem_unmap(&mut self, req: &VhostUserMMap) -> HandlerResult<u64> {
assert_eq!({ req.shmid }, 0);
if self.shmem_mappings.remove(&(req.shm_offset, req.len)) {
return Ok(0);
}
Ok(1)
}
}

#[test]
fn test_default_frontend_impl() {
struct Frontend;
impl VhostUserFrontendReqHandler for Frontend {}

let f = Frontend;
assert_eq!(
f.shared_object_add(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_remove(&Default::default())
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_lookup(&Default::default(), &0)
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);

assert_eq!(
f.shmem_map(&Default::default(), &0).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_unmap(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
}

#[test]
fn test_default_frontend_impl_mut() {
struct FrontendMut;
impl VhostUserFrontendReqHandlerMut for FrontendMut {}

let mut f = FrontendMut;
assert_eq!(
f.shared_object_add(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_remove(&Default::default())
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shared_object_lookup(&Default::default(), &0)
.unwrap_err()
.kind(),
ErrorKind::Unsupported
);

assert_eq!(
f.shmem_map(&Default::default(), &0).unwrap_err().kind(),
ErrorKind::Unsupported
);
assert_eq!(
f.shmem_unmap(&Default::default()).unwrap_err().kind(),
ErrorKind::Unsupported
);
}

#[test]
Expand Down Expand Up @@ -436,6 +561,13 @@ mod tests {
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);

// Testing shmem map/unmap messages.
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
});

backend.set_shared_object_flag(true);
Expand All @@ -456,6 +588,24 @@ mod tests {
.is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());

let (_, some_fd_to_map) = UnixStream::pair().unwrap();
let map_request1 = VhostUserMMap {
shm_offset: 0,
len: 4096,
..Default::default()
};
let map_request2 = VhostUserMMap {
shm_offset: 4096,
len: 8192,
..Default::default()
};
backend.shmem_map(&map_request1, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_map(&map_request2, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_unmap(&map_request1).unwrap();

// Ensure that the handler thread did not panic.
assert!(frontend_handler.join().is_ok());
}
Expand Down Expand Up @@ -485,6 +635,13 @@ mod tests {
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);

// Testing shmem map/unmap messages.
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 1);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
assert_eq!(handler.handle_request().unwrap(), 0);
});

backend.set_reply_ack_flag(true);
Expand All @@ -506,6 +663,24 @@ mod tests {
.is_err());
assert!(backend.shared_object_remove(&shobj_msg).is_ok());
assert!(backend.shared_object_remove(&shobj_msg).is_err());

let (_, some_fd_to_map) = UnixStream::pair().unwrap();
let map_request1 = VhostUserMMap {
shm_offset: 0,
len: 4096,
..Default::default()
};
let map_request2 = VhostUserMMap {
shm_offset: 4096,
len: 8192,
..Default::default()
};
backend.shmem_map(&map_request1, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap_err();
backend.shmem_map(&map_request2, &some_fd_to_map).unwrap();
backend.shmem_unmap(&map_request2).unwrap();
backend.shmem_unmap(&map_request1).unwrap();

// Ensure that the handler thread did not panic.
assert!(frontend_handler.join().is_ok());
}
Expand Down
Loading