diff --git a/embassy/src/channel/mod.rs b/embassy/src/channel/mod.rs index 05edc55d1..5df1f5c5c 100644 --- a/embassy/src/channel/mod.rs +++ b/embassy/src/channel/mod.rs @@ -1,5 +1,5 @@ //! Async channels pub mod mpmc; - +pub mod pubsub; pub mod signal; diff --git a/embassy/src/channel/pubsub/mod.rs b/embassy/src/channel/pubsub/mod.rs new file mode 100644 index 000000000..9bfb845e0 --- /dev/null +++ b/embassy/src/channel/pubsub/mod.rs @@ -0,0 +1,542 @@ +//! Implementation of [PubSubChannel], a queue where published messages get received by all subscribers. + +#![deny(missing_docs)] + +use core::cell::RefCell; +use core::fmt::Debug; +use core::task::{Context, Poll, Waker}; + +use heapless::Deque; + +use self::publisher::{ImmediatePub, Pub}; +use self::subscriber::Sub; +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::MultiWakerRegistration; + +pub mod publisher; +pub mod subscriber; + +pub use publisher::{DynImmediatePublisher, DynPublisher, ImmediatePublisher, Publisher}; +pub use subscriber::{DynSubscriber, Subscriber}; + +/// A broadcast channel implementation where multiple publishers can send messages to multiple subscribers +/// +/// Any published message can be read by all subscribers. +/// A publisher can choose how it sends its message. +/// +/// - With [Publisher::publish] the publisher has to wait until there is space in the internal message queue. +/// - With [Publisher::publish_immediate] the publisher doesn't await and instead lets the oldest message +/// in the queue drop if necessary. This will cause any [Subscriber] that missed the message to receive +/// an error to indicate that it has lagged. +/// +/// ## Example +/// +/// ``` +/// # use embassy::blocking_mutex::raw::NoopRawMutex; +/// # use embassy::channel::pubsub::WaitResult; +/// # use embassy::channel::pubsub::PubSubChannel; +/// # use futures_executor::block_on; +/// # let test = async { +/// // Create the channel. This can be static as well +/// let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); +/// +/// // This is a generic subscriber with a direct reference to the channel +/// let mut sub0 = channel.subscriber().unwrap(); +/// // This is a dynamic subscriber with a dynamic (trait object) reference to the channel +/// let mut sub1 = channel.dyn_subscriber().unwrap(); +/// +/// let pub0 = channel.publisher().unwrap(); +/// +/// // Publish a message, but wait if the queue is full +/// pub0.publish(42).await; +/// +/// // Publish a message, but if the queue is full, just kick out the oldest message. +/// // This may cause some subscribers to miss a message +/// pub0.publish_immediate(43); +/// +/// // Wait for a new message. If the subscriber missed a message, the WaitResult will be a Lag result +/// assert_eq!(sub0.next_message().await, WaitResult::Message(42)); +/// assert_eq!(sub1.next_message().await, WaitResult::Message(42)); +/// +/// // Wait again, but this time ignore any Lag results +/// assert_eq!(sub0.next_message_pure().await, 43); +/// assert_eq!(sub1.next_message_pure().await, 43); +/// +/// // There's also a polling interface +/// assert_eq!(sub0.try_next_message(), None); +/// assert_eq!(sub1.try_next_message(), None); +/// # }; +/// # +/// # block_on(test); +/// ``` +/// +pub struct PubSubChannel<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> { + inner: Mutex<M, RefCell<PubSubState<T, CAP, SUBS, PUBS>>>, +} + +impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> + PubSubChannel<M, T, CAP, SUBS, PUBS> +{ + /// Create a new channel + pub const fn new() -> Self { + Self { + inner: Mutex::const_new(M::INIT, RefCell::new(PubSubState::new())), + } + } + + /// Create a new subscriber. It will only receive messages that are published after its creation. + /// + /// If there are no subscriber slots left, an error will be returned. + pub fn subscriber(&self) -> Result<Subscriber<M, T, CAP, SUBS, PUBS>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.subscriber_count >= SUBS { + Err(Error::MaximumSubscribersReached) + } else { + s.subscriber_count += 1; + Ok(Subscriber(Sub::new(s.next_message_id, self))) + } + }) + } + + /// Create a new subscriber. It will only receive messages that are published after its creation. + /// + /// If there are no subscriber slots left, an error will be returned. + pub fn dyn_subscriber<'a>(&'a self) -> Result<DynSubscriber<'a, T>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.subscriber_count >= SUBS { + Err(Error::MaximumSubscribersReached) + } else { + s.subscriber_count += 1; + Ok(DynSubscriber(Sub::new(s.next_message_id, self))) + } + }) + } + + /// Create a new publisher + /// + /// If there are no publisher slots left, an error will be returned. + pub fn publisher(&self) -> Result<Publisher<M, T, CAP, SUBS, PUBS>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.publisher_count >= PUBS { + Err(Error::MaximumPublishersReached) + } else { + s.publisher_count += 1; + Ok(Publisher(Pub::new(self))) + } + }) + } + + /// Create a new publisher + /// + /// If there are no publisher slots left, an error will be returned. + pub fn dyn_publisher<'a>(&'a self) -> Result<DynPublisher<'a, T>, Error> { + self.inner.lock(|inner| { + let mut s = inner.borrow_mut(); + + if s.publisher_count >= PUBS { + Err(Error::MaximumPublishersReached) + } else { + s.publisher_count += 1; + Ok(DynPublisher(Pub::new(self))) + } + }) + } + + /// Create a new publisher that can only send immediate messages. + /// This kind of publisher does not take up a publisher slot. + pub fn immediate_publisher(&self) -> ImmediatePublisher<M, T, CAP, SUBS, PUBS> { + ImmediatePublisher(ImmediatePub::new(self)) + } + + /// Create a new publisher that can only send immediate messages. + /// This kind of publisher does not take up a publisher slot. + pub fn dyn_immediate_publisher(&self) -> DynImmediatePublisher<T> { + DynImmediatePublisher(ImmediatePub::new(self)) + } +} + +impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T> + for PubSubChannel<M, T, CAP, SUBS, PUBS> +{ + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll<WaitResult<T>> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + + // Check if we can read a message + match s.get_message(*next_message_id) { + // Yes, so we are done polling + Some(WaitResult::Message(message)) => { + *next_message_id += 1; + Poll::Ready(WaitResult::Message(message)) + } + // No, so we need to reregister our waker and sleep again + None => { + if let Some(cx) = cx { + s.register_subscriber_waker(cx.waker()); + } + Poll::Pending + } + // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged + Some(WaitResult::Lagged(amount)) => { + *next_message_id += amount; + Poll::Ready(WaitResult::Lagged(amount)) + } + } + }) + } + + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T> { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + // Try to publish the message + match s.try_publish(message) { + // We did it, we are ready + Ok(()) => Ok(()), + // The queue is full, so we need to reregister our waker and go to sleep + Err(message) => { + if let Some(cx) = cx { + s.register_publisher_waker(cx.waker()); + } + Err(message) + } + } + }) + } + + fn publish_immediate(&self, message: T) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.publish_immediate(message) + }) + } + + fn unregister_subscriber(&self, subscriber_next_message_id: u64) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_subscriber(subscriber_next_message_id) + }) + } + + fn unregister_publisher(&self) { + self.inner.lock(|s| { + let mut s = s.borrow_mut(); + s.unregister_publisher() + }) + } +} + +/// Internal state for the PubSub channel +struct PubSubState<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> { + /// The queue contains the last messages that have been published and a countdown of how many subscribers are yet to read it + queue: Deque<(T, usize), CAP>, + /// Every message has an id. + /// Don't worry, we won't run out. + /// If a million messages were published every second, then the ID's would run out in about 584942 years. + next_message_id: u64, + /// Collection of wakers for Subscribers that are waiting. + subscriber_wakers: MultiWakerRegistration<SUBS>, + /// Collection of wakers for Publishers that are waiting. + publisher_wakers: MultiWakerRegistration<PUBS>, + /// The amount of subscribers that are active + subscriber_count: usize, + /// The amount of publishers that are active + publisher_count: usize, +} + +impl<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubState<T, CAP, SUBS, PUBS> { + /// Create a new internal channel state + const fn new() -> Self { + Self { + queue: Deque::new(), + next_message_id: 0, + subscriber_wakers: MultiWakerRegistration::new(), + publisher_wakers: MultiWakerRegistration::new(), + subscriber_count: 0, + publisher_count: 0, + } + } + + fn try_publish(&mut self, message: T) -> Result<(), T> { + if self.subscriber_count == 0 { + // We don't need to publish anything because there is no one to receive it + return Ok(()); + } + + if self.queue.is_full() { + return Err(message); + } + // We just did a check for this + self.queue.push_back((message, self.subscriber_count)).ok().unwrap(); + + self.next_message_id += 1; + + // Wake all of the subscribers + self.subscriber_wakers.wake(); + + Ok(()) + } + + fn publish_immediate(&mut self, message: T) { + // Make space in the queue if required + if self.queue.is_full() { + self.queue.pop_front(); + } + + // This will succeed because we made sure there is space + self.try_publish(message).ok().unwrap(); + } + + fn get_message(&mut self, message_id: u64) -> Option<WaitResult<T>> { + let start_id = self.next_message_id - self.queue.len() as u64; + + if message_id < start_id { + return Some(WaitResult::Lagged(start_id - message_id)); + } + + let current_message_index = (message_id - start_id) as usize; + + if current_message_index >= self.queue.len() { + return None; + } + + // We've checked that the index is valid + let queue_item = self.queue.iter_mut().nth(current_message_index).unwrap(); + + // We're reading this item, so decrement the counter + queue_item.1 -= 1; + let message = queue_item.0.clone(); + + if current_message_index == 0 && queue_item.1 == 0 { + self.queue.pop_front(); + self.publisher_wakers.wake(); + } + + Some(WaitResult::Message(message)) + } + + fn register_subscriber_waker(&mut self, waker: &Waker) { + match self.subscriber_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a subscriber that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.subscriber_wakers.wake(); + self.subscriber_wakers.register(waker).unwrap(); + } + } + } + + fn register_publisher_waker(&mut self, waker: &Waker) { + match self.publisher_wakers.register(waker) { + Ok(()) => {} + Err(_) => { + // All waker slots were full. This can only happen when there was a publisher that now has dropped. + // We need to throw it away. It's a bit inefficient, but we can wake everything. + // Any future that is still active will simply reregister. + // This won't happen a lot, so it's ok. + self.publisher_wakers.wake(); + self.publisher_wakers.register(waker).unwrap(); + } + } + } + + fn unregister_subscriber(&mut self, subscriber_next_message_id: u64) { + self.subscriber_count -= 1; + + // All messages that haven't been read yet by this subscriber must have their counter decremented + let start_id = self.next_message_id - self.queue.len() as u64; + if subscriber_next_message_id >= start_id { + let current_message_index = (subscriber_next_message_id - start_id) as usize; + self.queue + .iter_mut() + .skip(current_message_index) + .for_each(|(_, counter)| *counter -= 1); + } + } + + fn unregister_publisher(&mut self) { + self.publisher_count -= 1; + } +} + +/// Error type for the [PubSubChannel] +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + /// All subscriber slots are used. To add another subscriber, first another subscriber must be dropped or + /// the capacity of the channels must be increased. + MaximumSubscribersReached, + /// All publisher slots are used. To add another publisher, first another publisher must be dropped or + /// the capacity of the channels must be increased. + MaximumPublishersReached, +} + +/// 'Middle level' behaviour of the pubsub channel. +/// This trait is used so that Sub and Pub can be generic over the channel. +pub trait PubSubBehavior<T> { + /// Try to get a message from the queue with the given message id. + /// + /// If the message is not yet present and a context is given, then its waker is registered in the subsriber wakers. + fn get_message_with_context(&self, next_message_id: &mut u64, cx: Option<&mut Context<'_>>) -> Poll<WaitResult<T>>; + + /// Try to publish a message to the queue. + /// + /// If the queue is full and a context is given, then its waker is registered in the publisher wakers. + fn publish_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), T>; + + /// Publish a message immediately + fn publish_immediate(&self, message: T); + + /// Let the channel know that a subscriber has dropped + fn unregister_subscriber(&self, subscriber_next_message_id: u64); + + /// Let the channel know that a publisher has dropped + fn unregister_publisher(&self); +} + +/// The result of the subscriber wait procedure +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum WaitResult<T> { + /// The subscriber did not receive all messages and lagged by the given amount of messages. + /// (This is the amount of messages that were missed) + Lagged(u64), + /// A message was received + Message(T), +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::blocking_mutex::raw::NoopRawMutex; + + #[futures_test::test] + async fn dyn_pub_sub_works() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.dyn_subscriber().unwrap(); + let mut sub1 = channel.dyn_subscriber().unwrap(); + let pub0 = channel.dyn_publisher().unwrap(); + + pub0.publish(42).await; + + assert_eq!(sub0.next_message().await, WaitResult::Message(42)); + assert_eq!(sub1.next_message().await, WaitResult::Message(42)); + + assert_eq!(sub0.try_next_message(), None); + assert_eq!(sub1.try_next_message(), None); + } + + #[futures_test::test] + async fn all_subscribers_receive() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.subscriber().unwrap(); + let mut sub1 = channel.subscriber().unwrap(); + let pub0 = channel.publisher().unwrap(); + + pub0.publish(42).await; + + assert_eq!(sub0.next_message().await, WaitResult::Message(42)); + assert_eq!(sub1.next_message().await, WaitResult::Message(42)); + + assert_eq!(sub0.try_next_message(), None); + assert_eq!(sub1.try_next_message(), None); + } + + #[futures_test::test] + async fn lag_when_queue_full_on_immediate_publish() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let mut sub0 = channel.subscriber().unwrap(); + let pub0 = channel.publisher().unwrap(); + + pub0.publish_immediate(42); + pub0.publish_immediate(43); + pub0.publish_immediate(44); + pub0.publish_immediate(45); + pub0.publish_immediate(46); + pub0.publish_immediate(47); + + assert_eq!(sub0.try_next_message(), Some(WaitResult::Lagged(2))); + assert_eq!(sub0.next_message().await, WaitResult::Message(44)); + assert_eq!(sub0.next_message().await, WaitResult::Message(45)); + assert_eq!(sub0.next_message().await, WaitResult::Message(46)); + assert_eq!(sub0.next_message().await, WaitResult::Message(47)); + assert_eq!(sub0.try_next_message(), None); + } + + #[test] + fn limited_subs_and_pubs() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let sub0 = channel.subscriber(); + let sub1 = channel.subscriber(); + let sub2 = channel.subscriber(); + let sub3 = channel.subscriber(); + let sub4 = channel.subscriber(); + + assert!(sub0.is_ok()); + assert!(sub1.is_ok()); + assert!(sub2.is_ok()); + assert!(sub3.is_ok()); + assert_eq!(sub4.err().unwrap(), Error::MaximumSubscribersReached); + + drop(sub0); + + let sub5 = channel.subscriber(); + assert!(sub5.is_ok()); + + // publishers + + let pub0 = channel.publisher(); + let pub1 = channel.publisher(); + let pub2 = channel.publisher(); + let pub3 = channel.publisher(); + let pub4 = channel.publisher(); + + assert!(pub0.is_ok()); + assert!(pub1.is_ok()); + assert!(pub2.is_ok()); + assert!(pub3.is_ok()); + assert_eq!(pub4.err().unwrap(), Error::MaximumPublishersReached); + + drop(pub0); + + let pub5 = channel.publisher(); + assert!(pub5.is_ok()); + } + + #[test] + fn publisher_wait_on_full_queue() { + let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new(); + + let pub0 = channel.publisher().unwrap(); + + // There are no subscribers, so the queue will never be full + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + + let sub0 = channel.subscriber().unwrap(); + + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Ok(())); + assert_eq!(pub0.try_publish(0), Err(0)); + + drop(sub0); + } +} diff --git a/embassy/src/channel/pubsub/publisher.rs b/embassy/src/channel/pubsub/publisher.rs new file mode 100644 index 000000000..89a0b9247 --- /dev/null +++ b/embassy/src/channel/pubsub/publisher.rs @@ -0,0 +1,183 @@ +//! Implementation of anything directly publisher related + +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures::Future; + +use super::{PubSubBehavior, PubSubChannel}; +use crate::blocking_mutex::raw::RawMutex; + +/// A publisher to a channel +pub struct Pub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The channel we are a publisher for + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Pub<'a, PSB, T> { + pub(super) fn new(channel: &'a PSB) -> Self { + Self { + channel, + _phantom: Default::default(), + } + } + + /// Publish a message right now even when the queue is full. + /// This may cause a subscriber to miss an older message. + pub fn publish_immediate(&self, message: T) { + self.channel.publish_immediate(message) + } + + /// Publish a message. But if the message queue is full, wait for all subscribers to have read the last message + pub fn publish<'s>(&'s self, message: T) -> PublisherWaitFuture<'s, 'a, PSB, T> { + PublisherWaitFuture { + message: Some(message), + publisher: self, + } + } + + /// Publish a message if there is space in the message queue + pub fn try_publish(&self, message: T) -> Result<(), T> { + self.channel.publish_with_context(message, None) + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Drop for Pub<'a, PSB, T> { + fn drop(&mut self) { + self.channel.unregister_publisher() + } +} + +/// A publisher that holds a dynamic reference to the channel +pub struct DynPublisher<'a, T: Clone>(pub(super) Pub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynPublisher<'a, T> { + type Target = Pub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynPublisher<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A publisher that holds a generic reference to the channel +pub struct Publisher<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) Pub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for Publisher<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = Pub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for Publisher<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A publisher that can only use the `publish_immediate` function, but it doesn't have to be registered with the channel. +/// (So an infinite amount is possible) +pub struct ImmediatePub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The channel we are a publisher for + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> ImmediatePub<'a, PSB, T> { + pub(super) fn new(channel: &'a PSB) -> Self { + Self { + channel, + _phantom: Default::default(), + } + } + /// Publish the message right now even when the queue is full. + /// This may cause a subscriber to miss an older message. + pub fn publish_immediate(&mut self, message: T) { + self.channel.publish_immediate(message) + } + + /// Publish a message if there is space in the message queue + pub fn try_publish(&self, message: T) -> Result<(), T> { + self.channel.publish_with_context(message, None) + } +} + +/// An immediate publisher that holds a dynamic reference to the channel +pub struct DynImmediatePublisher<'a, T: Clone>(pub(super) ImmediatePub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynImmediatePublisher<'a, T> { + type Target = ImmediatePub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynImmediatePublisher<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// An immediate publisher that holds a generic reference to the channel +pub struct ImmediatePublisher<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) ImmediatePub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for ImmediatePublisher<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = ImmediatePub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for ImmediatePublisher<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Future for the publisher wait action +pub struct PublisherWaitFuture<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The message we need to publish + message: Option<T>, + publisher: &'s Pub<'a, PSB, T>, +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Future for PublisherWaitFuture<'s, 'a, PSB, T> { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let message = self.message.take().unwrap(); + match self.publisher.channel.publish_with_context(message, Some(cx)) { + Ok(()) => Poll::Ready(()), + Err(message) => { + self.message = Some(message); + Poll::Pending + } + } + } +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for PublisherWaitFuture<'s, 'a, PSB, T> {} diff --git a/embassy/src/channel/pubsub/subscriber.rs b/embassy/src/channel/pubsub/subscriber.rs new file mode 100644 index 000000000..23c4938d9 --- /dev/null +++ b/embassy/src/channel/pubsub/subscriber.rs @@ -0,0 +1,153 @@ +//! Implementation of anything directly subscriber related + +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures::Future; + +use super::{PubSubBehavior, PubSubChannel, WaitResult}; +use crate::blocking_mutex::raw::RawMutex; + +/// A subscriber to a channel +pub struct Sub<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + /// The message id of the next message we are yet to receive + next_message_id: u64, + /// The channel we are a subscriber to + channel: &'a PSB, + _phantom: PhantomData<T>, +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Sub<'a, PSB, T> { + pub(super) fn new(next_message_id: u64, channel: &'a PSB) -> Self { + Self { + next_message_id, + channel, + _phantom: Default::default(), + } + } + + /// Wait for a published message + pub fn next_message<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, PSB, T> { + SubscriberWaitFuture { subscriber: self } + } + + /// Wait for a published message (ignoring lag results) + pub async fn next_message_pure(&mut self) -> T { + loop { + match self.next_message().await { + WaitResult::Lagged(_) => continue, + WaitResult::Message(message) => break message, + } + } + } + + /// Try to see if there's a published message we haven't received yet. + /// + /// This function does not peek. The message is received if there is one. + pub fn try_next_message(&mut self) -> Option<WaitResult<T>> { + match self.channel.get_message_with_context(&mut self.next_message_id, None) { + Poll::Ready(result) => Some(result), + Poll::Pending => None, + } + } + + /// Try to see if there's a published message we haven't received yet (ignoring lag results). + /// + /// This function does not peek. The message is received if there is one. + pub fn try_next_message_pure(&mut self) -> Option<T> { + loop { + match self.try_next_message() { + Some(WaitResult::Lagged(_)) => continue, + Some(WaitResult::Message(message)) => break Some(message), + None => break None, + } + } + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Drop for Sub<'a, PSB, T> { + fn drop(&mut self) { + self.channel.unregister_subscriber(self.next_message_id) + } +} + +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for Sub<'a, PSB, T> {} + +/// Warning: The stream implementation ignores lag results and returns all messages. +/// This might miss some messages without you knowing it. +impl<'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> futures::Stream for Sub<'a, PSB, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + match self + .channel + .get_message_with_context(&mut self.next_message_id, Some(cx)) + { + Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)), + Poll::Ready(WaitResult::Lagged(_)) => { + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } +} + +/// A subscriber that holds a dynamic reference to the channel +pub struct DynSubscriber<'a, T: Clone>(pub(super) Sub<'a, dyn PubSubBehavior<T> + 'a, T>); + +impl<'a, T: Clone> Deref for DynSubscriber<'a, T> { + type Target = Sub<'a, dyn PubSubBehavior<T> + 'a, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, T: Clone> DerefMut for DynSubscriber<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// A subscriber that holds a generic reference to the channel +pub struct Subscriber<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize>( + pub(super) Sub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>, +); + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> Deref + for Subscriber<'a, M, T, CAP, SUBS, PUBS> +{ + type Target = Sub<'a, PubSubChannel<M, T, CAP, SUBS, PUBS>, T>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> DerefMut + for Subscriber<'a, M, T, CAP, SUBS, PUBS> +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/// Future for the subscriber wait action +pub struct SubscriberWaitFuture<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> { + subscriber: &'s mut Sub<'a, PSB, T>, +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Future for SubscriberWaitFuture<'s, 'a, PSB, T> { + type Output = WaitResult<T>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + self.subscriber + .channel + .get_message_with_context(&mut self.subscriber.next_message_id, Some(cx)) + } +} + +impl<'s, 'a, PSB: PubSubBehavior<T> + ?Sized, T: Clone> Unpin for SubscriberWaitFuture<'s, 'a, PSB, T> {} diff --git a/embassy/src/waitqueue/mod.rs b/embassy/src/waitqueue/mod.rs index a2bafad99..5c4e1bc3b 100644 --- a/embassy/src/waitqueue/mod.rs +++ b/embassy/src/waitqueue/mod.rs @@ -3,3 +3,6 @@ #[cfg_attr(feature = "executor-agnostic", path = "waker_agnostic.rs")] mod waker; pub use waker::*; + +mod multi_waker; +pub use multi_waker::*; diff --git a/embassy/src/waitqueue/multi_waker.rs b/embassy/src/waitqueue/multi_waker.rs new file mode 100644 index 000000000..325d2cb3a --- /dev/null +++ b/embassy/src/waitqueue/multi_waker.rs @@ -0,0 +1,33 @@ +use core::task::Waker; + +use super::WakerRegistration; + +/// Utility struct to register and wake multiple wakers. +pub struct MultiWakerRegistration<const N: usize> { + wakers: [WakerRegistration; N], +} + +impl<const N: usize> MultiWakerRegistration<N> { + /// Create a new empty instance + pub const fn new() -> Self { + const WAKER: WakerRegistration = WakerRegistration::new(); + Self { wakers: [WAKER; N] } + } + + /// Register a waker. If the buffer is full the function returns it in the error + pub fn register<'a>(&mut self, w: &'a Waker) -> Result<(), &'a Waker> { + if let Some(waker_slot) = self.wakers.iter_mut().find(|waker_slot| !waker_slot.occupied()) { + waker_slot.register(w); + Ok(()) + } else { + Err(w) + } + } + + /// Wake all registered wakers. This clears the buffer + pub fn wake(&mut self) { + for waker_slot in self.wakers.iter_mut() { + waker_slot.wake() + } + } +} diff --git a/embassy/src/waitqueue/waker.rs b/embassy/src/waitqueue/waker.rs index da907300a..a90154cce 100644 --- a/embassy/src/waitqueue/waker.rs +++ b/embassy/src/waitqueue/waker.rs @@ -50,6 +50,11 @@ impl WakerRegistration { unsafe { wake_task(w) } } } + + /// Returns true if a waker is currently registered + pub fn occupied(&self) -> bool { + self.waker.is_some() + } } // SAFETY: `WakerRegistration` effectively contains an `Option<Waker>`, diff --git a/embassy/src/waitqueue/waker_agnostic.rs b/embassy/src/waitqueue/waker_agnostic.rs index 89430aa4c..62e3adb79 100644 --- a/embassy/src/waitqueue/waker_agnostic.rs +++ b/embassy/src/waitqueue/waker_agnostic.rs @@ -47,6 +47,11 @@ impl WakerRegistration { w.wake() } } + + /// Returns true if a waker is currently registered + pub fn occupied(&self) -> bool { + self.waker.is_some() + } } /// Utility struct to register and wake a waker. diff --git a/examples/nrf/src/bin/pubsub.rs b/examples/nrf/src/bin/pubsub.rs new file mode 100644 index 000000000..2c3a355c2 --- /dev/null +++ b/examples/nrf/src/bin/pubsub.rs @@ -0,0 +1,106 @@ +#![no_std] +#![no_main] +#![feature(type_alias_impl_trait)] + +use defmt::unwrap; +use embassy::blocking_mutex::raw::ThreadModeRawMutex; +use embassy::channel::pubsub::{DynSubscriber, PubSubChannel, Subscriber}; +use embassy::executor::Spawner; +use embassy::time::{Duration, Timer}; +use {defmt_rtt as _, panic_probe as _}; + +/// Create the message bus. It has a queue of 4, supports 3 subscribers and 1 publisher +static MESSAGE_BUS: PubSubChannel<ThreadModeRawMutex, Message, 4, 3, 1> = PubSubChannel::new(); + +#[derive(Clone, defmt::Format)] +enum Message { + A, + B, + C, +} + +#[embassy::main] +async fn main(spawner: Spawner, _p: embassy_nrf::Peripherals) { + defmt::info!("Hello World!"); + + // It's good to set up the subscribers before publishing anything. + // A subscriber will only yield messages that have been published after its creation. + + spawner.must_spawn(fast_logger(unwrap!(MESSAGE_BUS.subscriber()))); + spawner.must_spawn(slow_logger(unwrap!(MESSAGE_BUS.dyn_subscriber()))); + spawner.must_spawn(slow_logger_pure(unwrap!(MESSAGE_BUS.dyn_subscriber()))); + + // Get a publisher + let message_publisher = unwrap!(MESSAGE_BUS.publisher()); + // We can't get more (normal) publishers + // We can have an infinite amount of immediate publishers. They can't await a publish, only do an immediate publish + defmt::assert!(MESSAGE_BUS.publisher().is_err()); + + let mut index = 0; + loop { + Timer::after(Duration::from_millis(500)).await; + + let message = match index % 3 { + 0 => Message::A, + 1 => Message::B, + 2..=u32::MAX => Message::C, + }; + + // We publish immediately and don't await anything. + // If the queue is full, it will cause the oldest message to not be received by some/all subscribers + message_publisher.publish_immediate(message); + + // Try to comment out the last one and uncomment this line below. + // The behaviour will change: + // - The subscribers won't miss any messages any more + // - Trying to publish now has some wait time when the queue is full + + // message_publisher.publish(message).await; + + index += 1; + } +} + +/// A logger task that just awaits the messages it receives +/// +/// This takes the generic `Subscriber`. This is most performant, but requires you to write down all of the generics +#[embassy::task] +async fn fast_logger(mut messages: Subscriber<'static, ThreadModeRawMutex, Message, 4, 3, 1>) { + loop { + let message = messages.next_message().await; + defmt::info!("Received message at fast logger: {:?}", message); + } +} + +/// A logger task that awaits the messages, but also does some other work. +/// Because of this, depeding on how the messages were published, the subscriber might miss some messages +/// +/// This takes the dynamic `DynSubscriber`. This is not as performant as the generic version, but let's you ignore some of the generics +#[embassy::task] +async fn slow_logger(mut messages: DynSubscriber<'static, Message>) { + loop { + // Do some work + Timer::after(Duration::from_millis(2000)).await; + + // If the publisher has used the `publish_immediate` function, then we may receive a lag message here + let message = messages.next_message().await; + defmt::info!("Received message at slow logger: {:?}", message); + + // If the previous one was a lag message, then we should receive the next message here immediately + let message = messages.next_message().await; + defmt::info!("Received message at slow logger: {:?}", message); + } +} + +/// Same as `slow_logger` but it ignores lag results +#[embassy::task] +async fn slow_logger_pure(mut messages: DynSubscriber<'static, Message>) { + loop { + // Do some work + Timer::after(Duration::from_millis(2000)).await; + + // Instead of receiving lags here, we just ignore that and read the next message + let message = messages.next_message_pure().await; + defmt::info!("Received message at slow logger pure: {:?}", message); + } +}