microsoft/openvmm

Public

mirrored fromhttps://github.com/microsoft/openvmmAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
93af13fed5d5fc7a8a08fbf37c0ea1e155c4160a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

openhcl/diag_server/src/lib.rs

254lines · modecode

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Underhill diagnostics server.
5
6#![cfg(target_os = "linux")]
7#![warn(missing_docs)]
8
9mod diag_service;
10mod new_pty;
11
12pub use diag_service::DiagRequest;
13pub use diag_service::StartParams;
14
15use anyhow::Context;
16use futures::AsyncWriteExt;
17use futures::FutureExt;
18use mesh::CancelReason;
19use mesh_rpc::service::Code;
20use mesh_rpc::service::Status;
21use pal_async::driver::Driver;
22use pal_async::interest::PollEvents;
23use pal_async::socket::PollReadyExt;
24use pal_async::socket::PolledSocket;
25use pal_async::task::Spawn;
26use pal_async::task::Task;
27use parking_lot::Mutex;
28use socket2::Socket;
29use std::collections::HashMap;
30use std::path::Path;
31use std::pin::pin;
32use std::sync::Arc;
33use unix_socket::UnixListener;
34use vmsocket::VmAddress;
35use vmsocket::VmListener;
36
37/// The diagnostics server, which is a ttrpc server listening on `AF_VSOCK` at
38/// for control and data.
39pub struct DiagServer {
40 // control listener
41 control_listener: Socket,
42 // data listener
43 data_listener: Socket,
44 inner: Arc<Inner>,
45 server: mesh_rpc::Server,
46}
47
48impl DiagServer {
49 /// Creates a server over VmSockets and starts listening.
50 pub fn new_vsock(control_address: VmAddress, data_address: VmAddress) -> anyhow::Result<Self> {
51 tracing::info!(?control_address, "control starting");
52 let control_listener =
53 VmListener::bind(control_address).context("failed to bind socket")?;
54
55 tracing::info!(?data_address, "data starting");
56 let data_listener = VmListener::bind(data_address).context("failed to bind socket")?;
57
58 Ok(Self::new_generic(
59 control_listener.into(),
60 data_listener.into(),
61 ))
62 }
63
64 /// Creates a server over Unix sockets and starts listening.
65 pub fn new_unix(control_address: &Path, data_address: &Path) -> anyhow::Result<Self> {
66 tracing::info!(?control_address, "control starting");
67 let control_listener =
68 UnixListener::bind(control_address).context("failed to bind socket")?;
69
70 tracing::info!(?data_address, "data starting");
71 let data_listener = UnixListener::bind(data_address).context("failed to bind socket")?;
72
73 Ok(Self::new_generic(
74 control_listener.into(),
75 data_listener.into(),
76 ))
77 }
78
79 fn new_generic(control_listener: Socket, data_listener: Socket) -> Self {
80 Self {
81 control_listener,
82 data_listener,
83 server: mesh_rpc::Server::new(),
84 inner: Arc::new(Inner {
85 connections: Mutex::new(DataConnections {
86 next_id: 1, // connection IDs start at 1, as 0 is an invalid ID.
87 active: Default::default(),
88 }),
89 }),
90 }
91 }
92
93 /// Serves requests until `cancel` is dropped.
94 pub async fn serve(
95 mut self,
96 driver: &(impl Driver + Spawn + Clone),
97 cancel: mesh::OneshotReceiver<()>,
98 request_send: mesh::Sender<DiagRequest>,
99 ) -> anyhow::Result<()> {
100 let (diag_send, diag_recv) = mesh::channel();
101 let (inspect_send, inspect_recv) = mesh::channel();
102 // Disable all diag requests for CVMs. Inspect filtering will be handled
103 // internally more granularly.
104 if !underhill_confidentiality::confidential_filtering_enabled() {
105 self.server.add_service(diag_send);
106 }
107
108 self.server.add_service(inspect_send);
109
110 // TODO: split the profiler to a separate service provider.
111 let (profile_send, profile_recv) = mesh::channel();
112 self.server.add_service(profile_send);
113
114 let diag_service = Arc::new(diag_service::DiagServiceHandler::new(
115 request_send,
116 self.inner.clone(),
117 ));
118 let process = diag_service.process_requests(driver, diag_recv, inspect_recv, profile_recv);
119
120 let serve = self.server.run(driver, self.control_listener, cancel);
121 let data_connections = self
122 .inner
123 .process_data_connections(driver, self.data_listener);
124
125 futures::future::try_join3(serve, process, data_connections).await?;
126 Ok(())
127 }
128}
129
130#[derive(Debug)]
131struct DataConnectionEntry {
132 /// Sender used to notify the hangup task to return the socket.
133 sender: mesh::OneshotSender<()>,
134 /// Task used to wait for hangup notifications or a request to return the socket.
135 task: Task<Option<PolledSocket<Socket>>>,
136}
137
138#[derive(Debug, Default)]
139struct DataConnections {
140 next_id: u64,
141 active: HashMap<u64, DataConnectionEntry>,
142}
143
144impl DataConnections {
145 fn take_connection(&mut self, id: u64) -> anyhow::Result<DataConnectionEntry> {
146 self.active
147 .remove(&id)
148 .ok_or_else(|| anyhow::anyhow!("invalid connection id"))
149 }
150}
151
152struct Inner {
153 connections: Mutex<DataConnections>,
154}
155
156impl Inner {
157 async fn take_connection(&self, id: u64) -> anyhow::Result<PolledSocket<Socket>> {
158 let DataConnectionEntry { sender, task } = self.connections.lock().take_connection(id)?;
159
160 sender.send(());
161 task.await
162 .ok_or_else(|| anyhow::anyhow!("connection disconnected"))
163 }
164
165 /// Listen for data connections and add them to the internal connections lookup table as they arrive.
166 async fn process_data_connections(
167 self: &Arc<Self>,
168 driver: &(impl Driver + Spawn + Clone),
169 listener: Socket,
170 ) -> anyhow::Result<()> {
171 let mut listener = PolledSocket::new(driver, listener)?;
172
173 loop {
174 let (connection, _addr) = listener.accept().await?;
175 let mut socket = PolledSocket::new(driver, connection)?;
176 let inner = Arc::downgrade(self);
177
178 // Send the 8 byte connection id, then stash the connection in the lookup table to be used later.
179 let id;
180 {
181 let mut state = self.connections.lock();
182 id = state.next_id;
183 state.next_id += 1;
184
185 tracing::debug!(id, "new data connection");
186 }
187
188 let (sender, recv) = mesh::oneshot();
189
190 // Spawn a task that returns the socket when asked to, or removes itself from the map if disconnected.
191 let task = driver.spawn(format!("data connection {} waiting", id), async move {
192 match socket.write_all(&id.to_ne_bytes()).await {
193 Ok(_) => {}
194 Err(error) => {
195 tracing::trace!(?error, "error writing connection id, removing.");
196 if let Some(state) = inner.upgrade() {
197 state.connections.lock().active.remove(&id);
198 }
199
200 return None;
201 }
202 }
203
204 let mut return_future = pin!(async { recv.await.is_ok() }.fuse());
205 let hangup = futures::select! { // race semantics
206 _ = socket.wait_ready(PollEvents::RDHUP).fuse() => true,
207 _ = return_future => false,
208 };
209
210 if hangup {
211 // Other side has disconnected, remove from the table if not already done.
212 tracing::trace!(id, "data connection disconnected");
213 if let Some(state) = inner.upgrade() {
214 state.connections.lock().active.remove(&id);
215 }
216
217 None
218 } else {
219 Some(socket)
220 }
221 });
222
223 let mut state = self.connections.lock();
224 let result = state
225 .active
226 .insert(id, DataConnectionEntry { sender, task });
227
228 if result.is_some() {
229 anyhow::bail!("connection id reused");
230 }
231 }
232 }
233}
234
235fn grpc_result<T>(result: Result<anyhow::Result<T>, CancelReason>) -> Result<T, Status> {
236 match result {
237 Ok(result) => match result {
238 Ok(value) => Ok(value),
239 Err(err) => Err(Status {
240 code: Code::Unknown as i32,
241 message: format!("{:#}", err),
242 details: vec![],
243 }),
244 },
245 Err(err) => Err(Status {
246 code: match &err {
247 CancelReason::Cancelled => Code::Cancelled,
248 CancelReason::DeadlineExceeded => Code::DeadlineExceeded,
249 } as i32,
250 message: format!("{:#}", err),
251 details: vec![],
252 }),
253 }
254}
255