diff --git a/embassy-sync/src/semaphore.rs b/embassy-sync/src/semaphore.rs index 52c468b4a..d30eee30b 100644 --- a/embassy-sync/src/semaphore.rs +++ b/embassy-sync/src/semaphore.rs @@ -1,8 +1,7 @@ //! A synchronization primitive for controlling access to a pool of resources. use core::cell::{Cell, RefCell}; use core::convert::Infallible; -use core::future::poll_fn; -use core::mem::MaybeUninit; +use core::future::{poll_fn, Future}; use core::task::{Poll, Waker}; use heapless::Deque; @@ -258,9 +257,9 @@ where &self, permits: usize, acquire_all: bool, - cx: Option<(&Cell<Option<usize>>, &Waker)>, + cx: Option<(&mut Option<usize>, &Waker)>, ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> { - let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None); + let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None); self.state.lock(|cell| { let mut state = cell.borrow_mut(); if let Some(permits) = state.take(ticket, permits, acquire_all) { @@ -268,10 +267,10 @@ where semaphore: self, permits, })) - } else if let Some((cell, waker)) = cx { + } else if let Some((ticket_ref, waker)) = cx { match state.register(ticket, waker) { Ok(ticket) => { - cell.set(Some(ticket)); + *ticket_ref = Some(ticket); Poll::Pending } Err(err) => Poll::Ready(Err(err)), @@ -291,10 +290,12 @@ pub struct WaitQueueFull; impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { type Error = WaitQueueFull; - async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { - let ticket = Cell::new(None); - let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); - poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await + fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { + FairAcquire { + sema: self, + permits, + ticket: None, + } } fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { @@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { } } - async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { - let ticket = Cell::new(None); - let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); - poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await + fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { + FairAcquireAll { + sema: self, + min, + ticket: None, + } } fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { @@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { } } +struct FairAcquire<'a, M: RawMutex, const N: usize> { + sema: &'a FairSemaphore<M, N>, + permits: usize, + ticket: Option<usize>, +} + +impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> { + fn drop(&mut self) { + self.sema + .state + .lock(|cell| cell.borrow_mut().cancel(self.ticket.take())); + } +} + +impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> { + type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; + + fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { + self.sema + .poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker()))) + } +} + +struct FairAcquireAll<'a, M: RawMutex, const N: usize> { + sema: &'a FairSemaphore<M, N>, + min: usize, + ticket: Option<usize>, +} + +impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> { + fn drop(&mut self) { + self.sema + .state + .lock(|cell| cell.borrow_mut().cancel(self.ticket.take())); + } +} + +impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> { + type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; + + fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { + self.sema + .poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker()))) + } +} + struct FairSemaphoreState<const N: usize> { permits: usize, next_ticket: usize, @@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> { if ticket.is_some() { self.pop(); + if self.permits > 0 { + self.wake(); + } } Some(permits) @@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> { } } -/// A type to delay the drop handler invocation. -#[must_use = "to delay the drop handler invocation to the end of the scope"] -struct OnDrop<F: FnOnce()> { - f: MaybeUninit<F>, -} - -impl<F: FnOnce()> OnDrop<F> { - /// Create a new instance. - pub fn new(f: F) -> Self { - Self { f: MaybeUninit::new(f) } - } -} - -impl<F: FnOnce()> Drop for OnDrop<F> { - fn drop(&mut self) { - unsafe { self.f.as_ptr().read()() } - } -} - #[cfg(test)] mod tests { mod greedy { @@ -574,11 +607,16 @@ mod tests { mod fair { use core::pin::pin; + use core::time::Duration; + use futures_executor::ThreadPool; + use futures_timer::Delay; use futures_util::poll; + use futures_util::task::SpawnExt; + use static_cell::StaticCell; use super::super::*; - use crate::blocking_mutex::raw::NoopRawMutex; + use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; #[test] fn try_acquire() { @@ -700,5 +738,35 @@ mod tests { let c = poll!(c_fut.as_mut()); assert!(c.is_ready()); } + + #[futures_test::test] + async fn wakers() { + let executor = ThreadPool::new().unwrap(); + + static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new(); + let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3)); + + let a = semaphore.try_acquire(2); + assert!(a.is_some()); + + let b_task = executor + .spawn_with_handle(async move { semaphore.acquire(2).await }) + .unwrap(); + while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) { + Delay::new(Duration::from_millis(50)).await; + } + + let c_task = executor + .spawn_with_handle(async move { semaphore.acquire(1).await }) + .unwrap(); + + core::mem::drop(a); + + let b = b_task.await.unwrap(); + assert_eq!(b.permits(), 2); + + let c = c_task.await.unwrap(); + assert_eq!(c.permits(), 1); + } } }