Fix FairSemaphore
bugs
- `acquire` and `acquire_all` futures were `!Send`, even for `M: RawMutex + Send` due to the captured `Cell`. - If multiple `acquire` tasks were queued, waking the first would not wake the second, even if there were permits remaining after the first `acquire` completed.
This commit is contained in:
parent
1fd260e4b1
commit
c9acebf783
1 changed files with 102 additions and 34 deletions
|
@ -1,8 +1,7 @@
|
|||
//! A synchronization primitive for controlling access to a pool of resources.
|
||||
use core::cell::{Cell, RefCell};
|
||||
use core::convert::Infallible;
|
||||
use core::future::poll_fn;
|
||||
use core::mem::MaybeUninit;
|
||||
use core::future::{poll_fn, Future};
|
||||
use core::task::{Poll, Waker};
|
||||
|
||||
use heapless::Deque;
|
||||
|
@ -258,9 +257,9 @@ where
|
|||
&self,
|
||||
permits: usize,
|
||||
acquire_all: bool,
|
||||
cx: Option<(&Cell<Option<usize>>, &Waker)>,
|
||||
cx: Option<(&mut Option<usize>, &Waker)>,
|
||||
) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
|
||||
let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None);
|
||||
let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
|
||||
self.state.lock(|cell| {
|
||||
let mut state = cell.borrow_mut();
|
||||
if let Some(permits) = state.take(ticket, permits, acquire_all) {
|
||||
|
@ -268,10 +267,10 @@ where
|
|||
semaphore: self,
|
||||
permits,
|
||||
}))
|
||||
} else if let Some((cell, waker)) = cx {
|
||||
} else if let Some((ticket_ref, waker)) = cx {
|
||||
match state.register(ticket, waker) {
|
||||
Ok(ticket) => {
|
||||
cell.set(Some(ticket));
|
||||
*ticket_ref = Some(ticket);
|
||||
Poll::Pending
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(err)),
|
||||
|
@ -291,10 +290,12 @@ pub struct WaitQueueFull;
|
|||
impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
||||
type Error = WaitQueueFull;
|
||||
|
||||
async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
|
||||
let ticket = Cell::new(None);
|
||||
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
|
||||
poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await
|
||||
fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
|
||||
FairAcquire {
|
||||
sema: self,
|
||||
permits,
|
||||
ticket: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
||||
|
@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
|||
}
|
||||
}
|
||||
|
||||
async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
|
||||
let ticket = Cell::new(None);
|
||||
let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
|
||||
poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await
|
||||
fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
|
||||
FairAcquireAll {
|
||||
sema: self,
|
||||
min,
|
||||
ticket: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
|
||||
|
@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
|
|||
}
|
||||
}
|
||||
|
||||
struct FairAcquire<'a, M: RawMutex, const N: usize> {
|
||||
sema: &'a FairSemaphore<M, N>,
|
||||
permits: usize,
|
||||
ticket: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
|
||||
fn drop(&mut self) {
|
||||
self.sema
|
||||
.state
|
||||
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
|
||||
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
|
||||
|
||||
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
|
||||
self.sema
|
||||
.poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
|
||||
}
|
||||
}
|
||||
|
||||
struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
|
||||
sema: &'a FairSemaphore<M, N>,
|
||||
min: usize,
|
||||
ticket: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
|
||||
fn drop(&mut self) {
|
||||
self.sema
|
||||
.state
|
||||
.lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
|
||||
type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
|
||||
|
||||
fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
|
||||
self.sema
|
||||
.poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
|
||||
}
|
||||
}
|
||||
|
||||
struct FairSemaphoreState<const N: usize> {
|
||||
permits: usize,
|
||||
next_ticket: usize,
|
||||
|
@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> {
|
|||
|
||||
if ticket.is_some() {
|
||||
self.pop();
|
||||
if self.permits > 0 {
|
||||
self.wake();
|
||||
}
|
||||
}
|
||||
|
||||
Some(permits)
|
||||
|
@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A type to delay the drop handler invocation.
|
||||
#[must_use = "to delay the drop handler invocation to the end of the scope"]
|
||||
struct OnDrop<F: FnOnce()> {
|
||||
f: MaybeUninit<F>,
|
||||
}
|
||||
|
||||
impl<F: FnOnce()> OnDrop<F> {
|
||||
/// Create a new instance.
|
||||
pub fn new(f: F) -> Self {
|
||||
Self { f: MaybeUninit::new(f) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: FnOnce()> Drop for OnDrop<F> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { self.f.as_ptr().read()() }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod greedy {
|
||||
|
@ -574,11 +607,16 @@ mod tests {
|
|||
|
||||
mod fair {
|
||||
use core::pin::pin;
|
||||
use core::time::Duration;
|
||||
|
||||
use futures_executor::ThreadPool;
|
||||
use futures_timer::Delay;
|
||||
use futures_util::poll;
|
||||
use futures_util::task::SpawnExt;
|
||||
use static_cell::StaticCell;
|
||||
|
||||
use super::super::*;
|
||||
use crate::blocking_mutex::raw::NoopRawMutex;
|
||||
use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
|
||||
|
||||
#[test]
|
||||
fn try_acquire() {
|
||||
|
@ -700,5 +738,35 @@ mod tests {
|
|||
let c = poll!(c_fut.as_mut());
|
||||
assert!(c.is_ready());
|
||||
}
|
||||
|
||||
#[futures_test::test]
|
||||
async fn wakers() {
|
||||
let executor = ThreadPool::new().unwrap();
|
||||
|
||||
static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
|
||||
let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
|
||||
|
||||
let a = semaphore.try_acquire(2);
|
||||
assert!(a.is_some());
|
||||
|
||||
let b_task = executor
|
||||
.spawn_with_handle(async move { semaphore.acquire(2).await })
|
||||
.unwrap();
|
||||
while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
|
||||
Delay::new(Duration::from_millis(50)).await;
|
||||
}
|
||||
|
||||
let c_task = executor
|
||||
.spawn_with_handle(async move { semaphore.acquire(1).await })
|
||||
.unwrap();
|
||||
|
||||
core::mem::drop(a);
|
||||
|
||||
let b = b_task.await.unwrap();
|
||||
assert_eq!(b.permits(), 2);
|
||||
|
||||
let c = c_task.await.unwrap();
|
||||
assert_eq!(c.permits(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue