microsoft/openvmm

Public

mirrored fromhttps://github.com/microsoft/openvmmAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e6c778cbebacf3a70be1d39295b4f090134c4091

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

support/fast_select/src/lib.rs

383lines · modecode

1// Copyright (C) Microsoft Corporation. All rights reserved.
2
3//! A mechanism for efficiently selecting between futures.
4//!
5//! In async code, it is common to select between the completion of two or more
6//! futures. In this case, a naive implementation of select will poll each
7//! future during each wakeup. If the poll functions are expensive (because they
8//! takes locks, makes syscalls, or otherwise performs some computationally
9//! expensive task), then this can contribute to performance problems,
10//! especially in heavily-nested async code.
11//!
12//! This crate contains an [implementation of select](FastSelect::select) that
13//! constructs a separate waker for each alternative future, allowing `select`'s
14//! poll implementation to identify exactly which futures are ready to be
15//! polled.
16
17#![warn(missing_docs)]
18// UNSAFETY: Using unchecked raw Arc, Pin, and Waker APIs.
19#![allow(unsafe_code)]
20
21use parking_lot::Mutex;
22use std::future::Future;
23use std::marker::PhantomData;
24use std::mem::ManuallyDrop;
25use std::ops::Deref;
26use std::pin::pin;
27use std::pin::Pin;
28use std::sync::atomic::AtomicU32;
29use std::sync::atomic::Ordering;
30use std::sync::Arc;
31use std::task::Context;
32use std::task::Poll;
33use std::task::RawWaker;
34use std::task::RawWakerVTable;
35use std::task::Waker;
36
37/// An object that can be used to efficiently select over alternative futures.
38///
39/// This allocates storage used by calls to [`select`](Self::select). Be careful
40/// to preallocate any instances of this outside the hot path.
41///
42/// # Example
43///
44/// ```rust
45/// # use futures::StreamExt;
46/// # use futures::executor::block_on;
47/// # use futures::channel::mpsc::unbounded;
48/// # use fast_select::FastSelect;
49/// # block_on(async {
50/// let mut fast_select = FastSelect::new();
51/// let (_cancel_send, mut cancel_recv) = unbounded::<()>();
52/// loop {
53/// let operation = async {
54/// Some(5)
55/// };
56/// let cancelled = async {
57/// let _ = cancel_recv.next().await;
58/// None
59/// };
60/// if let Some(value) = fast_select.select((operation, cancelled)).await {
61/// break value;
62/// }
63/// }
64/// # });
65/// ```
66///
67/// In cases where one future is much more common than the others, you can leave
68/// that future out and use a traditional select macro or function to select
69/// between the common future and the tuple with the remaining futures. This may
70/// even be a tuple of length one. In this case, the common future will be
71/// polled every iteration, while the uncommon futures will be only polled as
72/// necessary.
73///
74/// For example:
75///
76/// ```rust
77/// # use futures::FutureExt;
78/// # use futures::executor::block_on;
79/// # use std::future::pending;
80/// # use fast_select::FastSelect;
81/// # block_on(async {
82/// let mut fast_select = FastSelect::new();
83/// futures::select_biased! {
84/// value = async { 5u32 }.fuse() => {
85/// println!("{}", value);
86/// }
87/// _ = fast_select.select((pending::<u32>(),)).fuse() => {
88/// unreachable!()
89/// }
90/// }
91/// # });
92/// ```
93#[derive(Default, Debug)]
94pub struct FastSelect {
95 state: Arc<State>,
96}
97
98#[derive(Debug)]
99struct SelectPoll<'a, T> {
100 poll_state: PollState<'a>,
101 futures: T,
102}
103
104impl FastSelect {
105 /// Creates a new [`FastSelect`].
106 pub fn new() -> Self {
107 Default::default()
108 }
109
110 /// Selects between the futures in tuple `futures`.
111 ///
112 /// Returns the output of the first one that completes. All the other
113 /// futures are dropped without being completed.
114 ///
115 /// The futures are polled in the order they are specified in the tuple, so
116 /// there is a bias for earlier ones in the tuple.
117 pub async fn select<T: Select>(&mut self, futures: T) -> T::Output {
118 assert!(T::COUNT <= 32);
119
120 SelectPoll {
121 poll_state: PollState {
122 state: &self.state,
123 last_waker: Default::default(),
124 poll: (1u32 << (T::COUNT % 32)).wrapping_sub(1),
125 },
126 futures: pin!(futures),
127 }
128 .await
129 }
130}
131
132impl<T: Select> Future for SelectPoll<'_, Pin<&mut T>> {
133 type Output = T::Output;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let this = self.get_mut();
137 this.futures.as_mut().poll_select(cx, &mut this.poll_state)
138 }
139}
140
141#[doc(hidden)]
142#[derive(Debug)]
143pub struct PollState<'a> {
144 state: &'a Arc<State>,
145 last_waker: LastWaker,
146 poll: u32,
147}
148
149impl PollState<'_> {
150 fn refill(&mut self, cx: &mut Context<'_>) -> Poll<()> {
151 while self.poll == 0 {
152 if self.state.poll.load(Ordering::Relaxed) != 0 {
153 self.poll = self.state.poll.swap(0, Ordering::Acquire);
154 }
155 if self.poll != 0 {
156 // The waker in `state` was probably taken and dropped.
157 self.last_waker.clear();
158 } else {
159 if let Some(waker) = self.last_waker.update_waker(cx) {
160 // Update the locked waker and loop around to check
161 // `state.poll` again.
162 *self.state.waker.lock() = Some(waker);
163 } else {
164 // The waker is up to date, so do nothing.
165 return Poll::Pending;
166 }
167 }
168 }
169 Poll::Ready(())
170 }
171}
172
173/// A sealed trait for tuple types that can be selected over with
174/// [`FastSelect`].
175pub trait Select: private::Sealed {
176 #[doc(hidden)]
177 /// The number of elements in the tuple.
178 const COUNT: usize;
179 #[doc(hidden)]
180 /// The output type of the tuple futures.
181 type Output;
182
183 #[doc(hidden)]
184 fn poll_select(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 state: &mut PollState<'_>,
188 ) -> Poll<Self::Output>;
189}
190
191mod private {
192 pub trait Sealed {}
193}
194
195macro_rules! gen_future {
196 ( $count:expr, $(($t:tt, $n:tt)),* ) => {
197 impl<R, $($t: Future<Output = R>,)*> private::Sealed for ($($t,)*) {}
198
199 impl<R, $($t: Future<Output = R>,)*> Select for ($($t,)*) {
200 const COUNT: usize = $count;
201 type Output = R;
202
203 fn poll_select(self: Pin<&mut Self>, cx: &mut Context<'_>, state: &mut PollState<'_>) -> Poll<R> {
204 // SAFETY: unpinning in order to re-pin each tuple element one
205 // at a time. This is safe because each element is only accessed
206 // via a pinned pointer.
207 let this = unsafe { self.get_unchecked_mut() };
208 loop {
209 std::task::ready!(state.refill(cx));
210 $(
211 if state.poll & (1<<$n) != 0 {
212 state.poll &= !(1<<$n);
213 // SAFETY: repinning as described above.
214 if let Poll::Ready(r) = unsafe { Pin::new_unchecked(&mut this.$n) }
215 .poll(&mut Context::from_waker(&state.state.waker_ref($n)))
216 {
217 return Poll::Ready(r);
218 }
219 }
220 )*
221 }
222 }
223 }
224 };
225}
226
227gen_future!(1, (T0, 0));
228gen_future!(2, (T0, 0), (T1, 1));
229gen_future!(3, (T0, 0), (T1, 1), (T2, 2));
230gen_future!(4, (T0, 0), (T1, 1), (T2, 2), (T3, 3));
231gen_future!(5, (T0, 0), (T1, 1), (T2, 2), (T3, 3), (T4, 4));
232gen_future!(6, (T0, 0), (T1, 1), (T2, 2), (T3, 3), (T4, 4), (T5, 5));
233
234#[derive(Debug, Default)]
235struct LastWaker {
236 last_waker: Option<RawWaker>,
237}
238
239// SAFETY: LastWaker contains a RawWaker (which is not inherently Send/Sync),
240// but it is used only for comparisons.
241unsafe impl Send for LastWaker {}
242// SAFETY: LastWaker contains a RawWaker (which is not inherently Send/Sync),
243// but it is used only for comparisons.
244unsafe impl Sync for LastWaker {}
245
246fn raw_waker_copy(waker: &Waker) -> RawWaker {
247 // FUTURE: use Waker::as_raw and RawWaker::{data, vtable} once stabilized to
248 // avoid unsafe here.
249 //
250 // SAFETY: Waker is repr(transparent) over RawWaker. RawWaker is safe to
251 // copy because it is just a wrapper around two pointers, and it has no Drop
252 // implementation.
253 unsafe { std::ptr::from_ref(waker).cast::<RawWaker>().read() }
254}
255
256impl LastWaker {
257 fn clear(&mut self) {
258 self.last_waker = None;
259 }
260
261 fn update_waker(&mut self, cx: &Context<'_>) -> Option<Waker> {
262 if self.last_waker == Some(raw_waker_copy(cx.waker())) {
263 return None;
264 }
265 let waker = cx.waker().clone();
266 self.last_waker = Some(raw_waker_copy(&waker));
267 Some(waker)
268 }
269}
270
271#[repr(C, align(4))]
272#[derive(Default, Debug)]
273struct State {
274 poll: AtomicU32,
275 waker: Mutex<Option<Waker>>,
276}
277
278impl State {
279 fn wake(&self, i: usize) {
280 let old = self.poll.fetch_or(1 << i, Ordering::Release);
281 if old == 0 {
282 let waker = self.waker.lock().take();
283 if let Some(waker) = waker {
284 waker.wake();
285 }
286 }
287 }
288
289 /// Gets the pointer and wake index from the data pointer.
290 unsafe fn from_ptr(data: *const ()) -> (ManuallyDrop<Arc<Self>>, usize) {
291 let align_mask = align_of::<Self>() - 1;
292 let i = (data as usize) & align_mask;
293 let this = (data as usize & !align_mask) as *const Self;
294 // SAFETY: caller guarantees that this is a valid reference.
295 let this = unsafe { Arc::from_raw(this) };
296 (ManuallyDrop::new(this), i)
297 }
298
299 unsafe fn clone_fn(data: *const ()) -> RawWaker {
300 // SAFETY: caller guarantees this is a valid data pointer.
301 let (this, _) = unsafe { Self::from_ptr(data) };
302 let _ = Arc::into_raw(Arc::clone(&this));
303 RawWaker::new(
304 data,
305 &RawWakerVTable::new(
306 Self::clone_fn,
307 Self::wake_fn,
308 Self::wake_by_ref_fn,
309 Self::drop_fn,
310 ),
311 )
312 }
313
314 unsafe fn wake_fn(data: *const ()) {
315 // SAFETY: caller guarantees this is a valid data pointer.
316 let (this, i) = unsafe { Self::from_ptr(data) };
317 let this = ManuallyDrop::into_inner(this);
318 this.wake(i);
319 }
320
321 unsafe fn wake_by_ref_fn(data: *const ()) {
322 // SAFETY: caller guarantees this is a valid data pointer.
323 let (this, i) = unsafe { Self::from_ptr(data) };
324 this.wake(i);
325 }
326
327 unsafe fn drop_fn(data: *const ()) {
328 // SAFETY: caller guarantees this is a valid data pointer.
329 let (this, _) = unsafe { Self::from_ptr(data) };
330 drop(ManuallyDrop::into_inner(this));
331 }
332
333 fn waker_ref<'a>(self: &'a Arc<Self>, i: usize) -> WakerRef<'a> {
334 let data = ((Arc::as_ptr(self) as usize) | i) as *const ();
335 let waker = RawWaker::new(
336 data,
337 &RawWakerVTable::new(
338 Self::clone_fn,
339 Self::wake_by_ref_fn,
340 Self::wake_by_ref_fn,
341 |_| (),
342 ),
343 );
344 // SAFETY: the vtable methods implement the waker contract.
345 let waker = unsafe { Waker::from_raw(waker) };
346 WakerRef {
347 waker,
348 _phantom: PhantomData,
349 }
350 }
351}
352
353struct WakerRef<'a> {
354 waker: Waker,
355 _phantom: PhantomData<&'a ()>,
356}
357
358impl Deref for WakerRef<'_> {
359 type Target = Waker;
360
361 fn deref(&self) -> &Self::Target {
362 &self.waker
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use crate::FastSelect;
369 use pal_async::async_test;
370 use pal_async::timer::PolledTimer;
371 use pal_async::DefaultDriver;
372 use std::future::pending;
373 use std::time::Duration;
374
375 #[async_test]
376 async fn test_foo(driver: DefaultDriver) {
377 let mut select = FastSelect::new();
378 let mut timer = PolledTimer::new(&driver);
379 select
380 .select((pending(), pending(), timer.sleep(Duration::from_millis(30))))
381 .await;
382 }
383}
384