microsoft/openvmm
Publicmirrored fromhttps://github.com/microsoft/openvmmAvailable
openhcl/diag_server/src/lib.rs
266lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | //! Underhill diagnostics server. |
| 5 | |
| 6 | #![cfg(target_os = "linux")] |
| 7 | |
| 8 | mod diag_service; |
| 9 | mod new_pty; |
| 10 | |
| 11 | pub use diag_service::DiagRequest; |
| 12 | pub use diag_service::StartParams; |
| 13 | |
| 14 | use anyhow::Context; |
| 15 | use cvm_tracing::CVM_ALLOWED; |
| 16 | use futures::AsyncWriteExt; |
| 17 | use futures::FutureExt; |
| 18 | use mesh::CancelReason; |
| 19 | use mesh_rpc::server::RpcReceiver; |
| 20 | use mesh_rpc::service::Code; |
| 21 | use mesh_rpc::service::Status; |
| 22 | use pal_async::driver::Driver; |
| 23 | use pal_async::interest::PollEvents; |
| 24 | use pal_async::socket::PollReadyExt; |
| 25 | use pal_async::socket::PolledSocket; |
| 26 | use pal_async::task::Spawn; |
| 27 | use pal_async::task::Task; |
| 28 | use parking_lot::Mutex; |
| 29 | use socket2::Socket; |
| 30 | use std::collections::HashMap; |
| 31 | use std::path::Path; |
| 32 | use std::pin::pin; |
| 33 | use std::sync::Arc; |
| 34 | use unix_socket::UnixListener; |
| 35 | use vmsocket::VmAddress; |
| 36 | use vmsocket::VmListener; |
| 37 | |
| 38 | /// The diagnostics server, which is a ttrpc server listening on `AF_VSOCK` at |
| 39 | /// for control and data. |
| 40 | pub struct DiagServer { |
| 41 | // control listener |
| 42 | control_listener: Socket, |
| 43 | // data listener |
| 44 | data_listener: Socket, |
| 45 | inner: Arc<Inner>, |
| 46 | server: mesh_rpc::Server, |
| 47 | } |
| 48 | |
| 49 | impl DiagServer { |
| 50 | /// Creates a server over VmSockets and starts listening. |
| 51 | pub fn new_vsock(control_address: VmAddress, data_address: VmAddress) -> anyhow::Result<Self> { |
| 52 | tracing::info!(CVM_ALLOWED, ?control_address, "control starting"); |
| 53 | let control_listener = |
| 54 | VmListener::bind(control_address).context("failed to bind socket")?; |
| 55 | |
| 56 | tracing::info!(CVM_ALLOWED, ?data_address, "data starting"); |
| 57 | let data_listener = VmListener::bind(data_address).context("failed to bind socket")?; |
| 58 | |
| 59 | Ok(Self::new_generic( |
| 60 | control_listener.into(), |
| 61 | data_listener.into(), |
| 62 | )) |
| 63 | } |
| 64 | |
| 65 | /// Creates a server over Unix sockets and starts listening. |
| 66 | pub fn new_unix(control_address: &Path, data_address: &Path) -> anyhow::Result<Self> { |
| 67 | tracing::info!(CVM_ALLOWED, ?control_address, "control starting"); |
| 68 | let control_listener = |
| 69 | UnixListener::bind(control_address).context("failed to bind socket")?; |
| 70 | |
| 71 | tracing::info!(CVM_ALLOWED, ?data_address, "data starting"); |
| 72 | let data_listener = UnixListener::bind(data_address).context("failed to bind socket")?; |
| 73 | |
| 74 | Ok(Self::new_generic( |
| 75 | control_listener.into(), |
| 76 | data_listener.into(), |
| 77 | )) |
| 78 | } |
| 79 | |
| 80 | fn new_generic(control_listener: Socket, data_listener: Socket) -> Self { |
| 81 | Self { |
| 82 | control_listener, |
| 83 | data_listener, |
| 84 | server: mesh_rpc::Server::new(), |
| 85 | inner: Arc::new(Inner { |
| 86 | connections: Mutex::new(DataConnections { |
| 87 | next_id: 1, // connection IDs start at 1, as 0 is an invalid ID. |
| 88 | active: Default::default(), |
| 89 | }), |
| 90 | }), |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | /// Serves requests until `cancel` is dropped. |
| 95 | pub async fn serve( |
| 96 | mut self, |
| 97 | driver: &(impl Driver + Spawn + Clone), |
| 98 | cancel: mesh::OneshotReceiver<()>, |
| 99 | request_send: mesh::Sender<DiagRequest>, |
| 100 | ) -> anyhow::Result<()> { |
| 101 | // Disable all diag requests for CVMs. Inspect filtering will be handled |
| 102 | // internally more granularly. |
| 103 | let (diag_recv, diag2_recv) = if underhill_confidentiality::confidential_filtering_enabled() |
| 104 | { |
| 105 | (RpcReceiver::disconnected(), RpcReceiver::disconnected()) |
| 106 | } else { |
| 107 | ( |
| 108 | self.server.add_service::<diag_proto::UnderhillDiag>(), |
| 109 | self.server.add_service::<diag_proto::OpenhclDiag>(), |
| 110 | ) |
| 111 | }; |
| 112 | |
| 113 | let inspect_recv = self.server.add_service::<inspect_proto::InspectService>(); |
| 114 | |
| 115 | // TODO: split the profiler to a separate service provider. |
| 116 | let profile_recv = self |
| 117 | .server |
| 118 | .add_service::<azure_profiler_proto::AzureProfiler>(); |
| 119 | |
| 120 | let diag_service = Arc::new(diag_service::DiagServiceHandler::new( |
| 121 | request_send, |
| 122 | self.inner.clone(), |
| 123 | )); |
| 124 | let process = diag_service.process_requests( |
| 125 | driver, |
| 126 | diag_recv, |
| 127 | diag2_recv, |
| 128 | inspect_recv, |
| 129 | profile_recv, |
| 130 | ); |
| 131 | |
| 132 | let serve = self.server.run(driver, self.control_listener, cancel); |
| 133 | let data_connections = self |
| 134 | .inner |
| 135 | .process_data_connections(driver, self.data_listener); |
| 136 | |
| 137 | futures::future::try_join3(serve, process, data_connections).await?; |
| 138 | Ok(()) |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | #[derive(Debug)] |
| 143 | struct DataConnectionEntry { |
| 144 | /// Sender used to notify the hangup task to return the socket. |
| 145 | sender: mesh::OneshotSender<()>, |
| 146 | /// Task used to wait for hangup notifications or a request to return the socket. |
| 147 | task: Task<Option<PolledSocket<Socket>>>, |
| 148 | } |
| 149 | |
| 150 | #[derive(Debug, Default)] |
| 151 | struct DataConnections { |
| 152 | next_id: u64, |
| 153 | active: HashMap<u64, DataConnectionEntry>, |
| 154 | } |
| 155 | |
| 156 | impl DataConnections { |
| 157 | fn take_connection(&mut self, id: u64) -> anyhow::Result<DataConnectionEntry> { |
| 158 | self.active |
| 159 | .remove(&id) |
| 160 | .ok_or_else(|| anyhow::anyhow!("invalid connection id")) |
| 161 | } |
| 162 | } |
| 163 | |
| 164 | struct Inner { |
| 165 | connections: Mutex<DataConnections>, |
| 166 | } |
| 167 | |
| 168 | impl Inner { |
| 169 | async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> { |
| 170 | let DataConnectionEntry { sender, task } = self.connections.lock().take_connection(id)?; |
| 171 | |
| 172 | sender.send(()); |
| 173 | task.await |
| 174 | .ok_or_else(|| anyhow::anyhow!("connection disconnected")) |
| 175 | } |
| 176 | |
| 177 | /// Listen for data connections and add them to the internal connections lookup table as they arrive. |
| 178 | async fn process_data_connections( |
| 179 | self: &Arc<Self>, |
| 180 | driver: &(impl Driver + Spawn + Clone), |
| 181 | listener: Socket, |
| 182 | ) -> anyhow::Result<()> { |
| 183 | let mut listener = PolledSocket::new(driver, listener)?; |
| 184 | |
| 185 | loop { |
| 186 | let (connection, _addr) = listener.accept().await?; |
| 187 | let mut socket = PolledSocket::new(driver, connection)?; |
| 188 | let inner = Arc::downgrade(self); |
| 189 | |
| 190 | // Send the 8 byte connection id, then stash the connection in the lookup table to be used later. |
| 191 | let id; |
| 192 | { |
| 193 | let mut state = self.connections.lock(); |
| 194 | id = state.next_id; |
| 195 | state.next_id += 1; |
| 196 | |
| 197 | tracing::debug!(id, "new data connection"); |
| 198 | } |
| 199 | |
| 200 | let (sender, recv) = mesh::oneshot(); |
| 201 | |
| 202 | // Spawn a task that returns the socket when asked to, or removes itself from the map if disconnected. |
| 203 | let task = driver.spawn(format!("data connection {} waiting", id), async move { |
| 204 | match socket.write_all(&id.to_ne_bytes()).await { |
| 205 | Ok(_) => {} |
| 206 | Err(error) => { |
| 207 | tracing::trace!(?error, "error writing connection id, removing."); |
| 208 | if let Some(state) = inner.upgrade() { |
| 209 | state.connections.lock().active.remove(&id); |
| 210 | } |
| 211 | |
| 212 | return None; |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | let mut return_future = pin!(async { recv.await.is_ok() }.fuse()); |
| 217 | let hangup = futures::select! { // race semantics |
| 218 | _ = socket.wait_ready(PollEvents::RDHUP).fuse() => true, |
| 219 | _ = return_future => false, |
| 220 | }; |
| 221 | |
| 222 | if hangup { |
| 223 | // Other side has disconnected, remove from the table if not already done. |
| 224 | tracing::trace!(id, "data connection disconnected"); |
| 225 | if let Some(state) = inner.upgrade() { |
| 226 | state.connections.lock().active.remove(&id); |
| 227 | } |
| 228 | |
| 229 | None |
| 230 | } else { |
| 231 | Some(socket) |
| 232 | } |
| 233 | }); |
| 234 | |
| 235 | let mut state = self.connections.lock(); |
| 236 | let result = state |
| 237 | .active |
| 238 | .insert(id, DataConnectionEntry { sender, task }); |
| 239 | |
| 240 | if result.is_some() { |
| 241 | anyhow::bail!("connection id reused"); |
| 242 | } |
| 243 | } |
| 244 | } |
| 245 | } |
| 246 | |
| 247 | fn grpc_result<T>(result: Result<anyhow::Result<T>, CancelReason>) -> Result<T, Status> { |
| 248 | match result { |
| 249 | Ok(result) => match result { |
| 250 | Ok(value) => Ok(value), |
| 251 | Err(err) => Err(Status { |
| 252 | code: Code::Unknown as i32, |
| 253 | message: format!("{:#}", err), |
| 254 | details: vec![], |
| 255 | }), |
| 256 | }, |
| 257 | Err(err) => Err(Status { |
| 258 | code: match &err { |
| 259 | CancelReason::Cancelled => Code::Cancelled, |
| 260 | CancelReason::DeadlineExceeded => Code::DeadlineExceeded, |
| 261 | } as i32, |
| 262 | message: format!("{:#}", err), |
| 263 | details: vec![], |
| 264 | }), |
| 265 | } |
| 266 | } |
| 267 | |