From 1b9d5e50710cefde4bd1e234695783d62e824c68 Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 6 Jun 2021 18:36:16 +1000 Subject: [PATCH 01/23] Multi Producer Single Consumer channel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An MPSC inspired by Tokio and Crossbeam. The MPSC is designed to support both single and multi core processors, with only single core implemented at this time. The allocation of the channel’s buffer is inspired by the const generic parameters that Heapless provides. --- embassy/Cargo.toml | 7 + embassy/src/lib.rs | 1 + embassy/src/util/mod.rs | 1 + embassy/src/util/mpsc.rs | 919 +++++++++++++++++++++++++++++++++++ embassy/src/util/mutex.rs | 27 + examples/nrf/src/bin/mpsc.rs | 64 +++ 6 files changed, 1019 insertions(+) create mode 100644 embassy/src/util/mpsc.rs create mode 100644 examples/nrf/src/bin/mpsc.rs diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index b2ad80495..d26490247 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -3,6 +3,7 @@ name = "embassy" version = "0.1.0" authors = ["Dario Nieuwenhuis "] edition = "2018" +resolver = "2" [features] default = [] @@ -36,3 +37,9 @@ embedded-hal = "0.2.5" # Workaround https://github.com/japaric/cast.rs/pull/27 cast = { version = "=0.2.3", default-features = false } + +[dev-dependencies] +futures-executor = { version = "0.3", features = [ "thread-pool" ] } +futures-test = "0.3" +futures-timer = "0.3" +futures-util = { version = "0.3", features = [ "channel" ] } diff --git a/embassy/src/lib.rs b/embassy/src/lib.rs index 41102a180..845f82a3f 100644 --- a/embassy/src/lib.rs +++ b/embassy/src/lib.rs @@ -7,6 +7,7 @@ #![feature(min_type_alias_impl_trait)] #![feature(impl_trait_in_bindings)] #![feature(type_alias_impl_trait)] +#![feature(maybe_uninit_ref)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; diff --git a/embassy/src/util/mod.rs b/embassy/src/util/mod.rs index 88ae5c285..87d313e28 100644 --- a/embassy/src/util/mod.rs +++ b/embassy/src/util/mod.rs @@ -11,6 +11,7 @@ mod waker; pub use drop_bomb::*; pub use forever::*; +pub mod mpsc; pub use mutex::*; pub use on_drop::*; pub use portal::*; diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs new file mode 100644 index 000000000..d24eb00bf --- /dev/null +++ b/embassy/src/util/mpsc.rs @@ -0,0 +1,919 @@ +//! A multi-producer, single-consumer queue for sending values between +//! asynchronous tasks. This queue takes a Mutex type so that various +//! targets can be attained. For example, a ThreadModeMutex can be used +//! for single-core Cortex-M targets where messages are only passed +//! between tasks running in thread mode. Similarly, a CriticalSectionMutex +//! can also be used for single-core targets where messages are to be +//! passed from exception mode e.g. out of an interrupt handler. +//! +//! This module provides a bounded channel that has a limit on the number of +//! messages that it can store, and if this limit is reached, trying to send +//! another message will result in an error being returned. +//! +//! Similar to the `mpsc` channels provided by `std`, the channel constructor +//! functions provide separate send and receive handles, [`Sender`] and +//! [`Receiver`]. If there is no message to read, the current task will be +//! notified when a new value is sent. [`Sender`] allows sending values into +//! the channel. If the bounded channel is at capacity, the send is rejected. +//! +//! # Disconnection +//! +//! When all [`Sender`] handles have been dropped, it is no longer +//! possible to send values into the channel. This is considered the termination +//! event of the stream. +//! +//! If the [`Receiver`] handle is dropped, then messages can no longer +//! be read out of the channel. In this case, all further attempts to send will +//! result in an error. +//! +//! # Clean Shutdown +//! +//! When the [`Receiver`] is dropped, it is possible for unprocessed messages to +//! remain in the channel. Instead, it is usually desirable to perform a "clean" +//! shutdown. To do this, the receiver first calls `close`, which will prevent +//! any further messages to be sent into the channel. Then, the receiver +//! consumes the channel to completion, at which point the receiver can be +//! dropped. +//! +//! This channel and its associated types were derived from https://docs.rs/tokio/0.1.22/tokio/sync/mpsc/fn.channel.html + +use core::cell::UnsafeCell; +use core::fmt; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::pin::Pin; +use core::task::Context; +use core::task::Poll; +use core::task::Waker; + +use futures::Future; + +use super::CriticalSectionMutex; +use super::Mutex; +use super::ThreadModeMutex; + +/// Send values to the associated `Receiver`. +/// +/// Instances are created by the [`split`](split) function. +pub struct Sender<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: *mut Channel, + phantom_data: &'ch PhantomData, +} + +// Safe to pass the sender around +unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} + +/// Receive values from the associated `Sender`. +/// +/// Instances are created by the [`split`](split) function. +pub struct Receiver<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: *mut Channel, + _phantom_data: &'ch PhantomData, +} + +// Safe to pass the receiver around +unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} + +/// Splits a bounded mpsc channel into a `Sender` and `Receiver`. +/// +/// All data sent on `Sender` will become available on `Receiver` in the same +/// order as it was sent. +/// +/// The `Sender` can be cloned to `send` to the same channel from multiple code +/// locations. Only one `Receiver` is valid. +/// +/// If the `Receiver` is disconnected while trying to `send`, the `send` method +/// will return a `SendError`. Similarly, if `Sender` is disconnected while +/// trying to `recv`, the `recv` method will return a `RecvError`. +/// +/// Note that when splitting the channel, the sender and receiver cannot outlive +/// their channel. The following will therefore fail compilation: +//// +/// ```compile_fail +/// use embassy::util::mpsc; +/// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; +/// +/// let (sender, receiver) = { +/// let mut channel = Channel::::with_thread_mode_only(); +/// mpsc::split(&mut channel) +/// }; +/// ``` +pub fn split<'ch, M, T, const N: usize>( + channel: &'ch mut Channel, +) -> (Sender<'ch, M, T, N>, Receiver<'ch, M, T, N>) +where + M: Mutex, +{ + let sender = Sender { + channel, + phantom_data: &PhantomData, + }; + let receiver = Receiver { + channel, + _phantom_data: &PhantomData, + }; + channel.register_receiver(); + channel.register_sender(); + (sender, receiver) +} + +impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> +where + M: Mutex, +{ + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. + /// + /// Note that if [`close`] is called, but there are still outstanding + /// messages from before it was closed, the channel is not considered + /// closed by `recv` until they are all consumed. + /// + /// [`close`]: Self::close + pub async fn recv(&mut self) -> Option { + self.await + } + + /// Attempts to immediately receive a message on this `Receiver` + /// + /// This method will either receive a message from the channel immediately or return an error + /// if the channel is empty. + pub fn try_recv(&self) -> Result { + unsafe { self.channel.as_mut().unwrap().try_recv() } + } + + /// Closes the receiving half of a channel without dropping it. + /// + /// This prevents any further messages from being sent on the channel while + /// still enabling the receiver to drain messages that are buffered. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. If there are + /// outstanding messages, the `recv` method will not return `None` + /// until those are released. + /// + pub fn close(&mut self) { + unsafe { self.channel.as_mut().unwrap().close() } + } +} + +impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> +where + M: Mutex, +{ + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.try_recv() { + Ok(v) => Poll::Ready(Some(v)), + Err(TryRecvError::Closed) => Poll::Ready(None), + Err(TryRecvError::Empty) => { + unsafe { + self.channel + .as_mut() + .unwrap() + .set_receiver_waker(cx.waker().clone()); + }; + Poll::Pending + } + } + } +} + +impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> +where + M: Mutex, +{ + fn drop(&mut self) { + unsafe { self.channel.as_mut().unwrap().deregister_receiver() } + } +} + +impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> +where + M: Mutex, +{ + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + /// + /// # Errors + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`close`]: Receiver::close + /// [`Receiver`]: Receiver + pub async fn send(&self, message: T) -> Result<(), SendError> { + SendFuture { + sender: self.clone(), + message: UnsafeCell::new(message), + } + .await + } + + /// Attempts to immediately send a message on this `Sender` + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + /// + /// # Errors + /// + /// If the channel capacity has been reached, i.e., the channel has `n` + /// buffered values where `n` is the argument passed to [`channel`], then an + /// error is returned. + /// + /// If the receive half of the channel is closed, either due to [`close`] + /// being called or the [`Receiver`] handle dropping, the function returns + /// an error. The error includes the value passed to `send`. + /// + /// [`send`]: Sender::send + /// [`channel`]: channel + /// [`close`]: Receiver::close + pub fn try_send(&self, message: T) -> Result<(), TrySendError> { + unsafe { self.channel.as_mut().unwrap().try_send(message) } + } + + /// Completes when the receiver has dropped. + /// + /// This allows the producers to get notified when interest in the produced + /// values is canceled and immediately stop doing work. + pub async fn closed(&self) { + CloseFuture { + sender: self.clone(), + } + .await + } + + /// Checks if the channel has been closed. This happens when the + /// [`Receiver`] is dropped, or when the [`Receiver::close`] method is + /// called. + /// + /// [`Receiver`]: crate::sync::mpsc::Receiver + /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close + pub fn is_closed(&self) -> bool { + unsafe { self.channel.as_mut().unwrap().is_closed() } + } +} + +struct SendFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, + message: UnsafeCell, +} + +impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> +where + M: Mutex, +{ + type Output = Result<(), SendError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.sender.try_send(unsafe { self.message.get().read() }) { + Ok(..) => Poll::Ready(Ok(())), + Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), + Err(TrySendError::Full(..)) => { + unsafe { + self.sender + .channel + .as_mut() + .unwrap() + .set_senders_waker(cx.waker().clone()); + }; + Poll::Pending + // Note we leave the existing UnsafeCell contents - they still + // contain the original message. We could create another UnsafeCell + // with the message of Full, but there's no real need. + } + } + } +} + +struct CloseFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, +} + +impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> +where + M: Mutex, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.sender.is_closed() { + Poll::Ready(()) + } else { + unsafe { + self.sender + .channel + .as_mut() + .unwrap() + .set_senders_waker(cx.waker().clone()); + }; + Poll::Pending + } + } +} + +impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> +where + M: Mutex, +{ + fn drop(&mut self) { + unsafe { self.channel.as_mut().unwrap().deregister_sender() } + } +} + +impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> +where + M: Mutex, +{ + fn clone(&self) -> Self { + unsafe { self.channel.as_mut().unwrap().register_sender() }; + Sender { + channel: self.channel, + phantom_data: self.phantom_data, + } + } +} + +/// An error returned from the [`try_recv`] method. +/// +/// [`try_recv`]: super::Receiver::try_recv +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// A message could not be received because the channel is empty. + Empty, + + /// The message could not be received because the channel is empty and closed. + Closed, +} + +/// Error returned by the `Sender`. +#[derive(Debug)] +pub struct SendError(pub T); + +impl fmt::Display for SendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } +} + +/// This enumeration is the list of the possible error outcomes for the +/// [try_send](super::Sender::try_send) method. +#[derive(Debug)] +pub enum TrySendError { + /// The data could not be sent on the channel because the channel is + /// currently full and sending would require blocking. + Full(T), + + /// The receive half of the channel was explicitly closed or has been + /// dropped. + Closed(T), +} + +impl fmt::Display for TrySendError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "{}", + match self { + TrySendError::Full(..) => "no available capacity", + TrySendError::Closed(..) => "channel closed", + } + ) + } +} + +pub struct ChannelState { + buf: [MaybeUninit>; N], + read_pos: usize, + write_pos: usize, + full: bool, + closing: bool, + closed: bool, + receiver_registered: bool, + senders_registered: u32, + receiver_waker: Option, + senders_waker: Option, +} + +impl ChannelState { + const INIT: MaybeUninit> = MaybeUninit::uninit(); + + const fn new() -> Self { + let buf = [Self::INIT; N]; + let read_pos = 0; + let write_pos = 0; + let full = false; + let closing = false; + let closed = false; + let receiver_registered = false; + let senders_registered = 0; + let receiver_waker = None; + let senders_waker = None; + ChannelState { + buf, + read_pos, + write_pos, + full, + closing, + closed, + receiver_registered, + senders_registered, + receiver_waker, + senders_waker, + } + } +} + +/// A a bounded mpsc channel for communicating between asynchronous tasks +/// with backpressure. +/// +/// The channel will buffer up to the provided number of messages. Once the +/// buffer is full, attempts to `send` new messages will wait until a message is +/// received from the channel. +/// +/// All data sent will become available in the same order as it was sent. +pub struct Channel +where + M: Mutex, +{ + mutex: M, + state: ChannelState, +} + +pub type WithCriticalSections = CriticalSectionMutex<()>; + +impl Channel { + /// Establish a new bounded channel using critical sections. Critical sections + /// should be used only single core targets where communication is required + /// from exception mode e.g. interrupt handlers. To create one: + /// + /// ``` + /// use embassy::util::mpsc; + /// use embassy::util::mpsc::{Channel, WithCriticalSections}; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = mpsc::Channel::::with_critical_sections(); + /// // once we have a channel, obtain its sender and receiver + /// let (sender, receiver) = mpsc::split(&mut channel); + /// ``` + pub const fn with_critical_sections() -> Self { + let mutex = CriticalSectionMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } +} + +pub type WithThreadModeOnly = ThreadModeMutex<()>; + +impl Channel { + /// Establish a new bounded channel for use in Cortex-M thread mode. Thread + /// mode is intended for application threads on a single core, not interrupts. + /// As such, only one task at a time can acquire a resource and so this + /// channel avoids all locks. To create one: + /// + /// ``` no_run + /// use embassy::util::mpsc; + /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = Channel::::with_thread_mode_only(); + /// // once we have a channel, obtain its sender and receiver + /// let (sender, receiver) = mpsc::split(&mut channel); + /// ``` + pub const fn with_thread_mode_only() -> Self { + let mutex = ThreadModeMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } +} + +impl Channel +where + M: Mutex, +{ + fn try_recv(&mut self) -> Result { + let state = &mut self.state; + self.mutex.lock(|_| { + if !state.closed { + if state.read_pos != state.write_pos || state.full { + if state.full { + state.full = false; + state.senders_waker.take().map(|w| w.wake()); + } + let message = + unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; + state.read_pos = (state.read_pos + 1) % state.buf.len(); + Ok(message) + } else if !state.closing { + Err(TryRecvError::Empty) + } else { + state.closed = true; + state.senders_waker.take().map(|w| w.wake()); + Err(TryRecvError::Closed) + } + } else { + Err(TryRecvError::Closed) + } + }) + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError> { + let state = &mut self.state; + self.mutex.lock(|_| { + if !state.closed { + if !state.full { + state.buf[state.write_pos] = MaybeUninit::new(message.into()); + state.write_pos = (state.write_pos + 1) % state.buf.len(); + if state.write_pos == state.read_pos { + state.full = true; + } + state.receiver_waker.take().map(|w| w.wake()); + Ok(()) + } else { + Err(TrySendError::Full(message)) + } + } else { + Err(TrySendError::Closed(message)) + } + }) + } + + fn close(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.receiver_waker.take().map(|w| w.wake()); + state.closing = true; + }); + } + + fn is_closed(&mut self) -> bool { + let state = &self.state; + self.mutex.lock(|_| state.closing || state.closed) + } + + fn register_receiver(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + assert!(!state.receiver_registered); + state.receiver_registered = true; + }); + } + + fn deregister_receiver(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + if state.receiver_registered { + state.closed = true; + state.senders_waker.take().map(|w| w.wake()); + } + state.receiver_registered = false; + }) + } + + fn register_sender(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.senders_registered = state.senders_registered + 1; + }) + } + + fn deregister_sender(&mut self) { + let state = &mut self.state; + self.mutex.lock(|_| { + assert!(state.senders_registered > 0); + state.senders_registered = state.senders_registered - 1; + if state.senders_registered == 0 { + state.receiver_waker.take().map(|w| w.wake()); + state.closing = true; + } + }) + } + + fn set_receiver_waker(&mut self, receiver_waker: Waker) { + let state = &mut self.state; + self.mutex.lock(|_| { + state.receiver_waker = Some(receiver_waker); + }) + } + + fn set_senders_waker(&mut self, senders_waker: Waker) { + let state = &mut self.state; + self.mutex.lock(|_| { + + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + if let Some(waker) = state.senders_waker.clone() { + if !senders_waker.will_wake(&waker) { + trace!("Waking an an active send waker due to being superseded with a new one. While benign, please report this."); + waker.wake(); + } + } + state.senders_waker = Some(senders_waker); + }) + } +} + +#[cfg(test)] +mod tests { + use core::time::Duration; + + use futures::task::SpawnExt; + use futures_executor::ThreadPool; + use futures_timer::Delay; + + use super::*; + + fn capacity(c: &Channel) -> usize + where + M: Mutex, + { + if !c.state.full { + if c.state.write_pos > c.state.read_pos { + (c.state.buf.len() - c.state.write_pos) + c.state.read_pos + } else { + (c.state.buf.len() - c.state.read_pos) + c.state.write_pos + } + } else { + 0 + } + } + + /// A mutex that does nothing - useful for our testing purposes + pub struct NoopMutex { + inner: UnsafeCell, + } + + impl NoopMutex { + pub const fn new(value: T) -> Self { + NoopMutex { + inner: UnsafeCell::new(value), + } + } + } + + impl NoopMutex { + pub fn borrow(&self) -> &T { + unsafe { &*self.inner.get() } + } + } + + impl Mutex for NoopMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + f(self.borrow()) + } + } + + pub type WithNoThreads = NoopMutex<()>; + + impl Channel { + pub const fn with_no_threads() -> Self { + let mutex = NoopMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } + } + + #[test] + fn sending_once() { + let mut c = Channel::::with_no_threads(); + assert!(c.try_send(1).is_ok()); + assert_eq!(capacity(&c), 2); + } + + #[test] + fn sending_when_full() { + let mut c = Channel::::with_no_threads(); + let _ = c.try_send(1); + let _ = c.try_send(1); + let _ = c.try_send(1); + match c.try_send(2) { + Err(TrySendError::Full(2)) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 0); + } + + #[test] + fn sending_when_closed() { + let mut c = Channel::::with_no_threads(); + c.state.closed = true; + match c.try_send(2) { + Err(TrySendError::Closed(2)) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn receiving_once_with_one_send() { + let mut c = Channel::::with_no_threads(); + assert!(c.try_send(1).is_ok()); + assert_eq!(c.try_recv().unwrap(), 1); + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_empty() { + let mut c = Channel::::with_no_threads(); + match c.try_recv() { + Err(TryRecvError::Empty) => assert!(true), + _ => assert!(false), + } + assert_eq!(capacity(&c), 3); + } + + #[test] + fn receiving_when_closed() { + let mut c = Channel::::with_no_threads(); + c.state.closed = true; + match c.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn simple_send_and_receive() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + assert!(s.clone().try_send(1).is_ok()); + assert_eq!(r.try_recv().unwrap(), 1); + } + + #[test] + fn should_close_without_sender() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + drop(s); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_close_once_drained() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + assert!(s.try_send(1).is_ok()); + drop(s); + assert_eq!(r.try_recv().unwrap(), 1); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_reject_send_when_receiver_dropped() { + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); + drop(r); + match s.try_send(1) { + Err(TrySendError::Closed(1)) => assert!(true), + _ => assert!(false), + } + } + + #[test] + fn should_reject_send_when_channel_closed() { + let mut c = Channel::::with_no_threads(); + let (s, mut r) = split(&mut c); + assert!(s.try_send(1).is_ok()); + r.close(); + assert_eq!(r.try_recv().unwrap(), 1); + match r.try_recv() { + Err(TryRecvError::Closed) => assert!(true), + _ => assert!(false), + } + assert!(s.is_closed()); + } + + #[futures_test::test] + async fn receiver_closes_when_sender_dropped_async() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(executor + .spawn(async move { + drop(s); + }) + .is_ok()); + assert_eq!(r.recv().await, None); + } + + #[futures_test::test] + async fn receiver_receives_given_try_send_async() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(executor + .spawn(async move { + let _ = s.try_send(1); + }) + .is_ok()); + assert_eq!(r.recv().await, Some(1)); + } + + #[futures_test::test] + async fn sender_send_completes_if_capacity() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + assert!(s.send(1).await.is_ok()); + assert_eq!(r.recv().await, Some(1)); + } + + #[futures_test::test] + async fn sender_send_completes_if_closed() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, r) = split(unsafe { &mut CHANNEL }); + drop(r); + match s.send(1).await { + Err(SendError(1)) => assert!(true), + _ => assert!(false), + } + } + + #[futures_test::test] + async fn senders_sends_wait_until_capacity() { + let executor = ThreadPool::new().unwrap(); + + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s0, mut r) = split(unsafe { &mut CHANNEL }); + assert!(s0.try_send(1).is_ok()); + let s1 = s0.clone(); + let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); + let send_task_2 = executor.spawn_with_handle(async move { s1.send(3).await }); + // Wish I could think of a means of determining that the async send is waiting instead. + // However, I've used the debugger to observe that the send does indeed wait. + assert!(Delay::new(Duration::from_millis(500)).await.is_ok()); + assert_eq!(r.recv().await, Some(1)); + assert!(executor + .spawn(async move { while let Some(_) = r.recv().await {} }) + .is_ok()); + assert!(send_task_1.unwrap().await.is_ok()); + assert!(send_task_2.unwrap().await.is_ok()); + } + + #[futures_test::test] + async fn sender_close_completes_if_closing() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, mut r) = split(unsafe { &mut CHANNEL }); + r.close(); + s.closed().await; + } + + #[futures_test::test] + async fn sender_close_completes_if_closed() { + static mut CHANNEL: Channel = + Channel::with_critical_sections(); + let (s, r) = split(unsafe { &mut CHANNEL }); + drop(r); + s.closed().await; + } +} diff --git a/embassy/src/util/mutex.rs b/embassy/src/util/mutex.rs index e4b7764ce..682fcb39d 100644 --- a/embassy/src/util/mutex.rs +++ b/embassy/src/util/mutex.rs @@ -1,6 +1,17 @@ use core::cell::UnsafeCell; use critical_section::CriticalSection; +/// Any object implementing this trait guarantees exclusive access to the data contained +/// within the mutex for the duration of the lock. +/// Adapted from https://github.com/rust-embedded/mutex-trait. +pub trait Mutex { + /// Data protected by the mutex. + type Data; + + /// Creates a critical section and grants temporary access to the protected data. + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R; +} + /// A "mutex" based on critical sections /// /// # Safety @@ -33,6 +44,14 @@ impl CriticalSectionMutex { } } +impl Mutex for CriticalSectionMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + critical_section::with(|cs| f(self.borrow(cs))) + } +} + /// A "mutex" that only allows borrowing from thread mode. /// /// # Safety @@ -70,6 +89,14 @@ impl ThreadModeMutex { } } +impl Mutex for ThreadModeMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + f(self.borrow()) + } +} + pub fn in_thread_mode() -> bool { #[cfg(feature = "std")] return Some("main") == std::thread::current().name(); diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs new file mode 100644 index 000000000..6a0f8f471 --- /dev/null +++ b/examples/nrf/src/bin/mpsc.rs @@ -0,0 +1,64 @@ +#![no_std] +#![no_main] +#![feature(min_type_alias_impl_trait)] +#![feature(impl_trait_in_bindings)] +#![feature(type_alias_impl_trait)] +#![allow(incomplete_features)] + +#[path = "../example_common.rs"] +mod example_common; + +use defmt::panic; +use embassy::executor::Spawner; +use embassy::time::{Duration, Timer}; +use embassy::util::mpsc::TryRecvError; +use embassy::util::{mpsc, Forever}; +use embassy_nrf::gpio::{Level, Output, OutputDrive}; +use embassy_nrf::Peripherals; +use embedded_hal::digital::v2::OutputPin; +use mpsc::{Channel, Sender, WithThreadModeOnly}; + +enum LedState { + On, + Off, +} + +static CHANNEL: Forever> = Forever::new(); + +#[embassy::task(pool_size = 1)] +async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { + loop { + let _ = sender.send(LedState::On).await; + Timer::after(Duration::from_secs(1)).await; + let _ = sender.send(LedState::Off).await; + Timer::after(Duration::from_secs(1)).await; + } +} + +#[embassy::main] +async fn main(spawner: Spawner, p: Peripherals) { + let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); + + let channel = CHANNEL.put(Channel::with_thread_mode_only()); + let (sender, mut receiver) = mpsc::split(channel); + + spawner.spawn(my_task(sender)).unwrap(); + + // We could just loop on `receiver.recv()` for simplicity. The code below + // is optimized to drain the queue as fast as possible in the spirit of + // handling events as fast as possible. This optimization is benign when in + // thread mode, but can be useful when interrupts are sending messages + // with the channel having been created via with_critical_sections. + loop { + let maybe_message = match receiver.try_recv() { + m @ Ok(..) => m.ok(), + Err(TryRecvError::Empty) => receiver.recv().await, + Err(TryRecvError::Closed) => break, + }; + match maybe_message { + Some(LedState::On) => led.set_high().unwrap(), + Some(LedState::Off) => led.set_low().unwrap(), + _ => (), + } + } +} From 816b78c0d9733362d8653eb2032f126e6a710030 Mon Sep 17 00:00:00 2001 From: huntc Date: Tue, 6 Jul 2021 23:20:47 +1000 Subject: [PATCH 02/23] Reduces the types on sender and receiver In exchange for an UnsafeCell being passed into split --- embassy/src/util/mpsc.rs | 248 +++++++++++++++-------------------- examples/nrf/src/bin/mpsc.rs | 8 +- 2 files changed, 110 insertions(+), 146 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index d24eb00bf..d8a010d7d 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -39,7 +39,6 @@ use core::cell::UnsafeCell; use core::fmt; -use core::marker::PhantomData; use core::mem::MaybeUninit; use core::pin::Pin; use core::task::Context; @@ -55,32 +54,24 @@ use super::ThreadModeMutex; /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. -pub struct Sender<'ch, M, T, const N: usize> -where - M: Mutex, -{ - channel: *mut Channel, - phantom_data: &'ch PhantomData, +pub struct Sender<'ch, T> { + channel: &'ch UnsafeCell>, } // Safe to pass the sender around -unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, T> Send for Sender<'ch, T> {} +unsafe impl<'ch, T> Sync for Sender<'ch, T> {} /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. -pub struct Receiver<'ch, M, T, const N: usize> -where - M: Mutex, -{ - channel: *mut Channel, - _phantom_data: &'ch PhantomData, +pub struct Receiver<'ch, T> { + channel: &'ch UnsafeCell>, } // Safe to pass the receiver around -unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, T> Send for Receiver<'ch, T> {} +unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// @@ -98,37 +89,29 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: /// their channel. The following will therefore fail compilation: //// /// ```compile_fail +/// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// let (sender, receiver) = { -/// let mut channel = Channel::::with_thread_mode_only(); -/// mpsc::split(&mut channel) +/// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); +/// mpsc::split(&channel) /// }; /// ``` -pub fn split<'ch, M, T, const N: usize>( - channel: &'ch mut Channel, -) -> (Sender<'ch, M, T, N>, Receiver<'ch, M, T, N>) -where - M: Mutex, -{ - let sender = Sender { - channel, - phantom_data: &PhantomData, - }; - let receiver = Receiver { - channel, - _phantom_data: &PhantomData, - }; - channel.register_receiver(); - channel.register_sender(); +pub fn split<'ch, T>( + channel: &'ch UnsafeCell>, +) -> (Sender<'ch, T>, Receiver<'ch, T>) { + let sender = Sender { channel: &channel }; + let receiver = Receiver { channel: &channel }; + { + let c = unsafe { &mut *channel.get() }; + c.register_receiver(); + c.register_sender(); + } (sender, receiver) } -impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Receiver<'ch, T> { /// Receives the next value for this receiver. /// /// This method returns `None` if the channel has been closed and there are @@ -154,7 +137,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - unsafe { self.channel.as_mut().unwrap().try_recv() } + unsafe { &mut *self.channel.get() }.try_recv() } /// Closes the receiving half of a channel without dropping it. @@ -168,14 +151,11 @@ where /// until those are released. /// pub fn close(&mut self) { - unsafe { self.channel.as_mut().unwrap().close() } + unsafe { &mut *self.channel.get() }.close() } } -impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for Receiver<'ch, T> { type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -183,31 +163,20 @@ where Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - unsafe { - self.channel - .as_mut() - .unwrap() - .set_receiver_waker(cx.waker().clone()); - }; + unsafe { &mut *self.channel.get() }.set_receiver_waker(cx.waker().clone()); Poll::Pending } } } } -impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Drop for Receiver<'ch, T> { fn drop(&mut self) { - unsafe { self.channel.as_mut().unwrap().deregister_receiver() } + unsafe { &mut *self.channel.get() }.deregister_receiver() } } -impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Sender<'ch, T> { /// Sends a value, waiting until there is capacity. /// /// A successful send occurs when it is determined that the other end of the @@ -255,7 +224,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - unsafe { self.channel.as_mut().unwrap().try_send(message) } + unsafe { &mut *self.channel.get() }.try_send(message) } /// Completes when the receiver has dropped. @@ -276,22 +245,16 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - unsafe { self.channel.as_mut().unwrap().is_closed() } + unsafe { &mut *self.channel.get() }.is_closed() } } -struct SendFuture<'ch, M, T, const N: usize> -where - M: Mutex, -{ - sender: Sender<'ch, M, T, N>, +struct SendFuture<'ch, T> { + sender: Sender<'ch, T>, message: UnsafeCell, } -impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for SendFuture<'ch, T> { type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -299,13 +262,7 @@ where Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - unsafe { - self.sender - .channel - .as_mut() - .unwrap() - .set_senders_waker(cx.waker().clone()); - }; + unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -315,53 +272,34 @@ where } } -struct CloseFuture<'ch, M, T, const N: usize> -where - M: Mutex, -{ - sender: Sender<'ch, M, T, N>, +struct CloseFuture<'ch, T> { + sender: Sender<'ch, T>, } -impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Future for CloseFuture<'ch, T> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.sender.is_closed() { Poll::Ready(()) } else { - unsafe { - self.sender - .channel - .as_mut() - .unwrap() - .set_senders_waker(cx.waker().clone()); - }; + unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); Poll::Pending } } } -impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Drop for Sender<'ch, T> { fn drop(&mut self) { - unsafe { self.channel.as_mut().unwrap().deregister_sender() } + unsafe { &mut *self.channel.get() }.deregister_sender() } } -impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> -where - M: Mutex, -{ +impl<'ch, T> Clone for Sender<'ch, T> { fn clone(&self) -> Self { - unsafe { self.channel.as_mut().unwrap().register_sender() }; + unsafe { &mut *self.channel.get() }.register_sender(); Sender { - channel: self.channel, - phantom_data: self.phantom_data, + channel: self.channel.clone(), } } } @@ -414,6 +352,28 @@ impl fmt::Display for TrySendError { } } +pub trait ChannelLike { + fn try_recv(&mut self) -> Result; + + fn try_send(&mut self, message: T) -> Result<(), TrySendError>; + + fn close(&mut self); + + fn is_closed(&mut self) -> bool; + + fn register_receiver(&mut self); + + fn deregister_receiver(&mut self); + + fn register_sender(&mut self); + + fn deregister_sender(&mut self); + + fn set_receiver_waker(&mut self, receiver_waker: Waker); + + fn set_senders_waker(&mut self, senders_waker: Waker); +} + pub struct ChannelState { buf: [MaybeUninit>; N], read_pos: usize, @@ -480,13 +440,14 @@ impl Channel { /// from exception mode e.g. interrupt handlers. To create one: /// /// ``` + /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithCriticalSections}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = mpsc::Channel::::with_critical_sections(); + /// let mut channel = UnsafeCell::new(mpsc::Channel::::with_critical_sections()); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); + /// let (sender, receiver) = mpsc::split(&channel); /// ``` pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); @@ -504,13 +465,14 @@ impl Channel { /// channel avoids all locks. To create one: /// /// ``` no_run + /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::with_thread_mode_only(); + /// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); + /// let (sender, receiver) = mpsc::split(&channel); /// ``` pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); @@ -519,7 +481,7 @@ impl Channel { } } -impl Channel +impl ChannelLike for Channel where M: Mutex, { @@ -771,16 +733,16 @@ mod tests { #[test] fn simple_send_and_receive() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); } #[test] fn should_close_without_sender() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); drop(s); match r.try_recv() { Err(TryRecvError::Closed) => assert!(true), @@ -790,8 +752,8 @@ mod tests { #[test] fn should_close_once_drained() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); assert!(s.try_send(1).is_ok()); drop(s); assert_eq!(r.try_recv().unwrap(), 1); @@ -803,8 +765,8 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let mut c = Channel::::with_no_threads(); - let (s, r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, r) = split(&c); drop(r); match s.try_send(1) { Err(TrySendError::Closed(1)) => assert!(true), @@ -814,8 +776,8 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let mut c = Channel::::with_no_threads(); - let (s, mut r) = split(&mut c); + let c = UnsafeCell::new(Channel::::with_no_threads()); + let (s, mut r) = split(&c); assert!(s.try_send(1).is_ok()); r.close(); assert_eq!(r.try_recv().unwrap(), 1); @@ -830,9 +792,9 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { drop(s); @@ -845,12 +807,12 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { - let _ = s.try_send(1); + assert!(s.try_send(1).is_ok()); }) .is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -858,18 +820,18 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); } #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, r) = split(unsafe { &CHANNEL }); drop(r); match s.send(1).await { Err(SendError(1)) => assert!(true), @@ -881,9 +843,9 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s0, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s0, mut r) = split(unsafe { &CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); @@ -901,18 +863,18 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, mut r) = split(unsafe { &CHANNEL }); r.close(); s.closed().await; } #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static mut CHANNEL: UnsafeCell> = + UnsafeCell::new(Channel::with_critical_sections()); + let (s, r) = split(unsafe { &CHANNEL }); drop(r); s.closed().await; } diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index 6a0f8f471..d692abee2 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -8,6 +8,8 @@ #[path = "../example_common.rs"] mod example_common; +use core::cell::UnsafeCell; + use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; @@ -23,10 +25,10 @@ enum LedState { Off, } -static CHANNEL: Forever> = Forever::new(); +static CHANNEL: Forever>> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { +async fn my_task(sender: Sender<'static, LedState>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await; @@ -39,7 +41,7 @@ async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(Channel::with_thread_mode_only()); + let channel = CHANNEL.put(UnsafeCell::new(Channel::with_thread_mode_only())); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap(); From a156f72bfba854087afd17f4769060577417698a Mon Sep 17 00:00:00 2001 From: huntc Date: Tue, 6 Jul 2021 23:51:47 +1000 Subject: [PATCH 03/23] Improves the representation of side effects --- embassy/src/util/mpsc.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index d8a010d7d..f09bcedbc 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -492,7 +492,9 @@ where if state.read_pos != state.write_pos || state.full { if state.full { state.full = false; - state.senders_waker.take().map(|w| w.wake()); + if let Some(w) = state.senders_waker.take() { + w.wake(); + } } let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; @@ -502,7 +504,9 @@ where Err(TryRecvError::Empty) } else { state.closed = true; - state.senders_waker.take().map(|w| w.wake()); + if let Some(w) = state.senders_waker.take() { + w.wake(); + } Err(TryRecvError::Closed) } } else { @@ -521,7 +525,9 @@ where if state.write_pos == state.read_pos { state.full = true; } - state.receiver_waker.take().map(|w| w.wake()); + if let Some(w) = state.receiver_waker.take() { + w.wake(); + } Ok(()) } else { Err(TrySendError::Full(message)) @@ -535,7 +541,9 @@ where fn close(&mut self) { let state = &mut self.state; self.mutex.lock(|_| { - state.receiver_waker.take().map(|w| w.wake()); + if let Some(w) = state.receiver_waker.take() { + w.wake(); + } state.closing = true; }); } @@ -558,7 +566,9 @@ where self.mutex.lock(|_| { if state.receiver_registered { state.closed = true; - state.senders_waker.take().map(|w| w.wake()); + if let Some(w) = state.senders_waker.take() { + w.wake(); + } } state.receiver_registered = false; }) @@ -577,7 +587,9 @@ where assert!(state.senders_registered > 0); state.senders_registered = state.senders_registered - 1; if state.senders_registered == 0 { - state.receiver_waker.take().map(|w| w.wake()); + if let Some(w) = state.receiver_waker.take() { + w.wake(); + } state.closing = true; } }) From 1b49acc2f734201d8bc8d7b59e692817f4f64dea Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 7 Jul 2021 00:12:58 +1000 Subject: [PATCH 04/23] Fixed some clippy warnings --- embassy/src/util/mpsc.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index f09bcedbc..5f2dd729a 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -98,9 +98,7 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// mpsc::split(&channel) /// }; /// ``` -pub fn split<'ch, T>( - channel: &'ch UnsafeCell>, -) -> (Sender<'ch, T>, Receiver<'ch, T>) { +pub fn split(channel: &UnsafeCell>) -> (Sender, Receiver) { let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; { @@ -296,6 +294,7 @@ impl<'ch, T> Drop for Sender<'ch, T> { } impl<'ch, T> Clone for Sender<'ch, T> { + #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { unsafe { &mut *self.channel.get() }.register_sender(); Sender { @@ -577,7 +576,7 @@ where fn register_sender(&mut self) { let state = &mut self.state; self.mutex.lock(|_| { - state.senders_registered = state.senders_registered + 1; + state.senders_registered += 1; }) } @@ -585,7 +584,7 @@ where let state = &mut self.state; self.mutex.lock(|_| { assert!(state.senders_registered > 0); - state.senders_registered = state.senders_registered - 1; + state.senders_registered -= 1; if state.senders_registered == 0 { if let Some(w) = state.receiver_waker.take() { w.wake(); From ae62948d6c21bc1ac4af50c3e39888c52d696b24 Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 7 Jul 2021 18:08:36 +1000 Subject: [PATCH 05/23] Replace UnsafeCell Using a new ChannelCell so that there's no leaking of the abstraction --- embassy/src/util/mpsc.rs | 84 +++++++++++++++++++++++------------- examples/nrf/src/bin/mpsc.rs | 8 ++-- 2 files changed, 56 insertions(+), 36 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 5f2dd729a..aabced416 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,11 +51,36 @@ use super::CriticalSectionMutex; use super::Mutex; use super::ThreadModeMutex; +/// A ChannelCell permits a channel to be shared between senders and their receivers. +// Derived from UnsafeCell. +#[repr(transparent)] +pub struct ChannelCell { + _value: T, +} + +impl ChannelCell { + #[inline(always)] + pub const fn new(value: T) -> ChannelCell { + ChannelCell { _value: value } + } +} + +impl ChannelCell { + #[inline(always)] + const fn get(&self) -> *mut T { + // As per UnsafeCell: + // We can just cast the pointer from `ChannelCell` to `T` because of + // #[repr(transparent)]. This exploits libstd's special status, there is + // no guarantee for user code that this will work in future versions of the compiler! + self as *const ChannelCell as *const T as *mut T + } +} + /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. pub struct Sender<'ch, T> { - channel: &'ch UnsafeCell>, + channel: &'ch ChannelCell>, } // Safe to pass the sender around @@ -66,7 +91,7 @@ unsafe impl<'ch, T> Sync for Sender<'ch, T> {} /// /// Instances are created by the [`split`](split) function. pub struct Receiver<'ch, T> { - channel: &'ch UnsafeCell>, + channel: &'ch ChannelCell>, } // Safe to pass the receiver around @@ -89,16 +114,15 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// their channel. The following will therefore fail compilation: //// /// ```compile_fail -/// use core::cell::UnsafeCell; /// use embassy::util::mpsc; -/// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; +/// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; /// /// let (sender, receiver) = { -/// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); +/// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); /// mpsc::split(&channel) /// }; /// ``` -pub fn split(channel: &UnsafeCell>) -> (Sender, Receiver) { +pub fn split(channel: &ChannelCell>) -> (Sender, Receiver) { let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; { @@ -439,12 +463,11 @@ impl Channel { /// from exception mode e.g. interrupt handlers. To create one: /// /// ``` - /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithCriticalSections}; + /// use embassy::util::mpsc::{Channel, ChannelCell, WithCriticalSections}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = UnsafeCell::new(mpsc::Channel::::with_critical_sections()); + /// let mut channel = ChannelCell::new(mpsc::Channel::::with_critical_sections()); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -464,12 +487,11 @@ impl Channel { /// channel avoids all locks. To create one: /// /// ``` no_run - /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; + /// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); + /// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -744,7 +766,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -752,7 +774,7 @@ mod tests { #[test] fn should_close_without_sender() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); drop(s); match r.try_recv() { @@ -763,7 +785,7 @@ mod tests { #[test] fn should_close_once_drained() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); assert!(s.try_send(1).is_ok()); drop(s); @@ -776,7 +798,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); drop(r); match s.try_send(1) { @@ -787,7 +809,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, mut r) = split(&c); assert!(s.try_send(1).is_ok()); r.close(); @@ -803,8 +825,8 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -818,8 +840,8 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -831,8 +853,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -840,8 +862,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, r) = split(unsafe { &CHANNEL }); drop(r); match s.send(1).await { @@ -854,8 +876,8 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s0, mut r) = split(unsafe { &CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); @@ -874,8 +896,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); r.close(); s.closed().await; @@ -883,8 +905,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, r) = split(unsafe { &CHANNEL }); drop(r); s.closed().await; diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index d692abee2..eafa29e60 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -8,12 +8,10 @@ #[path = "../example_common.rs"] mod example_common; -use core::cell::UnsafeCell; - use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; -use embassy::util::mpsc::TryRecvError; +use embassy::util::mpsc::{ChannelCell, TryRecvError}; use embassy::util::{mpsc, Forever}; use embassy_nrf::gpio::{Level, Output, OutputDrive}; use embassy_nrf::Peripherals; @@ -25,7 +23,7 @@ enum LedState { Off, } -static CHANNEL: Forever>> = Forever::new(); +static CHANNEL: Forever>> = Forever::new(); #[embassy::task(pool_size = 1)] async fn my_task(sender: Sender<'static, LedState>) { @@ -41,7 +39,7 @@ async fn my_task(sender: Sender<'static, LedState>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(UnsafeCell::new(Channel::with_thread_mode_only())); + let channel = CHANNEL.put(ChannelCell::new(Channel::with_thread_mode_only())); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap(); From 3fdf61c819d4612c721d846faf9821170f5f75a3 Mon Sep 17 00:00:00 2001 From: huntc Date: Thu, 8 Jul 2021 11:09:02 +1000 Subject: [PATCH 06/23] Constraint the use of ChannelCell to just channels --- embassy/src/util/mpsc.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index aabced416..354c3e95c 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -60,7 +60,10 @@ pub struct ChannelCell { impl ChannelCell { #[inline(always)] - pub const fn new(value: T) -> ChannelCell { + pub const fn new(value: T) -> ChannelCell + where + T: ChannelLike, + { ChannelCell { _value: value } } } From 56b3e927fe2c779c4bc6d556ff9fc836d2a4f2d4 Mon Sep 17 00:00:00 2001 From: huntc Date: Fri, 9 Jul 2021 10:25:50 +1000 Subject: [PATCH 07/23] ChannelState should be private --- embassy/src/util/mpsc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 354c3e95c..fc8006c38 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -400,7 +400,7 @@ pub trait ChannelLike { fn set_senders_waker(&mut self, senders_waker: Waker); } -pub struct ChannelState { +struct ChannelState { buf: [MaybeUninit>; N], read_pos: usize, write_pos: usize, From 5f87c7808c9d896a2a2d5e064a58ed2ac23a4348 Mon Sep 17 00:00:00 2001 From: huntc Date: Fri, 9 Jul 2021 12:04:22 +1000 Subject: [PATCH 08/23] Remove the cell and trait At the expense of exposing the channel types again. We do this as we want to avoid using dyn traits given their overhead for embedded environments. --- embassy/src/util/mpsc.rs | 169 ++++++++++++++++++----------------- examples/nrf/src/bin/mpsc.rs | 8 +- 2 files changed, 90 insertions(+), 87 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index fc8006c38..65e4bf7b7 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,55 +51,33 @@ use super::CriticalSectionMutex; use super::Mutex; use super::ThreadModeMutex; -/// A ChannelCell permits a channel to be shared between senders and their receivers. -// Derived from UnsafeCell. -#[repr(transparent)] -pub struct ChannelCell { - _value: T, -} - -impl ChannelCell { - #[inline(always)] - pub const fn new(value: T) -> ChannelCell - where - T: ChannelLike, - { - ChannelCell { _value: value } - } -} - -impl ChannelCell { - #[inline(always)] - const fn get(&self) -> *mut T { - // As per UnsafeCell: - // We can just cast the pointer from `ChannelCell` to `T` because of - // #[repr(transparent)]. This exploits libstd's special status, there is - // no guarantee for user code that this will work in future versions of the compiler! - self as *const ChannelCell as *const T as *mut T - } -} - /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. -pub struct Sender<'ch, T> { - channel: &'ch ChannelCell>, +pub struct Sender<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: &'ch Channel, } // Safe to pass the sender around -unsafe impl<'ch, T> Send for Sender<'ch, T> {} -unsafe impl<'ch, T> Sync for Sender<'ch, T> {} +unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} /// Receive values from the associated `Sender`. /// /// Instances are created by the [`split`](split) function. -pub struct Receiver<'ch, T> { - channel: &'ch ChannelCell>, +pub struct Receiver<'ch, M, T, const N: usize> +where + M: Mutex, +{ + channel: &'ch Channel, } // Safe to pass the receiver around -unsafe impl<'ch, T> Send for Receiver<'ch, T> {} -unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} +unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// @@ -125,18 +103,26 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// mpsc::split(&channel) /// }; /// ``` -pub fn split(channel: &ChannelCell>) -> (Sender, Receiver) { +pub fn split( + channel: &Channel, +) -> (Sender, Receiver) +where + M: Mutex, +{ let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; { - let c = unsafe { &mut *channel.get() }; + let c = channel.get(); c.register_receiver(); c.register_sender(); } (sender, receiver) } -impl<'ch, T> Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Receiver<'ch, M, T, N> +where + M: Mutex, +{ /// Receives the next value for this receiver. /// /// This method returns `None` if the channel has been closed and there are @@ -162,7 +148,7 @@ impl<'ch, T> Receiver<'ch, T> { /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - unsafe { &mut *self.channel.get() }.try_recv() + self.channel.get().try_recv() } /// Closes the receiving half of a channel without dropping it. @@ -176,11 +162,14 @@ impl<'ch, T> Receiver<'ch, T> { /// until those are released. /// pub fn close(&mut self) { - unsafe { &mut *self.channel.get() }.close() + self.channel.get().close() } } -impl<'ch, T> Future for Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> +where + M: Mutex, +{ type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -188,20 +177,26 @@ impl<'ch, T> Future for Receiver<'ch, T> { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - unsafe { &mut *self.channel.get() }.set_receiver_waker(cx.waker().clone()); + self.channel.get().set_receiver_waker(cx.waker().clone()); Poll::Pending } } } } -impl<'ch, T> Drop for Receiver<'ch, T> { +impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> +where + M: Mutex, +{ fn drop(&mut self) { - unsafe { &mut *self.channel.get() }.deregister_receiver() + self.channel.get().deregister_receiver() } } -impl<'ch, T> Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Sender<'ch, M, T, N> +where + M: Mutex, +{ /// Sends a value, waiting until there is capacity. /// /// A successful send occurs when it is determined that the other end of the @@ -249,7 +244,7 @@ impl<'ch, T> Sender<'ch, T> { /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - unsafe { &mut *self.channel.get() }.try_send(message) + self.channel.get().try_send(message) } /// Completes when the receiver has dropped. @@ -270,16 +265,22 @@ impl<'ch, T> Sender<'ch, T> { /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - unsafe { &mut *self.channel.get() }.is_closed() + self.channel.get().is_closed() } } -struct SendFuture<'ch, T> { - sender: Sender<'ch, T>, +struct SendFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, message: UnsafeCell, } -impl<'ch, T> Future for SendFuture<'ch, T> { +impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> +where + M: Mutex, +{ type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -287,7 +288,10 @@ impl<'ch, T> Future for SendFuture<'ch, T> { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); + self.sender + .channel + .get() + .set_senders_waker(cx.waker().clone()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -297,33 +301,48 @@ impl<'ch, T> Future for SendFuture<'ch, T> { } } -struct CloseFuture<'ch, T> { - sender: Sender<'ch, T>, +struct CloseFuture<'ch, M, T, const N: usize> +where + M: Mutex, +{ + sender: Sender<'ch, M, T, N>, } -impl<'ch, T> Future for CloseFuture<'ch, T> { +impl<'ch, M, T, const N: usize> Future for CloseFuture<'ch, M, T, N> +where + M: Mutex, +{ type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.sender.is_closed() { Poll::Ready(()) } else { - unsafe { &mut *self.sender.channel.get() }.set_senders_waker(cx.waker().clone()); + self.sender + .channel + .get() + .set_senders_waker(cx.waker().clone()); Poll::Pending } } } -impl<'ch, T> Drop for Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Drop for Sender<'ch, M, T, N> +where + M: Mutex, +{ fn drop(&mut self) { - unsafe { &mut *self.channel.get() }.deregister_sender() + self.channel.get().deregister_sender() } } -impl<'ch, T> Clone for Sender<'ch, T> { +impl<'ch, M, T, const N: usize> Clone for Sender<'ch, M, T, N> +where + M: Mutex, +{ #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - unsafe { &mut *self.channel.get() }.register_sender(); + self.channel.get().register_sender(); Sender { channel: self.channel.clone(), } @@ -378,28 +397,6 @@ impl fmt::Display for TrySendError { } } -pub trait ChannelLike { - fn try_recv(&mut self) -> Result; - - fn try_send(&mut self, message: T) -> Result<(), TrySendError>; - - fn close(&mut self); - - fn is_closed(&mut self) -> bool; - - fn register_receiver(&mut self); - - fn deregister_receiver(&mut self); - - fn register_sender(&mut self); - - fn deregister_sender(&mut self); - - fn set_receiver_waker(&mut self, receiver_waker: Waker); - - fn set_senders_waker(&mut self, senders_waker: Waker); -} - struct ChannelState { buf: [MaybeUninit>; N], read_pos: usize, @@ -505,10 +502,16 @@ impl Channel { } } -impl ChannelLike for Channel +impl Channel where M: Mutex, { + fn get(&self) -> &mut Self { + let const_ptr = self as *const Self; + let mut_ptr = const_ptr as *mut Self; + unsafe { &mut *mut_ptr } + } + fn try_recv(&mut self) -> Result { let state = &mut self.state; self.mutex.lock(|_| { diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index eafa29e60..6a0f8f471 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -11,7 +11,7 @@ mod example_common; use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; -use embassy::util::mpsc::{ChannelCell, TryRecvError}; +use embassy::util::mpsc::TryRecvError; use embassy::util::{mpsc, Forever}; use embassy_nrf::gpio::{Level, Output, OutputDrive}; use embassy_nrf::Peripherals; @@ -23,10 +23,10 @@ enum LedState { Off, } -static CHANNEL: Forever>> = Forever::new(); +static CHANNEL: Forever> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, LedState>) { +async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await; @@ -39,7 +39,7 @@ async fn my_task(sender: Sender<'static, LedState>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(ChannelCell::new(Channel::with_thread_mode_only())); + let channel = CHANNEL.put(Channel::with_thread_mode_only()); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap(); From f159beec1cbd1406f63ca7c3e84a1d598bbadaa1 Mon Sep 17 00:00:00 2001 From: huntc Date: Fri, 9 Jul 2021 12:13:07 +1000 Subject: [PATCH 09/23] Use of a NoopMutex --- embassy/src/util/mpsc.rs | 109 +++++++++++++++-------------------- embassy/src/util/mutex.rs | 27 +++++++++ examples/nrf/src/bin/mpsc.rs | 8 +-- 3 files changed, 78 insertions(+), 66 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 65e4bf7b7..e54c507c1 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -49,6 +49,7 @@ use futures::Future; use super::CriticalSectionMutex; use super::Mutex; +use super::NoopMutex; use super::ThreadModeMutex; /// Send values to the associated `Receiver`. @@ -96,10 +97,10 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: //// /// ```compile_fail /// use embassy::util::mpsc; -/// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; +/// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// let (sender, receiver) = { -/// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); +/// let mut channel = Channel::::with_thread_mode_only(); /// mpsc::split(&channel) /// }; /// ``` @@ -464,10 +465,10 @@ impl Channel { /// /// ``` /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, ChannelCell, WithCriticalSections}; + /// use embassy::util::mpsc::{Channel, WithCriticalSections}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = ChannelCell::new(mpsc::Channel::::with_critical_sections()); + /// let mut channel = Channel::::with_critical_sections(); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -488,10 +489,10 @@ impl Channel { /// /// ``` no_run /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; + /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); + /// let mut channel = Channel::::with_thread_mode_only(); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -502,6 +503,27 @@ impl Channel { } } +pub type WithNoThreads = NoopMutex<()>; + +impl Channel { + /// Establish a new bounded channel for within a single thread. To create one: + /// + /// ``` + /// use embassy::util::mpsc; + /// use embassy::util::mpsc::{Channel, WithNoThreads}; + /// + /// // Declare a bounded channel of 3 u32s. + /// let mut channel = Channel::::with_no_threads(); + /// // once we have a channel, obtain its sender and receiver + /// let (sender, receiver) = mpsc::split(&channel); + /// ``` + pub const fn with_no_threads() -> Self { + let mutex = NoopMutex::new(()); + let state = ChannelState::new(); + Channel { mutex, state } + } +} + impl Channel where M: Mutex, @@ -675,43 +697,6 @@ mod tests { } } - /// A mutex that does nothing - useful for our testing purposes - pub struct NoopMutex { - inner: UnsafeCell, - } - - impl NoopMutex { - pub const fn new(value: T) -> Self { - NoopMutex { - inner: UnsafeCell::new(value), - } - } - } - - impl NoopMutex { - pub fn borrow(&self) -> &T { - unsafe { &*self.inner.get() } - } - } - - impl Mutex for NoopMutex { - type Data = T; - - fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { - f(self.borrow()) - } - } - - pub type WithNoThreads = NoopMutex<()>; - - impl Channel { - pub const fn with_no_threads() -> Self { - let mutex = NoopMutex::new(()); - let state = ChannelState::new(); - Channel { mutex, state } - } - } - #[test] fn sending_once() { let mut c = Channel::::with_no_threads(); @@ -772,7 +757,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let c = ChannelCell::new(Channel::::with_no_threads()); + let c = Channel::::with_no_threads(); let (s, r) = split(&c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -780,7 +765,7 @@ mod tests { #[test] fn should_close_without_sender() { - let c = ChannelCell::new(Channel::::with_no_threads()); + let c = Channel::::with_no_threads(); let (s, r) = split(&c); drop(s); match r.try_recv() { @@ -791,7 +776,7 @@ mod tests { #[test] fn should_close_once_drained() { - let c = ChannelCell::new(Channel::::with_no_threads()); + let c = Channel::::with_no_threads(); let (s, r) = split(&c); assert!(s.try_send(1).is_ok()); drop(s); @@ -804,7 +789,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let c = ChannelCell::new(Channel::::with_no_threads()); + let c = Channel::::with_no_threads(); let (s, r) = split(&c); drop(r); match s.try_send(1) { @@ -815,7 +800,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let c = ChannelCell::new(Channel::::with_no_threads()); + let c = Channel::::with_no_threads(); let (s, mut r) = split(&c); assert!(s.try_send(1).is_ok()); r.close(); @@ -831,8 +816,8 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -846,8 +831,8 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -859,8 +844,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -868,8 +853,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, r) = split(unsafe { &CHANNEL }); drop(r); match s.send(1).await { @@ -882,8 +867,8 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s0, mut r) = split(unsafe { &CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); @@ -902,8 +887,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, mut r) = split(unsafe { &CHANNEL }); r.close(); s.closed().await; @@ -911,8 +896,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: ChannelCell> = - ChannelCell::new(Channel::with_critical_sections()); + static mut CHANNEL: Channel = + Channel::with_critical_sections(); let (s, r) = split(unsafe { &CHANNEL }); drop(r); s.closed().await; diff --git a/embassy/src/util/mutex.rs b/embassy/src/util/mutex.rs index 682fcb39d..db3423cb3 100644 --- a/embassy/src/util/mutex.rs +++ b/embassy/src/util/mutex.rs @@ -105,3 +105,30 @@ pub fn in_thread_mode() -> bool { return cortex_m::peripheral::SCB::vect_active() == cortex_m::peripheral::scb::VectActive::ThreadMode; } + +/// A "mutex" that does nothing and cannot be shared between threads. +pub struct NoopMutex { + inner: UnsafeCell, +} + +impl NoopMutex { + pub const fn new(value: T) -> Self { + NoopMutex { + inner: UnsafeCell::new(value), + } + } +} + +impl NoopMutex { + pub fn borrow(&self) -> &T { + unsafe { &*self.inner.get() } + } +} + +impl Mutex for NoopMutex { + type Data = T; + + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { + f(self.borrow()) + } +} diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index 6a0f8f471..c2cb107e1 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -16,17 +16,17 @@ use embassy::util::{mpsc, Forever}; use embassy_nrf::gpio::{Level, Output, OutputDrive}; use embassy_nrf::Peripherals; use embedded_hal::digital::v2::OutputPin; -use mpsc::{Channel, Sender, WithThreadModeOnly}; +use mpsc::{Channel, Sender, WithNoThreads}; enum LedState { On, Off, } -static CHANNEL: Forever> = Forever::new(); +static CHANNEL: Forever> = Forever::new(); #[embassy::task(pool_size = 1)] -async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { +async fn my_task(sender: Sender<'static, WithNoThreads, LedState, 1>) { loop { let _ = sender.send(LedState::On).await; Timer::after(Duration::from_secs(1)).await; @@ -39,7 +39,7 @@ async fn my_task(sender: Sender<'static, WithThreadModeOnly, LedState, 1>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(Channel::with_thread_mode_only()); + let channel = CHANNEL.put(Channel::with_no_threads()); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap(); From dcd0c38109ed6711d91c4bdff42825f25e3ee402 Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 10:54:35 +1000 Subject: [PATCH 10/23] Return a new future each time recv is called --- embassy/src/util/mpsc.rs | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index e54c507c1..8f1bba764 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -141,7 +141,18 @@ where /// /// [`close`]: Self::close pub async fn recv(&mut self) -> Option { - self.await + futures::future::poll_fn(|cx| self.recv_poll(cx)).await + } + + fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll> { + match self.try_recv() { + Ok(v) => Poll::Ready(Some(v)), + Err(TryRecvError::Closed) => Poll::Ready(None), + Err(TryRecvError::Empty) => { + self.channel.get().set_receiver_waker(cx.waker().clone()); + Poll::Pending + } + } } /// Attempts to immediately receive a message on this `Receiver` @@ -167,24 +178,6 @@ where } } -impl<'ch, M, T, const N: usize> Future for Receiver<'ch, M, T, N> -where - M: Mutex, -{ - type Output = Option; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.try_recv() { - Ok(v) => Poll::Ready(Some(v)), - Err(TryRecvError::Closed) => Poll::Ready(None), - Err(TryRecvError::Empty) => { - self.channel.get().set_receiver_waker(cx.waker().clone()); - Poll::Pending - } - } - } -} - impl<'ch, M, T, const N: usize> Drop for Receiver<'ch, M, T, N> where M: Mutex, From 108cffcba02d5f84099991d670cddfb458e2c106 Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 11:47:09 +1000 Subject: [PATCH 11/23] Migrated to the waker registration functionality for Embassy specific optimisations --- embassy/Cargo.toml | 1 + embassy/src/util/mpsc.rs | 61 +++++++++++++--------------------------- 2 files changed, 20 insertions(+), 42 deletions(-) diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index d26490247..c03fc0df5 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -39,6 +39,7 @@ embedded-hal = "0.2.5" cast = { version = "=0.2.3", default-features = false } [dev-dependencies] +embassy = { path = ".", features = ["executor-agnostic"] } futures-executor = { version = "0.3", features = [ "thread-pool" ] } futures-test = "0.3" futures-timer = "0.3" diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 8f1bba764..580c6794f 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,6 +51,7 @@ use super::CriticalSectionMutex; use super::Mutex; use super::NoopMutex; use super::ThreadModeMutex; +use super::WakerRegistration; /// Send values to the associated `Receiver`. /// @@ -149,7 +150,7 @@ where Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - self.channel.get().set_receiver_waker(cx.waker().clone()); + self.channel.get().set_receiver_waker(&cx.waker()); Poll::Pending } } @@ -282,10 +283,7 @@ where Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - self.sender - .channel - .get() - .set_senders_waker(cx.waker().clone()); + self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -312,10 +310,7 @@ where if self.sender.is_closed() { Poll::Ready(()) } else { - self.sender - .channel - .get() - .set_senders_waker(cx.waker().clone()); + self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending } } @@ -400,8 +395,8 @@ struct ChannelState { closed: bool, receiver_registered: bool, senders_registered: u32, - receiver_waker: Option, - senders_waker: Option, + receiver_waker: WakerRegistration, + senders_waker: WakerRegistration, } impl ChannelState { @@ -416,8 +411,8 @@ impl ChannelState { let closed = false; let receiver_registered = false; let senders_registered = 0; - let receiver_waker = None; - let senders_waker = None; + let receiver_waker = WakerRegistration::new(); + let senders_waker = WakerRegistration::new(); ChannelState { buf, read_pos, @@ -534,9 +529,7 @@ where if state.read_pos != state.write_pos || state.full { if state.full { state.full = false; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); } let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; @@ -546,9 +539,7 @@ where Err(TryRecvError::Empty) } else { state.closed = true; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); Err(TryRecvError::Closed) } } else { @@ -567,9 +558,7 @@ where if state.write_pos == state.read_pos { state.full = true; } - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); Ok(()) } else { Err(TrySendError::Full(message)) @@ -583,9 +572,7 @@ where fn close(&mut self) { let state = &mut self.state; self.mutex.lock(|_| { - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); state.closing = true; }); } @@ -608,9 +595,7 @@ where self.mutex.lock(|_| { if state.receiver_registered { state.closed = true; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); } state.receiver_registered = false; }) @@ -629,38 +614,30 @@ where assert!(state.senders_registered > 0); state.senders_registered -= 1; if state.senders_registered == 0 { - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); state.closing = true; } }) } - fn set_receiver_waker(&mut self, receiver_waker: Waker) { + fn set_receiver_waker(&mut self, receiver_waker: &Waker) { let state = &mut self.state; self.mutex.lock(|_| { - state.receiver_waker = Some(receiver_waker); + state.receiver_waker.register(receiver_waker); }) } - fn set_senders_waker(&mut self, senders_waker: Waker) { + fn set_senders_waker(&mut self, senders_waker: &Waker) { let state = &mut self.state; self.mutex.lock(|_| { - // Dispose of any existing sender causing them to be polled again. // This could cause a spin given multiple concurrent senders, however given that // most sends only block waiting for the receiver to become active, this should // be a short-lived activity. The upside is a greatly simplified implementation // that avoids the need for intrusive linked-lists and unsafe operations on pinned // pointers. - if let Some(waker) = state.senders_waker.clone() { - if !senders_waker.will_wake(&waker) { - trace!("Waking an an active send waker due to being superseded with a new one. While benign, please report this."); - waker.wake(); - } - } - state.senders_waker = Some(senders_waker); + state.senders_waker.wake(); + state.senders_waker.register(senders_waker); }) } } From 9b5f2e465bf97300664263cedecd9b8d8034d435 Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 11:49:51 +1000 Subject: [PATCH 12/23] Tidying --- embassy/src/util/mpsc.rs | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 580c6794f..68fcdf7f9 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -403,27 +403,17 @@ impl ChannelState { const INIT: MaybeUninit> = MaybeUninit::uninit(); const fn new() -> Self { - let buf = [Self::INIT; N]; - let read_pos = 0; - let write_pos = 0; - let full = false; - let closing = false; - let closed = false; - let receiver_registered = false; - let senders_registered = 0; - let receiver_waker = WakerRegistration::new(); - let senders_waker = WakerRegistration::new(); ChannelState { - buf, - read_pos, - write_pos, - full, - closing, - closed, - receiver_registered, - senders_registered, - receiver_waker, - senders_waker, + buf: [Self::INIT; N], + read_pos: 0, + write_pos: 0, + full: false, + closing: false, + closed: false, + receiver_registered: false, + senders_registered: 0, + receiver_waker: WakerRegistration::new(), + senders_waker: WakerRegistration::new(), } } } From 5a5795ef2b25fe227978935a1d5b2efb3cf5b08c Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 12:05:50 +1000 Subject: [PATCH 13/23] NoopMutex does not require an UnsafeCell --- embassy/src/util/mutex.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/embassy/src/util/mutex.rs b/embassy/src/util/mutex.rs index db3423cb3..c8fe84026 100644 --- a/embassy/src/util/mutex.rs +++ b/embassy/src/util/mutex.rs @@ -108,20 +108,18 @@ pub fn in_thread_mode() -> bool { /// A "mutex" that does nothing and cannot be shared between threads. pub struct NoopMutex { - inner: UnsafeCell, + inner: T, } impl NoopMutex { pub const fn new(value: T) -> Self { - NoopMutex { - inner: UnsafeCell::new(value), - } + NoopMutex { inner: value } } } impl NoopMutex { pub fn borrow(&self) -> &T { - unsafe { &*self.inner.get() } + &self.inner } } From baab52d40cb8d1ede339a3a422006108a86d8efb Mon Sep 17 00:00:00 2001 From: huntc Date: Sun, 11 Jul 2021 13:01:36 +1000 Subject: [PATCH 14/23] Avoid a race condition by reducing the locks to one --- embassy/src/util/mpsc.rs | 84 +++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 68fcdf7f9..8d534dc49 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -145,14 +145,11 @@ where futures::future::poll_fn(|cx| self.recv_poll(cx)).await } - fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll> { - match self.try_recv() { + fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.channel.get().try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), - Err(TryRecvError::Empty) => { - self.channel.get().set_receiver_waker(&cx.waker()); - Poll::Pending - } + Err(TryRecvError::Empty) => Poll::Pending, } } @@ -279,11 +276,15 @@ where type Output = Result<(), SendError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.sender.try_send(unsafe { self.message.get().read() }) { + match self + .sender + .channel + .get() + .try_send_with_context(unsafe { self.message.get().read() }, Some(cx)) + { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -307,10 +308,9 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.sender.is_closed() { + if self.sender.channel.get().is_closed_with_context(Some(cx)) { Poll::Ready(()) } else { - self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending } } @@ -513,7 +513,11 @@ where } fn try_recv(&mut self) -> Result { - let state = &mut self.state; + self.try_recv_with_context(None) + } + + fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { + let mut state = &mut self.state; self.mutex.lock(|_| { if !state.closed { if state.read_pos != state.write_pos || state.full { @@ -526,6 +530,8 @@ where state.read_pos = (state.read_pos + 1) % state.buf.len(); Ok(message) } else if !state.closing { + cx.into_iter() + .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); Err(TryRecvError::Empty) } else { state.closed = true; @@ -539,7 +545,15 @@ where } fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - let state = &mut self.state; + self.try_send_with_context(message, None) + } + + fn try_send_with_context( + &mut self, + message: T, + cx: Option<&mut Context<'_>>, + ) -> Result<(), TrySendError> { + let mut state = &mut self.state; self.mutex.lock(|_| { if !state.closed { if !state.full { @@ -551,6 +565,8 @@ where state.receiver_waker.wake(); Ok(()) } else { + cx.into_iter() + .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); Err(TrySendError::Full(message)) } } else { @@ -568,8 +584,20 @@ where } fn is_closed(&mut self) -> bool { - let state = &self.state; - self.mutex.lock(|_| state.closing || state.closed) + self.is_closed_with_context(None) + } + + fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { + let mut state = &mut self.state; + self.mutex.lock(|_| { + if state.closing || state.closed { + cx.into_iter() + .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); + true + } else { + false + } + }) } fn register_receiver(&mut self) { @@ -610,25 +638,19 @@ where }) } - fn set_receiver_waker(&mut self, receiver_waker: &Waker) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.receiver_waker.register(receiver_waker); - }) + fn set_receiver_waker(state: &mut ChannelState, receiver_waker: &Waker) { + state.receiver_waker.register(receiver_waker); } - fn set_senders_waker(&mut self, senders_waker: &Waker) { - let state = &mut self.state; - self.mutex.lock(|_| { - // Dispose of any existing sender causing them to be polled again. - // This could cause a spin given multiple concurrent senders, however given that - // most sends only block waiting for the receiver to become active, this should - // be a short-lived activity. The upside is a greatly simplified implementation - // that avoids the need for intrusive linked-lists and unsafe operations on pinned - // pointers. - state.senders_waker.wake(); - state.senders_waker.register(senders_waker); - }) + fn set_senders_waker(state: &mut ChannelState, senders_waker: &Waker) { + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + state.senders_waker.wake(); + state.senders_waker.register(senders_waker); } } From 7c723d2bfd3e8b8cf6fa289b822a254180601528 Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 11:31:15 +1000 Subject: [PATCH 15/23] Removed UB code around the send future --- embassy/src/util/mpsc.rs | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 8d534dc49..f049b6217 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -210,7 +210,7 @@ where pub async fn send(&self, message: T) -> Result<(), SendError> { SendFuture { sender: self.clone(), - message: UnsafeCell::new(message), + message: Some(message), } .await } @@ -266,7 +266,7 @@ where M: Mutex, { sender: Sender<'ch, M, T, N>, - message: UnsafeCell, + message: Option, } impl<'ch, M, T, const N: usize> Future for SendFuture<'ch, M, T, N> @@ -275,25 +275,23 @@ where { type Output = Result<(), SendError>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self - .sender - .channel - .get() - .try_send_with_context(unsafe { self.message.get().read() }, Some(cx)) - { - Ok(..) => Poll::Ready(Ok(())), - Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), - Err(TrySendError::Full(..)) => { - Poll::Pending - // Note we leave the existing UnsafeCell contents - they still - // contain the original message. We could create another UnsafeCell - // with the message of Full, but there's no real need. - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.message.take() { + Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) { + Ok(..) => Poll::Ready(Ok(())), + Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), + Err(TrySendError::Full(m)) => { + self.message.insert(m); + Poll::Pending + } + }, + None => panic!("Message cannot be None"), } } } +impl<'ch, M, T, const N: usize> Unpin for SendFuture<'ch, M, T, N> where M: Mutex {} + struct CloseFuture<'ch, M, T, const N: usize> where M: Mutex, From a247fa4f2c90993bad3501349029c52e7bb06f9d Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 12:17:27 +1000 Subject: [PATCH 16/23] Explicitly drop non consumed items --- embassy/src/util/mpsc.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index f049b6217..b64d81c89 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -41,6 +41,7 @@ use core::cell::UnsafeCell; use core::fmt; use core::mem::MaybeUninit; use core::pin::Pin; +use core::ptr; use core::task::Context; use core::task::Poll; use core::task::Waker; @@ -416,6 +417,16 @@ impl ChannelState { } } +impl Drop for ChannelState { + fn drop(&mut self) { + while self.read_pos != self.write_pos || self.full { + self.full = false; + unsafe { ptr::drop_in_place(self.buf[self.read_pos].as_mut_ptr()) }; + self.read_pos = (self.read_pos + 1) % N; + } + } +} + /// A a bounded mpsc channel for communicating between asynchronous tasks /// with backpressure. /// From d86892ca566907aa7b9c29971229262557be49dc Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 13:31:23 +1000 Subject: [PATCH 17/23] Removed the closing state as it was not required --- embassy/src/util/mpsc.rs | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index b64d81c89..7f37eece4 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -390,7 +390,6 @@ struct ChannelState { read_pos: usize, write_pos: usize, full: bool, - closing: bool, closed: bool, receiver_registered: bool, senders_registered: u32, @@ -407,7 +406,6 @@ impl ChannelState { read_pos: 0, write_pos: 0, full: false, - closing: false, closed: false, receiver_registered: false, senders_registered: 0, @@ -528,25 +526,18 @@ where fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { let mut state = &mut self.state; self.mutex.lock(|_| { - if !state.closed { - if state.read_pos != state.write_pos || state.full { - if state.full { - state.full = false; - state.senders_waker.wake(); - } - let message = - unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; - state.read_pos = (state.read_pos + 1) % state.buf.len(); - Ok(message) - } else if !state.closing { - cx.into_iter() - .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); - Err(TryRecvError::Empty) - } else { - state.closed = true; + if state.read_pos != state.write_pos || state.full { + if state.full { + state.full = false; state.senders_waker.wake(); - Err(TryRecvError::Closed) } + let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; + state.read_pos = (state.read_pos + 1) % state.buf.len(); + Ok(message) + } else if !state.closed { + cx.into_iter() + .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); + Err(TryRecvError::Empty) } else { Err(TryRecvError::Closed) } @@ -588,7 +579,7 @@ where let state = &mut self.state; self.mutex.lock(|_| { state.receiver_waker.wake(); - state.closing = true; + state.closed = true; }); } @@ -599,7 +590,7 @@ where fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { let mut state = &mut self.state; self.mutex.lock(|_| { - if state.closing || state.closed { + if state.closed { cx.into_iter() .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); true @@ -642,7 +633,7 @@ where state.senders_registered -= 1; if state.senders_registered == 0 { state.receiver_waker.wake(); - state.closing = true; + state.closed = true; } }) } From babee7f32a4919957836a002e2c971aac368bfab Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 13:39:23 +1000 Subject: [PATCH 18/23] Tighten sender/receiver bounds --- embassy/src/util/mpsc.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 7f37eece4..b30e41318 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -65,8 +65,10 @@ where } // Safe to pass the sender around -unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Send for Sender<'ch, M, T, N> where M: Mutex + Sync +{} +unsafe impl<'ch, M, T, const N: usize> Sync for Sender<'ch, M, T, N> where M: Mutex + Sync +{} /// Receive values from the associated `Sender`. /// @@ -79,8 +81,14 @@ where } // Safe to pass the receiver around -unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where M: Mutex {} -unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where M: Mutex {} +unsafe impl<'ch, M, T, const N: usize> Send for Receiver<'ch, M, T, N> where + M: Mutex + Sync +{ +} +unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where + M: Mutex + Sync +{ +} /// Splits a bounded mpsc channel into a `Sender` and `Receiver`. /// From d711e8a82cef7ac26191e330aa4bd7cfebd570be Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 16:34:32 +1000 Subject: [PATCH 19/23] Eliminates unsoundness by using an UnsafeCell for sharing the channel --- embassy/src/util/mpsc.rs | 348 +++++++++++++++++++-------------------- 1 file changed, 174 insertions(+), 174 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index b30e41318..c409161f8 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -122,11 +122,10 @@ where { let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; - { - let c = channel.get(); + channel.lock(|c| { c.register_receiver(); c.register_sender(); - } + }); (sender, receiver) } @@ -155,11 +154,12 @@ where } fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { - match self.channel.get().try_recv_with_context(Some(cx)) { - Ok(v) => Poll::Ready(Some(v)), - Err(TryRecvError::Closed) => Poll::Ready(None), - Err(TryRecvError::Empty) => Poll::Pending, - } + self.channel + .lock(|c| match c.try_recv_with_context(Some(cx)) { + Ok(v) => Poll::Ready(Some(v)), + Err(TryRecvError::Closed) => Poll::Ready(None), + Err(TryRecvError::Empty) => Poll::Pending, + }) } /// Attempts to immediately receive a message on this `Receiver` @@ -167,7 +167,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - self.channel.get().try_recv() + self.channel.lock(|c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -181,7 +181,7 @@ where /// until those are released. /// pub fn close(&mut self) { - self.channel.get().close() + self.channel.lock(|c| c.close()) } } @@ -190,7 +190,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.get().deregister_receiver() + self.channel.lock(|c| c.deregister_receiver()) } } @@ -245,7 +245,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - self.channel.get().try_send(message) + self.channel.lock(|c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -266,7 +266,7 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - self.channel.get().is_closed() + self.channel.lock(|c| c.is_closed()) } } @@ -286,7 +286,11 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) { + Some(m) => match self + .sender + .channel + .lock(|c| c.try_send_with_context(m, Some(cx))) + { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -315,7 +319,11 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.sender.channel.get().is_closed_with_context(Some(cx)) { + if self + .sender + .channel + .lock(|c| c.is_closed_with_context(Some(cx))) + { Poll::Ready(()) } else { Poll::Pending @@ -328,7 +336,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.get().deregister_sender() + self.channel.lock(|c| c.deregister_sender()) } } @@ -338,7 +346,7 @@ where { #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - self.channel.get().register_sender(); + self.channel.lock(|c| c.register_sender()); Sender { channel: self.channel.clone(), } @@ -421,6 +429,116 @@ impl ChannelState { senders_waker: WakerRegistration::new(), } } + + fn try_recv(&mut self) -> Result { + self.try_recv_with_context(None) + } + + fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { + if self.read_pos != self.write_pos || self.full { + if self.full { + self.full = false; + self.senders_waker.wake(); + } + let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() }; + self.read_pos = (self.read_pos + 1) % self.buf.len(); + Ok(message) + } else if !self.closed { + cx.into_iter() + .for_each(|cx| self.set_receiver_waker(&cx.waker())); + Err(TryRecvError::Empty) + } else { + Err(TryRecvError::Closed) + } + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError> { + self.try_send_with_context(message, None) + } + + fn try_send_with_context( + &mut self, + message: T, + cx: Option<&mut Context<'_>>, + ) -> Result<(), TrySendError> { + if !self.closed { + if !self.full { + self.buf[self.write_pos] = MaybeUninit::new(message.into()); + self.write_pos = (self.write_pos + 1) % self.buf.len(); + if self.write_pos == self.read_pos { + self.full = true; + } + self.receiver_waker.wake(); + Ok(()) + } else { + cx.into_iter() + .for_each(|cx| self.set_senders_waker(&cx.waker())); + Err(TrySendError::Full(message)) + } + } else { + Err(TrySendError::Closed(message)) + } + } + + fn close(&mut self) { + self.receiver_waker.wake(); + self.closed = true; + } + + fn is_closed(&mut self) -> bool { + self.is_closed_with_context(None) + } + + fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { + if self.closed { + cx.into_iter() + .for_each(|cx| self.set_senders_waker(&cx.waker())); + true + } else { + false + } + } + + fn register_receiver(&mut self) { + assert!(!self.receiver_registered); + self.receiver_registered = true; + } + + fn deregister_receiver(&mut self) { + if self.receiver_registered { + self.closed = true; + self.senders_waker.wake(); + } + self.receiver_registered = false; + } + + fn register_sender(&mut self) { + self.senders_registered += 1; + } + + fn deregister_sender(&mut self) { + assert!(self.senders_registered > 0); + self.senders_registered -= 1; + if self.senders_registered == 0 { + self.receiver_waker.wake(); + self.closed = true; + } + } + + fn set_receiver_waker(&mut self, receiver_waker: &Waker) { + self.receiver_waker.register(receiver_waker); + } + + fn set_senders_waker(&mut self, senders_waker: &Waker) { + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + self.senders_waker.wake(); + self.senders_waker.register(senders_waker); + } } impl Drop for ChannelState { @@ -442,6 +560,13 @@ impl Drop for ChannelState { /// /// All data sent will become available in the same order as it was sent. pub struct Channel +where + M: Mutex, +{ + sync_channel: UnsafeCell>, +} + +struct ChannelCell where M: Mutex, { @@ -468,7 +593,10 @@ impl Channel { pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -492,7 +620,10 @@ impl Channel { pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -513,7 +644,10 @@ impl Channel { pub const fn with_no_threads() -> Self { let mutex = NoopMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -521,144 +655,13 @@ impl Channel where M: Mutex, { - fn get(&self) -> &mut Self { - let const_ptr = self as *const Self; - let mut_ptr = const_ptr as *mut Self; - unsafe { &mut *mut_ptr } - } - - fn try_recv(&mut self) -> Result { - self.try_recv_with_context(None) - } - - fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if state.read_pos != state.write_pos || state.full { - if state.full { - state.full = false; - state.senders_waker.wake(); - } - let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; - state.read_pos = (state.read_pos + 1) % state.buf.len(); - Ok(message) - } else if !state.closed { - cx.into_iter() - .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); - Err(TryRecvError::Empty) - } else { - Err(TryRecvError::Closed) - } - }) - } - - fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - self.try_send_with_context(message, None) - } - - fn try_send_with_context( - &mut self, - message: T, - cx: Option<&mut Context<'_>>, - ) -> Result<(), TrySendError> { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if !state.closed { - if !state.full { - state.buf[state.write_pos] = MaybeUninit::new(message.into()); - state.write_pos = (state.write_pos + 1) % state.buf.len(); - if state.write_pos == state.read_pos { - state.full = true; - } - state.receiver_waker.wake(); - Ok(()) - } else { - cx.into_iter() - .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); - Err(TrySendError::Full(message)) - } - } else { - Err(TrySendError::Closed(message)) - } - }) - } - - fn close(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.receiver_waker.wake(); - state.closed = true; - }); - } - - fn is_closed(&mut self) -> bool { - self.is_closed_with_context(None) - } - - fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if state.closed { - cx.into_iter() - .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); - true - } else { - false - } - }) - } - - fn register_receiver(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - assert!(!state.receiver_registered); - state.receiver_registered = true; - }); - } - - fn deregister_receiver(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - if state.receiver_registered { - state.closed = true; - state.senders_waker.wake(); - } - state.receiver_registered = false; - }) - } - - fn register_sender(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.senders_registered += 1; - }) - } - - fn deregister_sender(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - assert!(state.senders_registered > 0); - state.senders_registered -= 1; - if state.senders_registered == 0 { - state.receiver_waker.wake(); - state.closed = true; - } - }) - } - - fn set_receiver_waker(state: &mut ChannelState, receiver_waker: &Waker) { - state.receiver_waker.register(receiver_waker); - } - - fn set_senders_waker(state: &mut ChannelState, senders_waker: &Waker) { - // Dispose of any existing sender causing them to be polled again. - // This could cause a spin given multiple concurrent senders, however given that - // most sends only block waiting for the receiver to become active, this should - // be a short-lived activity. The upside is a greatly simplified implementation - // that avoids the need for intrusive linked-lists and unsafe operations on pinned - // pointers. - state.senders_waker.wake(); - state.senders_waker.register(senders_waker); + fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + unsafe { + let sync_channel = &mut *(self.sync_channel.get()); + let mutex = &mut sync_channel.mutex; + let mut state = &mut sync_channel.state; + mutex.lock(|_| f(&mut state)) + } } } @@ -672,15 +675,12 @@ mod tests { use super::*; - fn capacity(c: &Channel) -> usize - where - M: Mutex, - { - if !c.state.full { - if c.state.write_pos > c.state.read_pos { - (c.state.buf.len() - c.state.write_pos) + c.state.read_pos + fn capacity(c: &ChannelState) -> usize { + if !c.full { + if c.write_pos > c.read_pos { + (c.buf.len() - c.write_pos) + c.read_pos } else { - (c.state.buf.len() - c.state.read_pos) + c.state.write_pos + (c.buf.len() - c.read_pos) + c.write_pos } } else { 0 @@ -689,14 +689,14 @@ mod tests { #[test] fn sending_once() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); assert!(c.try_send(1).is_ok()); assert_eq!(capacity(&c), 2); } #[test] fn sending_when_full() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); let _ = c.try_send(1); let _ = c.try_send(1); let _ = c.try_send(1); @@ -709,8 +709,8 @@ mod tests { #[test] fn sending_when_closed() { - let mut c = Channel::::with_no_threads(); - c.state.closed = true; + let mut c = ChannelState::::new(); + c.closed = true; match c.try_send(2) { Err(TrySendError::Closed(2)) => assert!(true), _ => assert!(false), @@ -719,7 +719,7 @@ mod tests { #[test] fn receiving_once_with_one_send() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); assert!(c.try_send(1).is_ok()); assert_eq!(c.try_recv().unwrap(), 1); assert_eq!(capacity(&c), 3); @@ -727,7 +727,7 @@ mod tests { #[test] fn receiving_when_empty() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); match c.try_recv() { Err(TryRecvError::Empty) => assert!(true), _ => assert!(false), @@ -737,8 +737,8 @@ mod tests { #[test] fn receiving_when_closed() { - let mut c = Channel::::with_no_threads(); - c.state.closed = true; + let mut c = ChannelState::::new(); + c.closed = true; match c.try_recv() { Err(TryRecvError::Closed) => assert!(true), _ => assert!(false), From 076198a3b99e153fa9b6189d60b2abbfc0b1f29a Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 20:06:15 +1000 Subject: [PATCH 20/23] Small tidy up --- embassy/src/util/mpsc.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index c409161f8..f350c6e53 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -563,7 +563,7 @@ pub struct Channel where M: Mutex, { - sync_channel: UnsafeCell>, + channel_cell: UnsafeCell>, } struct ChannelCell @@ -593,9 +593,9 @@ impl Channel { pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); let state = ChannelState::new(); - let sync_channel = ChannelCell { mutex, state }; + let channel_cell = ChannelCell { mutex, state }; Channel { - sync_channel: UnsafeCell::new(sync_channel), + channel_cell: UnsafeCell::new(channel_cell), } } } @@ -620,9 +620,9 @@ impl Channel { pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); let state = ChannelState::new(); - let sync_channel = ChannelCell { mutex, state }; + let channel_cell = ChannelCell { mutex, state }; Channel { - sync_channel: UnsafeCell::new(sync_channel), + channel_cell: UnsafeCell::new(channel_cell), } } } @@ -644,9 +644,9 @@ impl Channel { pub const fn with_no_threads() -> Self { let mutex = NoopMutex::new(()); let state = ChannelState::new(); - let sync_channel = ChannelCell { mutex, state }; + let channel_cell = ChannelCell { mutex, state }; Channel { - sync_channel: UnsafeCell::new(sync_channel), + channel_cell: UnsafeCell::new(channel_cell), } } } @@ -657,9 +657,9 @@ where { fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { unsafe { - let sync_channel = &mut *(self.sync_channel.get()); - let mutex = &mut sync_channel.mutex; - let mut state = &mut sync_channel.state; + let channel_cell = &mut *(self.channel_cell.get()); + let mutex = &mut channel_cell.mutex; + let mut state = &mut channel_cell.state; mutex.lock(|_| f(&mut state)) } } From 6f78527aeb7a0bacb02ca3264edd04d37550ea02 Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 22:10:14 +1000 Subject: [PATCH 21/23] Partial borrow for receiver to enforce compile-time mpssc --- embassy/src/util/mpsc.rs | 111 +++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index f350c6e53..246bd27e4 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -39,6 +39,7 @@ use core::cell::UnsafeCell; use core::fmt; +use core::marker::PhantomData; use core::mem::MaybeUninit; use core::pin::Pin; use core::ptr; @@ -61,7 +62,7 @@ pub struct Sender<'ch, M, T, const N: usize> where M: Mutex, { - channel: &'ch Channel, + channel_cell: &'ch UnsafeCell>, } // Safe to pass the sender around @@ -77,7 +78,8 @@ pub struct Receiver<'ch, M, T, const N: usize> where M: Mutex, { - channel: &'ch Channel, + channel_cell: &'ch UnsafeCell>, + _receiver_consumed: &'ch mut PhantomData<()>, } // Safe to pass the receiver around @@ -111,18 +113,23 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where /// /// let (sender, receiver) = { /// let mut channel = Channel::::with_thread_mode_only(); -/// mpsc::split(&channel) +/// mpsc::split(&mut channel) /// }; /// ``` pub fn split( - channel: &Channel, + channel: &mut Channel, ) -> (Sender, Receiver) where M: Mutex, { - let sender = Sender { channel: &channel }; - let receiver = Receiver { channel: &channel }; - channel.lock(|c| { + let sender = Sender { + channel_cell: &channel.channel_cell, + }; + let receiver = Receiver { + channel_cell: &channel.channel_cell, + _receiver_consumed: &mut channel.receiver_consumed, + }; + Channel::lock(&channel.channel_cell, |c| { c.register_receiver(); c.register_sender(); }); @@ -154,12 +161,13 @@ where } fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { - self.channel - .lock(|c| match c.try_recv_with_context(Some(cx)) { + Channel::lock(self.channel_cell, |c| { + match c.try_recv_with_context(Some(cx)) { Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => Poll::Pending, - }) + } + }) } /// Attempts to immediately receive a message on this `Receiver` @@ -167,7 +175,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - self.channel.lock(|c| c.try_recv()) + Channel::lock(self.channel_cell, |c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -181,7 +189,7 @@ where /// until those are released. /// pub fn close(&mut self) { - self.channel.lock(|c| c.close()) + Channel::lock(self.channel_cell, |c| c.close()) } } @@ -190,7 +198,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.lock(|c| c.deregister_receiver()) + Channel::lock(self.channel_cell, |c| c.deregister_receiver()) } } @@ -245,7 +253,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - self.channel.lock(|c| c.try_send(message)) + Channel::lock(self.channel_cell, |c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -266,7 +274,7 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - self.channel.lock(|c| c.is_closed()) + Channel::lock(self.channel_cell, |c| c.is_closed()) } } @@ -286,11 +294,9 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match self - .sender - .channel - .lock(|c| c.try_send_with_context(m, Some(cx))) - { + Some(m) => match Channel::lock(self.sender.channel_cell, |c| { + c.try_send_with_context(m, Some(cx)) + }) { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -319,11 +325,9 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self - .sender - .channel - .lock(|c| c.is_closed_with_context(Some(cx))) - { + if Channel::lock(self.sender.channel_cell, |c| { + c.is_closed_with_context(Some(cx)) + }) { Poll::Ready(()) } else { Poll::Pending @@ -336,7 +340,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.lock(|c| c.deregister_sender()) + Channel::lock(self.channel_cell, |c| c.deregister_sender()) } } @@ -346,9 +350,9 @@ where { #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - self.channel.lock(|c| c.register_sender()); + Channel::lock(self.channel_cell, |c| c.register_sender()); Sender { - channel: self.channel.clone(), + channel_cell: self.channel_cell.clone(), } } } @@ -564,6 +568,7 @@ where M: Mutex, { channel_cell: UnsafeCell>, + receiver_consumed: PhantomData<()>, } struct ChannelCell @@ -588,7 +593,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_critical_sections(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); @@ -596,6 +601,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -615,7 +621,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_thread_mode_only(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); @@ -623,6 +629,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -639,7 +646,7 @@ impl Channel { /// // Declare a bounded channel of 3 u32s. /// let mut channel = Channel::::with_no_threads(); /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&channel); + /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` pub const fn with_no_threads() -> Self { let mutex = NoopMutex::new(()); @@ -647,6 +654,7 @@ impl Channel { let channel_cell = ChannelCell { mutex, state }; Channel { channel_cell: UnsafeCell::new(channel_cell), + receiver_consumed: PhantomData, } } } @@ -655,9 +663,12 @@ impl Channel where M: Mutex, { - fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + fn lock( + channel_cell: &UnsafeCell>, + f: impl FnOnce(&mut ChannelState) -> R, + ) -> R { unsafe { - let channel_cell = &mut *(self.channel_cell.get()); + let channel_cell = &mut *(channel_cell.get()); let mutex = &mut channel_cell.mutex; let mut state = &mut channel_cell.state; mutex.lock(|_| f(&mut state)) @@ -747,16 +758,16 @@ mod tests { #[test] fn simple_send_and_receive() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); } #[test] fn should_close_without_sender() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); drop(s); match r.try_recv() { Err(TryRecvError::Closed) => assert!(true), @@ -766,8 +777,8 @@ mod tests { #[test] fn should_close_once_drained() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); assert!(s.try_send(1).is_ok()); drop(s); assert_eq!(r.try_recv().unwrap(), 1); @@ -779,8 +790,8 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let c = Channel::::with_no_threads(); - let (s, r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, r) = split(&mut c); drop(r); match s.try_send(1) { Err(TrySendError::Closed(1)) => assert!(true), @@ -790,8 +801,8 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let c = Channel::::with_no_threads(); - let (s, mut r) = split(&c); + let mut c = Channel::::with_no_threads(); + let (s, mut r) = split(&mut c); assert!(s.try_send(1).is_ok()); r.close(); assert_eq!(r.try_recv().unwrap(), 1); @@ -808,7 +819,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(executor .spawn(async move { drop(s); @@ -823,7 +834,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(executor .spawn(async move { assert!(s.try_send(1).is_ok()); @@ -836,7 +847,7 @@ mod tests { async fn sender_send_completes_if_capacity() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); } @@ -845,7 +856,7 @@ mod tests { async fn sender_send_completes_if_closed() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, r) = split(unsafe { &CHANNEL }); + let (s, r) = split(unsafe { &mut CHANNEL }); drop(r); match s.send(1).await { Err(SendError(1)) => assert!(true), @@ -859,7 +870,7 @@ mod tests { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s0, mut r) = split(unsafe { &CHANNEL }); + let (s0, mut r) = split(unsafe { &mut CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); @@ -879,7 +890,7 @@ mod tests { async fn sender_close_completes_if_closing() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &CHANNEL }); + let (s, mut r) = split(unsafe { &mut CHANNEL }); r.close(); s.closed().await; } @@ -888,7 +899,7 @@ mod tests { async fn sender_close_completes_if_closed() { static mut CHANNEL: Channel = Channel::with_critical_sections(); - let (s, r) = split(unsafe { &CHANNEL }); + let (s, r) = split(unsafe { &mut CHANNEL }); drop(r); s.closed().await; } From 3778f55d80f70b336f6ca846f365cf619032a685 Mon Sep 17 00:00:00 2001 From: huntc Date: Thu, 15 Jul 2021 12:08:35 +1000 Subject: [PATCH 22/23] Provides a cleaner construction of the channel with the common "new" naming --- embassy/src/util/mpsc.rs | 122 ++++++++++------------------------- embassy/src/util/mutex.rs | 14 ++++ examples/nrf/src/bin/mpsc.rs | 3 +- 3 files changed, 51 insertions(+), 88 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 246bd27e4..cc9e2a5dd 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -581,75 +581,27 @@ where pub type WithCriticalSections = CriticalSectionMutex<()>; -impl Channel { - /// Establish a new bounded channel using critical sections. Critical sections - /// should be used only single core targets where communication is required - /// from exception mode e.g. interrupt handlers. To create one: - /// - /// ``` - /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithCriticalSections}; - /// - /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::with_critical_sections(); - /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); - /// ``` - pub const fn with_critical_sections() -> Self { - let mutex = CriticalSectionMutex::new(()); - let state = ChannelState::new(); - let channel_cell = ChannelCell { mutex, state }; - Channel { - channel_cell: UnsafeCell::new(channel_cell), - receiver_consumed: PhantomData, - } - } -} - pub type WithThreadModeOnly = ThreadModeMutex<()>; -impl Channel { - /// Establish a new bounded channel for use in Cortex-M thread mode. Thread - /// mode is intended for application threads on a single core, not interrupts. - /// As such, only one task at a time can acquire a resource and so this - /// channel avoids all locks. To create one: - /// - /// ``` no_run - /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; - /// - /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::with_thread_mode_only(); - /// // once we have a channel, obtain its sender and receiver - /// let (sender, receiver) = mpsc::split(&mut channel); - /// ``` - pub const fn with_thread_mode_only() -> Self { - let mutex = ThreadModeMutex::new(()); - let state = ChannelState::new(); - let channel_cell = ChannelCell { mutex, state }; - Channel { - channel_cell: UnsafeCell::new(channel_cell), - receiver_consumed: PhantomData, - } - } -} - pub type WithNoThreads = NoopMutex<()>; -impl Channel { - /// Establish a new bounded channel for within a single thread. To create one: +impl Channel +where + M: Mutex, +{ + /// Establish a new bounded channel. For example, to create one with a NoopMutex: /// /// ``` /// use embassy::util::mpsc; /// use embassy::util::mpsc::{Channel, WithNoThreads}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = Channel::::with_no_threads(); + /// let mut channel = Channel::::new(); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&mut channel); /// ``` - pub const fn with_no_threads() -> Self { - let mutex = NoopMutex::new(()); + pub fn new() -> Self { + let mutex = M::new(()); let state = ChannelState::new(); let channel_cell = ChannelCell { mutex, state }; Channel { @@ -657,12 +609,7 @@ impl Channel { receiver_consumed: PhantomData, } } -} -impl Channel -where - M: Mutex, -{ fn lock( channel_cell: &UnsafeCell>, f: impl FnOnce(&mut ChannelState) -> R, @@ -684,6 +631,8 @@ mod tests { use futures_executor::ThreadPool; use futures_timer::Delay; + use crate::util::Forever; + use super::*; fn capacity(c: &ChannelState) -> usize { @@ -758,7 +707,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let mut c = Channel::::with_no_threads(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -766,7 +715,7 @@ mod tests { #[test] fn should_close_without_sender() { - let mut c = Channel::::with_no_threads(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(s); match r.try_recv() { @@ -777,7 +726,7 @@ mod tests { #[test] fn should_close_once_drained() { - let mut c = Channel::::with_no_threads(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); assert!(s.try_send(1).is_ok()); drop(s); @@ -790,7 +739,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let mut c = Channel::::with_no_threads(); + let mut c = Channel::::new(); let (s, r) = split(&mut c); drop(r); match s.try_send(1) { @@ -801,7 +750,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let mut c = Channel::::with_no_threads(); + let mut c = Channel::::new(); let (s, mut r) = split(&mut c); assert!(s.try_send(1).is_ok()); r.close(); @@ -817,9 +766,9 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s, mut r) = split(c); assert!(executor .spawn(async move { drop(s); @@ -832,9 +781,9 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s, mut r) = split(c); assert!(executor .spawn(async move { assert!(s.try_send(1).is_ok()); @@ -845,18 +794,17 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + let mut c = Channel::::new(); + let (s, mut r) = split(&mut c); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); } #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s, r) = split(c); drop(r); match s.send(1).await { Err(SendError(1)) => assert!(true), @@ -868,9 +816,9 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s0, mut r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s0, mut r) = split(c); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); @@ -888,18 +836,18 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, mut r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s, mut r) = split(c); r.close(); s.closed().await; } #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: Channel = - Channel::with_critical_sections(); - let (s, r) = split(unsafe { &mut CHANNEL }); + static CHANNEL: Forever> = Forever::new(); + let c = CHANNEL.put(Channel::new()); + let (s, r) = split(c); drop(r); s.closed().await; } diff --git a/embassy/src/util/mutex.rs b/embassy/src/util/mutex.rs index c8fe84026..0506ffe6f 100644 --- a/embassy/src/util/mutex.rs +++ b/embassy/src/util/mutex.rs @@ -8,6 +8,8 @@ pub trait Mutex { /// Data protected by the mutex. type Data; + fn new(data: Self::Data) -> Self; + /// Creates a critical section and grants temporary access to the protected data. fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R; } @@ -47,6 +49,10 @@ impl CriticalSectionMutex { impl Mutex for CriticalSectionMutex { type Data = T; + fn new(data: T) -> Self { + Self::new(data) + } + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { critical_section::with(|cs| f(self.borrow(cs))) } @@ -92,6 +98,10 @@ impl ThreadModeMutex { impl Mutex for ThreadModeMutex { type Data = T; + fn new(data: T) -> Self { + Self::new(data) + } + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { f(self.borrow()) } @@ -126,6 +136,10 @@ impl NoopMutex { impl Mutex for NoopMutex { type Data = T; + fn new(data: T) -> Self { + Self::new(data) + } + fn lock(&mut self, f: impl FnOnce(&Self::Data) -> R) -> R { f(self.borrow()) } diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index c2cb107e1..443955239 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -37,9 +37,10 @@ async fn my_task(sender: Sender<'static, WithNoThreads, LedState, 1>) { #[embassy::main] async fn main(spawner: Spawner, p: Peripherals) { + let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(Channel::with_no_threads()); + let channel = CHANNEL.put(Channel::new()); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap(); From 72d6f79ec731980bf776d35bfc0af848a1d994da Mon Sep 17 00:00:00 2001 From: huntc Date: Thu, 15 Jul 2021 12:31:36 +1000 Subject: [PATCH 23/23] Feature no longer required given 1.55 --- embassy/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/embassy/src/lib.rs b/embassy/src/lib.rs index 845f82a3f..41102a180 100644 --- a/embassy/src/lib.rs +++ b/embassy/src/lib.rs @@ -7,7 +7,6 @@ #![feature(min_type_alias_impl_trait)] #![feature(impl_trait_in_bindings)] #![feature(type_alias_impl_trait)] -#![feature(maybe_uninit_ref)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt;