From ca0d02933b2551e89f0862493be217d1fb9097f1 Mon Sep 17 00:00:00 2001 From: Scott Mabin Date: Sat, 18 Nov 2023 14:21:43 +0000 Subject: [PATCH] Priority channel using binary heap --- embassy-sync/src/lib.rs | 1 + embassy-sync/src/priority_channel.rs | 744 +++++++++++++++++++++++++++ 2 files changed, 745 insertions(+) create mode 100644 embassy-sync/src/priority_channel.rs diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs index c40fa3b6a..3ffcb9135 100644 --- a/embassy-sync/src/lib.rs +++ b/embassy-sync/src/lib.rs @@ -15,6 +15,7 @@ pub mod blocking_mutex; pub mod channel; pub mod mutex; pub mod pipe; +pub mod priority_channel; pub mod pubsub; pub mod signal; pub mod waitqueue; diff --git a/embassy-sync/src/priority_channel.rs b/embassy-sync/src/priority_channel.rs new file mode 100644 index 000000000..04148b56f --- /dev/null +++ b/embassy-sync/src/priority_channel.rs @@ -0,0 +1,744 @@ +//! A queue for sending values between asynchronous tasks. +//! +//! Similar to a [`Channel`](crate::channel::Channel), however [`PriorityChannel`] sifts higher priority items to the front of the queue. + +use core::cell::RefCell; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use heapless::BinaryHeap; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::WakerRegistration; + +/// Send-only access to a [`PriorityChannel`]. +pub struct Sender<'ch, M, T, K, const N: usize> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + channel: &'ch PriorityChannel, +} + +impl<'ch, M, T, K, const N: usize> Clone for Sender<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + fn clone(&self) -> Self { + Sender { channel: self.channel } + } +} + +impl<'ch, M, T, K, const N: usize> Copy for Sender<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ +} + +impl<'ch, M, T, K, const N: usize> Sender<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + /// Sends a value. + /// + /// See [`PriorityChannel::send()`] + pub fn send(&self, message: T) -> SendFuture<'ch, M, T, K, N> { + self.channel.send(message) + } + + /// Attempt to immediately send a message. + /// + /// See [`PriorityChannel::send()`] + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + self.channel.try_send(message) + } + + /// Allows a poll_fn to poll until the channel is ready to send + /// + /// See [`PriorityChannel::poll_ready_to_send()`] + pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> { + self.channel.poll_ready_to_send(cx) + } +} + +/// Send-only access to a [`PriorityChannel`] without knowing channel size. +pub struct DynamicSender<'ch, T> { + channel: &'ch dyn DynamicChannel, +} + +impl<'ch, T> Clone for DynamicSender<'ch, T> { + fn clone(&self) -> Self { + DynamicSender { channel: self.channel } + } +} + +impl<'ch, T> Copy for DynamicSender<'ch, T> {} + +impl<'ch, M, T, K, const N: usize> From> for DynamicSender<'ch, T> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + fn from(s: Sender<'ch, M, T, K, N>) -> Self { + Self { channel: s.channel } + } +} + +impl<'ch, T> DynamicSender<'ch, T> { + /// Sends a value. + /// + /// See [`PriorityChannel::send()`] + pub fn send(&self, message: T) -> DynamicSendFuture<'ch, T> { + DynamicSendFuture { + channel: self.channel, + message: Some(message), + } + } + + /// Attempt to immediately send a message. + /// + /// See [`PriorityChannel::send()`] + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + self.channel.try_send_with_context(message, None) + } + + /// Allows a poll_fn to poll until the channel is ready to send + /// + /// See [`PriorityChannel::poll_ready_to_send()`] + pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> { + self.channel.poll_ready_to_send(cx) + } +} + +/// Receive-only access to a [`PriorityChannel`]. +pub struct Receiver<'ch, M, T, K, const N: usize> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + channel: &'ch PriorityChannel, +} + +impl<'ch, M, T, K, const N: usize> Clone for Receiver<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + fn clone(&self) -> Self { + Receiver { channel: self.channel } + } +} + +impl<'ch, M, T, K, const N: usize> Copy for Receiver<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ +} + +impl<'ch, M, T, K, const N: usize> Receiver<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + /// Receive the next value. + /// + /// See [`PriorityChannel::receive()`]. + pub fn receive(&self) -> ReceiveFuture<'_, M, T, K, N> { + self.channel.receive() + } + + /// Attempt to immediately receive the next value. + /// + /// See [`PriorityChannel::try_receive()`] + pub fn try_receive(&self) -> Result { + self.channel.try_receive() + } + + /// Allows a poll_fn to poll until the channel is ready to receive + /// + /// See [`PriorityChannel::poll_ready_to_receive()`] + pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> { + self.channel.poll_ready_to_receive(cx) + } + + /// Poll the channel for the next item + /// + /// See [`PriorityChannel::poll_receive()`] + pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll { + self.channel.poll_receive(cx) + } +} + +/// Receive-only access to a [`PriorityChannel`] without knowing channel size. +pub struct DynamicReceiver<'ch, T> { + channel: &'ch dyn DynamicChannel, +} + +impl<'ch, T> Clone for DynamicReceiver<'ch, T> { + fn clone(&self) -> Self { + DynamicReceiver { channel: self.channel } + } +} + +impl<'ch, T> Copy for DynamicReceiver<'ch, T> {} + +impl<'ch, T> DynamicReceiver<'ch, T> { + /// Receive the next value. + /// + /// See [`PriorityChannel::receive()`]. + pub fn receive(&self) -> DynamicReceiveFuture<'_, T> { + DynamicReceiveFuture { channel: self.channel } + } + + /// Attempt to immediately receive the next value. + /// + /// See [`PriorityChannel::try_receive()`] + pub fn try_receive(&self) -> Result { + self.channel.try_receive_with_context(None) + } + + /// Allows a poll_fn to poll until the channel is ready to receive + /// + /// See [`PriorityChannel::poll_ready_to_receive()`] + pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> { + self.channel.poll_ready_to_receive(cx) + } + + /// Poll the channel for the next item + /// + /// See [`PriorityChannel::poll_receive()`] + pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll { + self.channel.poll_receive(cx) + } +} + +impl<'ch, M, T, K, const N: usize> From> for DynamicReceiver<'ch, T> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + fn from(s: Receiver<'ch, M, T, K, N>) -> Self { + Self { channel: s.channel } + } +} + +/// Future returned by [`PriorityChannel::receive`] and [`Receiver::receive`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ReceiveFuture<'ch, M, T, K, const N: usize> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + channel: &'ch PriorityChannel, +} + +impl<'ch, M, T, K, const N: usize> Future for ReceiveFuture<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.channel.poll_receive(cx) + } +} + +/// Future returned by [`DynamicReceiver::receive`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DynamicReceiveFuture<'ch, T> { + channel: &'ch dyn DynamicChannel, +} + +impl<'ch, T> Future for DynamicReceiveFuture<'ch, T> { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.channel.try_receive_with_context(Some(cx)) { + Ok(v) => Poll::Ready(v), + Err(TryReceiveError::Empty) => Poll::Pending, + } + } +} + +/// Future returned by [`PriorityChannel::send`] and [`Sender::send`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct SendFuture<'ch, M, T, K, const N: usize> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + channel: &'ch PriorityChannel, + message: Option, +} + +impl<'ch, M, T, K, const N: usize> Future for SendFuture<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.message.take() { + Some(m) => match self.channel.try_send_with_context(m, Some(cx)) { + Ok(..) => Poll::Ready(()), + Err(TrySendError::Full(m)) => { + self.message = Some(m); + Poll::Pending + } + }, + None => panic!("Message cannot be None"), + } + } +} + +impl<'ch, M, T, K, const N: usize> Unpin for SendFuture<'ch, M, T, K, N> +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ +} + +/// Future returned by [`DynamicSender::send`]. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct DynamicSendFuture<'ch, T> { + channel: &'ch dyn DynamicChannel, + message: Option, +} + +impl<'ch, T> Future for DynamicSendFuture<'ch, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.message.take() { + Some(m) => match self.channel.try_send_with_context(m, Some(cx)) { + Ok(..) => Poll::Ready(()), + Err(TrySendError::Full(m)) => { + self.message = Some(m); + Poll::Pending + } + }, + None => panic!("Message cannot be None"), + } + } +} + +impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} + +trait DynamicChannel { + fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError>; + + fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result; + + fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()>; + fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()>; + + fn poll_receive(&self, cx: &mut Context<'_>) -> Poll; +} + +/// Error returned by [`try_receive`](PriorityChannel::try_receive). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TryReceiveError { + /// A message could not be received because the channel is empty. + Empty, +} + +/// Error returned by [`try_send`](PriorityChannel::try_send). +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TrySendError { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), +} + +struct ChannelState { + queue: BinaryHeap, + receiver_waker: WakerRegistration, + senders_waker: WakerRegistration, +} + +impl ChannelState +where + T: Ord, + K: heapless::binary_heap::Kind, +{ + const fn new() -> Self { + ChannelState { + queue: BinaryHeap::new(), + receiver_waker: WakerRegistration::new(), + senders_waker: WakerRegistration::new(), + } + } + + fn try_receive(&mut self) -> Result { + self.try_receive_with_context(None) + } + + fn try_receive_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { + if self.queue.len() == self.queue.capacity() { + self.senders_waker.wake(); + } + + if let Some(message) = self.queue.pop() { + Ok(message) + } else { + if let Some(cx) = cx { + self.receiver_waker.register(cx.waker()); + } + Err(TryReceiveError::Empty) + } + } + + fn poll_receive(&mut self, cx: &mut Context<'_>) -> Poll { + if self.queue.len() == self.queue.capacity() { + self.senders_waker.wake(); + } + + if let Some(message) = self.queue.pop() { + Poll::Ready(message) + } else { + self.receiver_waker.register(cx.waker()); + Poll::Pending + } + } + + fn poll_ready_to_receive(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.receiver_waker.register(cx.waker()); + + if !self.queue.is_empty() { + Poll::Ready(()) + } else { + Poll::Pending + } + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError> { + self.try_send_with_context(message, None) + } + + fn try_send_with_context(&mut self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError> { + match self.queue.push(message) { + Ok(()) => { + self.receiver_waker.wake(); + Ok(()) + } + Err(message) => { + if let Some(cx) = cx { + self.senders_waker.register(cx.waker()); + } + Err(TrySendError::Full(message)) + } + } + } + + fn poll_ready_to_send(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.senders_waker.register(cx.waker()); + + if !self.queue.len() == self.queue.capacity() { + Poll::Ready(()) + } else { + Poll::Pending + } + } +} + +/// A bounded channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. +/// +/// All data sent will become available in the same order as it was sent. +pub struct PriorityChannel +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + inner: Mutex>>, +} + +impl PriorityChannel +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + /// Establish a new bounded channel. For example, to create one with a NoopMutex: + /// + /// ``` + /// use embassy_sync::channel::PriorityChannel; + /// use embassy_sync::blocking_mutex::raw::NoopRawMutex; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = PriorityChannel::::new(); + /// ``` + pub const fn new() -> Self { + Self { + inner: Mutex::new(RefCell::new(ChannelState::new())), + } + } + + fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + self.inner.lock(|rc| f(&mut *unwrap!(rc.try_borrow_mut()))) + } + + fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result { + self.lock(|c| c.try_receive_with_context(cx)) + } + + /// Poll the channel for the next message + pub fn poll_receive(&self, cx: &mut Context<'_>) -> Poll { + self.lock(|c| c.poll_receive(cx)) + } + + fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError> { + self.lock(|c| c.try_send_with_context(m, cx)) + } + + /// Allows a poll_fn to poll until the channel is ready to receive + pub fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> { + self.lock(|c| c.poll_ready_to_receive(cx)) + } + + /// Allows a poll_fn to poll until the channel is ready to send + pub fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> { + self.lock(|c| c.poll_ready_to_send(cx)) + } + + /// Get a sender for this channel. + pub fn sender(&self) -> Sender<'_, M, T, K, N> { + Sender { channel: self } + } + + /// Get a receiver for this channel. + pub fn receiver(&self) -> Receiver<'_, M, T, K, N> { + Receiver { channel: self } + } + + /// Send a value, waiting until there is capacity. + /// + /// Sending completes when the value has been pushed to the channel's queue. + /// This doesn't mean the value has been received yet. + pub fn send(&self, message: T) -> SendFuture<'_, M, T, K, N> { + SendFuture { + channel: self, + message: Some(message), + } + } + + /// Attempt to immediately send a message. + /// + /// This method differs from [`send`](PriorityChannel::send) by returning immediately if the channel's + /// buffer is full, instead of waiting. + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`PriorityChannel`], then an + /// error is returned. + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + self.lock(|c| c.try_send(message)) + } + + /// Receive the next value. + /// + /// If there are no messages in the channel's buffer, this method will + /// wait until a message is sent. + pub fn receive(&self) -> ReceiveFuture<'_, M, T, K, N> { + ReceiveFuture { channel: self } + } + + /// Attempt to immediately receive a message. + /// + /// This method will either receive a message from the channel immediately or return an error + /// if the channel is empty. + pub fn try_receive(&self) -> Result { + self.lock(|c| c.try_receive()) + } +} + +/// Implements the DynamicChannel to allow creating types that are unaware of the queue size with the +/// tradeoff cost of dynamic dispatch. +impl DynamicChannel for PriorityChannel +where + T: Ord, + K: heapless::binary_heap::Kind, + M: RawMutex, +{ + fn try_send_with_context(&self, m: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError> { + PriorityChannel::try_send_with_context(self, m, cx) + } + + fn try_receive_with_context(&self, cx: Option<&mut Context<'_>>) -> Result { + PriorityChannel::try_receive_with_context(self, cx) + } + + fn poll_ready_to_send(&self, cx: &mut Context<'_>) -> Poll<()> { + PriorityChannel::poll_ready_to_send(self, cx) + } + + fn poll_ready_to_receive(&self, cx: &mut Context<'_>) -> Poll<()> { + PriorityChannel::poll_ready_to_receive(self, cx) + } + + fn poll_receive(&self, cx: &mut Context<'_>) -> Poll { + PriorityChannel::poll_receive(self, cx) + } +} + +#[cfg(test)] +mod tests { + use core::time::Duration; + + use futures_executor::ThreadPool; + use futures_timer::Delay; + use futures_util::task::SpawnExt; + use static_cell::StaticCell; + + use super::*; + use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; + + fn capacity(c: &ChannelState) -> usize { + c.queue.capacity() - c.queue.len() + } + + #[test] + fn sending_once() { + let mut c = ChannelState::::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(capacity(&c), 2); + } + + #[test] + fn sending_when_full() { + let mut c = ChannelState::::new(); + let _ = c.try_send(1); + let _ = c.try_send(1); + let _ = c.try_send(1); + match c.try_send(2) { + Err(TrySendError::Full(2)) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 0); + } + + #[test] + fn receiving_once_with_one_send() { + let mut c = ChannelState::::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_receive().unwrap(), 1); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_empty() { + let mut c = ChannelState::::new(); + match c.try_receive() { + Err(TryReceiveError::Empty) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 3); + } + + #[test] + fn simple_send_and_receive() { + let c = PriorityChannel::::new(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_receive().unwrap(), 1); + } + + #[test] + fn cloning() { + let c = PriorityChannel::::new(); + let r1 = c.receiver(); + let s1 = c.sender(); + + let _ = r1.clone(); + let _ = s1.clone(); + } + + #[test] + fn dynamic_dispatch() { + let c = PriorityChannel::::new(); + let s: DynamicSender<'_, u32> = c.sender().into(); + let r: DynamicReceiver<'_, u32> = c.receiver().into(); + + assert!(s.try_send(1).is_ok()); + assert_eq!(r.try_receive().unwrap(), 1); + } + + #[futures_test::test] + async fn receiver_receives_given_try_send_async() { + let executor = ThreadPool::new().unwrap(); + + static CHANNEL: StaticCell> = StaticCell::new(); + let c = &*CHANNEL.init(PriorityChannel::new()); + let c2 = c; + assert!(executor + .spawn(async move { + assert!(c2.try_send(1).is_ok()); + }) + .is_ok()); + assert_eq!(c.receive().await, 1); + } + + #[futures_test::test] + async fn sender_send_completes_if_capacity() { + let c = PriorityChannel::::new(); + c.send(1).await; + assert_eq!(c.receive().await, 1); + } + + #[futures_test::test] + async fn senders_sends_wait_until_capacity() { + let executor = ThreadPool::new().unwrap(); + + static CHANNEL: StaticCell> = StaticCell::new(); + let c = &*CHANNEL.init(PriorityChannel::new()); + assert!(c.try_send(1).is_ok()); + + let c2 = c; + let send_task_1 = executor.spawn_with_handle(async move { c2.send(2).await }); + let c2 = c; + let send_task_2 = executor.spawn_with_handle(async move { c2.send(3).await }); + // Wish I could think of a means of determining that the async send is waiting instead. + // However, I've used the debugger to observe that the send does indeed wait. + Delay::new(Duration::from_millis(500)).await; + assert_eq!(c.receive().await, 1); + assert!(executor + .spawn(async move { + loop { + c.receive().await; + } + }) + .is_ok()); + send_task_1.unwrap().await; + send_task_2.unwrap().await; + } +}