diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index 0a8ab4434..ae06bc198 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -42,6 +42,7 @@ embassy-traits = { version = "0.1.0", path = "../embassy-traits"} atomic-polyfill = "0.1.3" critical-section = "0.2.1" embedded-hal = "0.2.6" +heapless = "0.7.5" [dev-dependencies] embassy = { path = ".", features = ["executor-agnostic"] } diff --git a/embassy/src/channel/mpsc.rs b/embassy/src/channel/mpsc.rs index b20d48a95..c77452441 100644 --- a/embassy/src/channel/mpsc.rs +++ b/embassy/src/channel/mpsc.rs @@ -40,14 +40,13 @@ use core::cell::UnsafeCell; use core::fmt; use core::marker::PhantomData; -use core::mem::MaybeUninit; use core::pin::Pin; -use core::ptr; use core::task::Context; use core::task::Poll; use core::task::Waker; use futures::Future; +use heapless::Deque; use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; use crate::waitqueue::WakerRegistration; @@ -446,10 +445,7 @@ impl defmt::Format for TrySendError { } struct ChannelState { - buf: [MaybeUninit>; N], - read_pos: usize, - write_pos: usize, - full: bool, + queue: Deque, closed: bool, receiver_registered: bool, senders_registered: u32, @@ -458,14 +454,9 @@ struct ChannelState { } impl ChannelState { - const INIT: MaybeUninit> = MaybeUninit::uninit(); - const fn new() -> Self { ChannelState { - buf: [Self::INIT; N], - read_pos: 0, - write_pos: 0, - full: false, + queue: Deque::new(), closed: false, receiver_registered: false, senders_registered: 0, @@ -479,17 +470,16 @@ impl ChannelState { } fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { - if self.read_pos != self.write_pos || self.full { - if self.full { - self.full = false; - self.senders_waker.wake(); - } - let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() }; - self.read_pos = (self.read_pos + 1) % self.buf.len(); + if self.queue.is_full() { + self.senders_waker.wake(); + } + + if let Some(message) = self.queue.pop_front() { Ok(message) } else if !self.closed { - cx.into_iter() - .for_each(|cx| self.set_receiver_waker(&cx.waker())); + if let Some(cx) = cx { + self.set_receiver_waker(cx.waker()); + } Err(TryRecvError::Empty) } else { Err(TryRecvError::Closed) @@ -505,22 +495,21 @@ impl ChannelState { message: T, cx: Option<&mut Context<'_>>, ) -> Result<(), TrySendError> { - if !self.closed { - if !self.full { - self.buf[self.write_pos] = MaybeUninit::new(message.into()); - self.write_pos = (self.write_pos + 1) % self.buf.len(); - if self.write_pos == self.read_pos { - self.full = true; - } + if self.closed { + return Err(TrySendError::Closed(message)); + } + + match self.queue.push_back(message) { + Ok(()) => { self.receiver_waker.wake(); + Ok(()) - } else { + } + Err(message) => { cx.into_iter() .for_each(|cx| self.set_senders_waker(&cx.waker())); Err(TrySendError::Full(message)) } - } else { - Err(TrySendError::Closed(message)) } } @@ -585,16 +574,6 @@ impl ChannelState { } } -impl Drop for ChannelState { - fn drop(&mut self) { - while self.read_pos != self.write_pos || self.full { - self.full = false; - unsafe { ptr::drop_in_place(self.buf[self.read_pos].as_mut_ptr()) }; - self.read_pos = (self.read_pos + 1) % N; - } - } -} - /// A a bounded mpsc channel for communicating between asynchronous tasks /// with backpressure. /// @@ -676,15 +655,7 @@ mod tests { use super::*; fn capacity(c: &ChannelState) -> usize { - if !c.full { - if c.write_pos > c.read_pos { - (c.buf.len() - c.write_pos) + c.read_pos - } else { - (c.buf.len() - c.read_pos) + c.write_pos - } - } else { - 0 - } + c.queue.capacity() - c.queue.len() } #[test]