microsoft/openvmm
Publicmirrored fromhttps://github.com/microsoft/openvmmAvailable
openhcl/diag_server/src/lib.rs
254lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | //! Underhill diagnostics server. |
| 5 | |
| 6 | #![cfg(target_os = "linux")] |
| 7 | #![warn(missing_docs)] |
| 8 | |
| 9 | mod diag_service; |
| 10 | mod new_pty; |
| 11 | |
| 12 | pub use diag_service::DiagRequest; |
| 13 | pub use diag_service::StartParams; |
| 14 | |
| 15 | use anyhow::Context; |
| 16 | use futures::AsyncWriteExt; |
| 17 | use futures::FutureExt; |
| 18 | use mesh::CancelReason; |
| 19 | use mesh_rpc::service::Code; |
| 20 | use mesh_rpc::service::Status; |
| 21 | use pal_async::driver::Driver; |
| 22 | use pal_async::interest::PollEvents; |
| 23 | use pal_async::socket::PollReadyExt; |
| 24 | use pal_async::socket::PolledSocket; |
| 25 | use pal_async::task::Spawn; |
| 26 | use pal_async::task::Task; |
| 27 | use parking_lot::Mutex; |
| 28 | use socket2::Socket; |
| 29 | use std::collections::HashMap; |
| 30 | use std::path::Path; |
| 31 | use std::pin::pin; |
| 32 | use std::sync::Arc; |
| 33 | use unix_socket::UnixListener; |
| 34 | use vmsocket::VmAddress; |
| 35 | use vmsocket::VmListener; |
| 36 | |
| 37 | /// The diagnostics server, which is a ttrpc server listening on `AF_VSOCK` at |
| 38 | /// for control and data. |
| 39 | pub struct DiagServer { |
| 40 | // control listener |
| 41 | control_listener: Socket, |
| 42 | // data listener |
| 43 | data_listener: Socket, |
| 44 | inner: Arc<Inner>, |
| 45 | server: mesh_rpc::Server, |
| 46 | } |
| 47 | |
| 48 | impl DiagServer { |
| 49 | /// Creates a server over VmSockets and starts listening. |
| 50 | pub fn new_vsock(control_address: VmAddress, data_address: VmAddress) -> anyhow::Result<Self> { |
| 51 | tracing::info!(?control_address, "control starting"); |
| 52 | let control_listener = |
| 53 | VmListener::bind(control_address).context("failed to bind socket")?; |
| 54 | |
| 55 | tracing::info!(?data_address, "data starting"); |
| 56 | let data_listener = VmListener::bind(data_address).context("failed to bind socket")?; |
| 57 | |
| 58 | Ok(Self::new_generic( |
| 59 | control_listener.into(), |
| 60 | data_listener.into(), |
| 61 | )) |
| 62 | } |
| 63 | |
| 64 | /// Creates a server over Unix sockets and starts listening. |
| 65 | pub fn new_unix(control_address: &Path, data_address: &Path) -> anyhow::Result<Self> { |
| 66 | tracing::info!(?control_address, "control starting"); |
| 67 | let control_listener = |
| 68 | UnixListener::bind(control_address).context("failed to bind socket")?; |
| 69 | |
| 70 | tracing::info!(?data_address, "data starting"); |
| 71 | let data_listener = UnixListener::bind(data_address).context("failed to bind socket")?; |
| 72 | |
| 73 | Ok(Self::new_generic( |
| 74 | control_listener.into(), |
| 75 | data_listener.into(), |
| 76 | )) |
| 77 | } |
| 78 | |
| 79 | fn new_generic(control_listener: Socket, data_listener: Socket) -> Self { |
| 80 | Self { |
| 81 | control_listener, |
| 82 | data_listener, |
| 83 | server: mesh_rpc::Server::new(), |
| 84 | inner: Arc::new(Inner { |
| 85 | connections: Mutex::new(DataConnections { |
| 86 | next_id: 1, // connection IDs start at 1, as 0 is an invalid ID. |
| 87 | active: Default::default(), |
| 88 | }), |
| 89 | }), |
| 90 | } |
| 91 | } |
| 92 | |
| 93 | /// Serves requests until `cancel` is dropped. |
| 94 | pub async fn serve( |
| 95 | mut self, |
| 96 | driver: &(impl Driver + Spawn + Clone), |
| 97 | cancel: mesh::OneshotReceiver<()>, |
| 98 | request_send: mesh::Sender<DiagRequest>, |
| 99 | ) -> anyhow::Result<()> { |
| 100 | let (diag_send, diag_recv) = mesh::channel(); |
| 101 | let (inspect_send, inspect_recv) = mesh::channel(); |
| 102 | // Disable all diag requests for CVMs. Inspect filtering will be handled |
| 103 | // internally more granularly. |
| 104 | if !underhill_confidentiality::confidential_filtering_enabled() { |
| 105 | self.server.add_service(diag_send); |
| 106 | } |
| 107 | |
| 108 | self.server.add_service(inspect_send); |
| 109 | |
| 110 | // TODO: split the profiler to a separate service provider. |
| 111 | let (profile_send, profile_recv) = mesh::channel(); |
| 112 | self.server.add_service(profile_send); |
| 113 | |
| 114 | let diag_service = Arc::new(diag_service::DiagServiceHandler::new( |
| 115 | request_send, |
| 116 | self.inner.clone(), |
| 117 | )); |
| 118 | let process = diag_service.process_requests(driver, diag_recv, inspect_recv, profile_recv); |
| 119 | |
| 120 | let serve = self.server.run(driver, self.control_listener, cancel); |
| 121 | let data_connections = self |
| 122 | .inner |
| 123 | .process_data_connections(driver, self.data_listener); |
| 124 | |
| 125 | futures::future::try_join3(serve, process, data_connections).await?; |
| 126 | Ok(()) |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | #[derive(Debug)] |
| 131 | struct DataConnectionEntry { |
| 132 | /// Sender used to notify the hangup task to return the socket. |
| 133 | sender: mesh::OneshotSender<()>, |
| 134 | /// Task used to wait for hangup notifications or a request to return the socket. |
| 135 | task: Task<Option<PolledSocket<Socket>>>, |
| 136 | } |
| 137 | |
| 138 | #[derive(Debug, Default)] |
| 139 | struct DataConnections { |
| 140 | next_id: u64, |
| 141 | active: HashMap<u64, DataConnectionEntry>, |
| 142 | } |
| 143 | |
| 144 | impl DataConnections { |
| 145 | fn take_connection(&mut self, id: u64) -> anyhow::Result<DataConnectionEntry> { |
| 146 | self.active |
| 147 | .remove(&id) |
| 148 | .ok_or_else(|| anyhow::anyhow!("invalid connection id")) |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | struct Inner { |
| 153 | connections: Mutex<DataConnections>, |
| 154 | } |
| 155 | |
| 156 | impl Inner { |
| 157 | async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> { |
| 158 | let DataConnectionEntry { sender, task } = self.connections.lock().take_connection(id)?; |
| 159 | |
| 160 | sender.send(()); |
| 161 | task.await |
| 162 | .ok_or_else(|| anyhow::anyhow!("connection disconnected")) |
| 163 | } |
| 164 | |
| 165 | /// Listen for data connections and add them to the internal connections lookup table as they arrive. |
| 166 | async fn process_data_connections( |
| 167 | self: &Arc<Self>, |
| 168 | driver: &(impl Driver + Spawn + Clone), |
| 169 | listener: Socket, |
| 170 | ) -> anyhow::Result<()> { |
| 171 | let mut listener = PolledSocket::new(driver, listener)?; |
| 172 | |
| 173 | loop { |
| 174 | let (connection, _addr) = listener.accept().await?; |
| 175 | let mut socket = PolledSocket::new(driver, connection)?; |
| 176 | let inner = Arc::downgrade(self); |
| 177 | |
| 178 | // Send the 8 byte connection id, then stash the connection in the lookup table to be used later. |
| 179 | let id; |
| 180 | { |
| 181 | let mut state = self.connections.lock(); |
| 182 | id = state.next_id; |
| 183 | state.next_id += 1; |
| 184 | |
| 185 | tracing::debug!(id, "new data connection"); |
| 186 | } |
| 187 | |
| 188 | let (sender, recv) = mesh::oneshot(); |
| 189 | |
| 190 | // Spawn a task that returns the socket when asked to, or removes itself from the map if disconnected. |
| 191 | let task = driver.spawn(format!("data connection {} waiting", id), async move { |
| 192 | match socket.write_all(&id.to_ne_bytes()).await { |
| 193 | Ok(_) => {} |
| 194 | Err(error) => { |
| 195 | tracing::trace!(?error, "error writing connection id, removing."); |
| 196 | if let Some(state) = inner.upgrade() { |
| 197 | state.connections.lock().active.remove(&id); |
| 198 | } |
| 199 | |
| 200 | return None; |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | let mut return_future = pin!(async { recv.await.is_ok() }.fuse()); |
| 205 | let hangup = futures::select! { // race semantics |
| 206 | _ = socket.wait_ready(PollEvents::RDHUP).fuse() => true, |
| 207 | _ = return_future => false, |
| 208 | }; |
| 209 | |
| 210 | if hangup { |
| 211 | // Other side has disconnected, remove from the table if not already done. |
| 212 | tracing::trace!(id, "data connection disconnected"); |
| 213 | if let Some(state) = inner.upgrade() { |
| 214 | state.connections.lock().active.remove(&id); |
| 215 | } |
| 216 | |
| 217 | None |
| 218 | } else { |
| 219 | Some(socket) |
| 220 | } |
| 221 | }); |
| 222 | |
| 223 | let mut state = self.connections.lock(); |
| 224 | let result = state |
| 225 | .active |
| 226 | .insert(id, DataConnectionEntry { sender, task }); |
| 227 | |
| 228 | if result.is_some() { |
| 229 | anyhow::bail!("connection id reused"); |
| 230 | } |
| 231 | } |
| 232 | } |
| 233 | } |
| 234 | |
| 235 | fn grpc_result<T>(result: Result<anyhow::Result<T>, CancelReason>) -> Result<T, Status> { |
| 236 | match result { |
| 237 | Ok(result) => match result { |
| 238 | Ok(value) => Ok(value), |
| 239 | Err(err) => Err(Status { |
| 240 | code: Code::Unknown as i32, |
| 241 | message: format!("{:#}", err), |
| 242 | details: vec![], |
| 243 | }), |
| 244 | }, |
| 245 | Err(err) => Err(Status { |
| 246 | code: match &err { |
| 247 | CancelReason::Cancelled => Code::Cancelled, |
| 248 | CancelReason::DeadlineExceeded => Code::DeadlineExceeded, |
| 249 | } as i32, |
| 250 | message: format!("{:#}", err), |
| 251 | details: vec![], |
| 252 | }), |
| 253 | } |
| 254 | } |
| 255 | |