Fix FairSemaphore bugs

- `acquire` and `acquire_all` futures were `!Send`, even for `M: RawMutex + Send` due to the captured `Cell`.
- If multiple `acquire` tasks were queued, waking the first would not wake the second, even if there were permits remaining after the first `acquire` completed.
This commit is contained in:
Alex Moon 2024-04-03 19:13:57 -04:00
parent 1fd260e4b1
commit c9acebf783

View file

@ -1,8 +1,7 @@
//! A synchronization primitive for controlling access to a pool of resources.
use core::cell::{Cell, RefCell};
use core::convert::Infallible;
use core::future::poll_fn;
use core::mem::MaybeUninit;
use core::future::{poll_fn, Future};
use core::task::{Poll, Waker};
use heapless::Deque;
@ -258,9 +257,9 @@ where
&self,
permits: usize,
acquire_all: bool,
cx: Option<(&Cell<Option<usize>>, &Waker)>,
cx: Option<(&mut Option<usize>, &Waker)>,
) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None);
let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
self.state.lock(|cell| {
let mut state = cell.borrow_mut();
if let Some(permits) = state.take(ticket, permits, acquire_all) {
@ -268,10 +267,10 @@ where
semaphore: self,
permits,
}))
} else if let Some((cell, waker)) = cx {
} else if let Some((ticket_ref, waker)) = cx {
match state.register(ticket, waker) {
Ok(ticket) => {
cell.set(Some(ticket));
*ticket_ref = Some(ticket);
Poll::Pending
}
Err(err) => Poll::Ready(Err(err)),
@ -291,10 +290,12 @@ pub struct WaitQueueFull;
impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
type Error = WaitQueueFull;
async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
let ticket = Cell::new(None);
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await
fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
FairAcquire {
sema: self,
permits,
ticket: None,
}
}
fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
}
}
async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
let ticket = Cell::new(None);
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await
fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
FairAcquireAll {
sema: self,
min,
ticket: None,
}
}
fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
}
}
struct FairAcquire<'a, M: RawMutex, const N: usize> {
sema: &'a FairSemaphore<M, N>,
permits: usize,
ticket: Option<usize>,
}
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
fn drop(&mut self) {
self.sema
.state
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
}
}
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
self.sema
.poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
}
}
struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
sema: &'a FairSemaphore<M, N>,
min: usize,
ticket: Option<usize>,
}
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
fn drop(&mut self) {
self.sema
.state
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
}
}
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
self.sema
.poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
}
}
struct FairSemaphoreState<const N: usize> {
permits: usize,
next_ticket: usize,
@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> {
if ticket.is_some() {
self.pop();
if self.permits > 0 {
self.wake();
}
}
Some(permits)
@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> {
}
}
/// A type to delay the drop handler invocation.
#[must_use = "to delay the drop handler invocation to the end of the scope"]
struct OnDrop<F: FnOnce()> {
f: MaybeUninit<F>,
}
impl<F: FnOnce()> OnDrop<F> {
/// Create a new instance.
pub fn new(f: F) -> Self {
Self { f: MaybeUninit::new(f) }
}
}
impl<F: FnOnce()> Drop for OnDrop<F> {
fn drop(&mut self) {
unsafe { self.f.as_ptr().read()() }
}
}
#[cfg(test)]
mod tests {
mod greedy {
@ -574,11 +607,16 @@ mod tests {
mod fair {
use core::pin::pin;
use core::time::Duration;
use futures_executor::ThreadPool;
use futures_timer::Delay;
use futures_util::poll;
use futures_util::task::SpawnExt;
use static_cell::StaticCell;
use super::super::*;
use crate::blocking_mutex::raw::NoopRawMutex;
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
#[test]
fn try_acquire() {
@ -700,5 +738,35 @@ mod tests {
let c = poll!(c_fut.as_mut());
assert!(c.is_ready());
}
#[futures_test::test]
async fn wakers() {
let executor = ThreadPool::new().unwrap();
static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
let a = semaphore.try_acquire(2);
assert!(a.is_some());
let b_task = executor
.spawn_with_handle(async move { semaphore.acquire(2).await })
.unwrap();
while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
Delay::new(Duration::from_millis(50)).await;
}
let c_task = executor
.spawn_with_handle(async move { semaphore.acquire(1).await })
.unwrap();
core::mem::drop(a);
let b = b_task.await.unwrap();
assert_eq!(b.permits(), 2);
let c = c_task.await.unwrap();
assert_eq!(c.permits(), 1);
}
}
}