diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 5f2dd729a..aabced416 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,11 +51,36 @@ use super::CriticalSectionMutex; use super::Mutex; use super::ThreadModeMutex; +/// A ChannelCell permits a channel to be shared between senders and their receivers. +// Derived from UnsafeCell. +#[repr(transparent)] +pub struct ChannelCell { + _value: T, +} + +impl ChannelCell { + #[inline(always)] + pub const fn new(value: T) -> ChannelCell { + ChannelCell { _value: value } + } +} + +impl ChannelCell { + #[inline(always)] + const fn get(&self) -> *mut T { + // As per UnsafeCell: + // We can just cast the pointer from `ChannelCell` to `T` because of + // #[repr(transparent)]. This exploits libstd's special status, there is + // no guarantee for user code that this will work in future versions of the compiler! + self as *const ChannelCell as *const T as *mut T + } +} + /// Send values to the associated `Receiver`. /// /// Instances are created by the [`split`](split) function. pub struct Sender<'ch, T> { - channel: &'ch UnsafeCell>, + channel: &'ch ChannelCell>, } // Safe to pass the sender around @@ -66,7 +91,7 @@ unsafe impl<'ch, T> Sync for Sender<'ch, T> {} /// /// Instances are created by the [`split`](split) function. pub struct Receiver<'ch, T> { - channel: &'ch UnsafeCell>, + channel: &'ch ChannelCell>, } // Safe to pass the receiver around @@ -89,16 +114,15 @@ unsafe impl<'ch, T> Sync for Receiver<'ch, T> {} /// their channel. The following will therefore fail compilation: //// /// ```compile_fail -/// use core::cell::UnsafeCell; /// use embassy::util::mpsc; -/// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; +/// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; /// /// let (sender, receiver) = { -/// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); +/// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); /// mpsc::split(&channel) /// }; /// ``` -pub fn split(channel: &UnsafeCell>) -> (Sender, Receiver) { +pub fn split(channel: &ChannelCell>) -> (Sender, Receiver) { let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; { @@ -439,12 +463,11 @@ impl Channel { /// from exception mode e.g. interrupt handlers. To create one: /// /// ``` - /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithCriticalSections}; + /// use embassy::util::mpsc::{Channel, ChannelCell, WithCriticalSections}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = UnsafeCell::new(mpsc::Channel::::with_critical_sections()); + /// let mut channel = ChannelCell::new(mpsc::Channel::::with_critical_sections()); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -464,12 +487,11 @@ impl Channel { /// channel avoids all locks. To create one: /// /// ``` no_run - /// use core::cell::UnsafeCell; /// use embassy::util::mpsc; - /// use embassy::util::mpsc::{Channel, WithThreadModeOnly}; + /// use embassy::util::mpsc::{Channel, ChannelCell, WithThreadModeOnly}; /// /// // Declare a bounded channel of 3 u32s. - /// let mut channel = UnsafeCell::new(Channel::::with_thread_mode_only()); + /// let mut channel = ChannelCell::new(Channel::::with_thread_mode_only()); /// // once we have a channel, obtain its sender and receiver /// let (sender, receiver) = mpsc::split(&channel); /// ``` @@ -744,7 +766,7 @@ mod tests { #[test] fn simple_send_and_receive() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); assert!(s.clone().try_send(1).is_ok()); assert_eq!(r.try_recv().unwrap(), 1); @@ -752,7 +774,7 @@ mod tests { #[test] fn should_close_without_sender() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); drop(s); match r.try_recv() { @@ -763,7 +785,7 @@ mod tests { #[test] fn should_close_once_drained() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); assert!(s.try_send(1).is_ok()); drop(s); @@ -776,7 +798,7 @@ mod tests { #[test] fn should_reject_send_when_receiver_dropped() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, r) = split(&c); drop(r); match s.try_send(1) { @@ -787,7 +809,7 @@ mod tests { #[test] fn should_reject_send_when_channel_closed() { - let c = UnsafeCell::new(Channel::::with_no_threads()); + let c = ChannelCell::new(Channel::::with_no_threads()); let (s, mut r) = split(&c); assert!(s.try_send(1).is_ok()); r.close(); @@ -803,8 +825,8 @@ mod tests { async fn receiver_closes_when_sender_dropped_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -818,8 +840,8 @@ mod tests { async fn receiver_receives_given_try_send_async() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(executor .spawn(async move { @@ -831,8 +853,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_capacity() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); assert!(s.send(1).await.is_ok()); assert_eq!(r.recv().await, Some(1)); @@ -840,8 +862,8 @@ mod tests { #[futures_test::test] async fn sender_send_completes_if_closed() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, r) = split(unsafe { &CHANNEL }); drop(r); match s.send(1).await { @@ -854,8 +876,8 @@ mod tests { async fn senders_sends_wait_until_capacity() { let executor = ThreadPool::new().unwrap(); - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s0, mut r) = split(unsafe { &CHANNEL }); assert!(s0.try_send(1).is_ok()); let s1 = s0.clone(); @@ -874,8 +896,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closing() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, mut r) = split(unsafe { &CHANNEL }); r.close(); s.closed().await; @@ -883,8 +905,8 @@ mod tests { #[futures_test::test] async fn sender_close_completes_if_closed() { - static mut CHANNEL: UnsafeCell> = - UnsafeCell::new(Channel::with_critical_sections()); + static mut CHANNEL: ChannelCell> = + ChannelCell::new(Channel::with_critical_sections()); let (s, r) = split(unsafe { &CHANNEL }); drop(r); s.closed().await; diff --git a/examples/nrf/src/bin/mpsc.rs b/examples/nrf/src/bin/mpsc.rs index d692abee2..eafa29e60 100644 --- a/examples/nrf/src/bin/mpsc.rs +++ b/examples/nrf/src/bin/mpsc.rs @@ -8,12 +8,10 @@ #[path = "../example_common.rs"] mod example_common; -use core::cell::UnsafeCell; - use defmt::panic; use embassy::executor::Spawner; use embassy::time::{Duration, Timer}; -use embassy::util::mpsc::TryRecvError; +use embassy::util::mpsc::{ChannelCell, TryRecvError}; use embassy::util::{mpsc, Forever}; use embassy_nrf::gpio::{Level, Output, OutputDrive}; use embassy_nrf::Peripherals; @@ -25,7 +23,7 @@ enum LedState { Off, } -static CHANNEL: Forever>> = Forever::new(); +static CHANNEL: Forever>> = Forever::new(); #[embassy::task(pool_size = 1)] async fn my_task(sender: Sender<'static, LedState>) { @@ -41,7 +39,7 @@ async fn my_task(sender: Sender<'static, LedState>) { async fn main(spawner: Spawner, p: Peripherals) { let mut led = Output::new(p.P0_13, Level::Low, OutputDrive::Standard); - let channel = CHANNEL.put(UnsafeCell::new(Channel::with_thread_mode_only())); + let channel = CHANNEL.put(ChannelCell::new(Channel::with_thread_mode_only())); let (sender, mut receiver) = mpsc::split(channel); spawner.spawn(my_task(sender)).unwrap();