From d711e8a82cef7ac26191e330aa4bd7cfebd570be Mon Sep 17 00:00:00 2001 From: huntc Date: Wed, 14 Jul 2021 16:34:32 +1000 Subject: [PATCH] Eliminates unsoundness by using an UnsafeCell for sharing the channel --- embassy/src/util/mpsc.rs | 348 +++++++++++++++++++-------------------- 1 file changed, 174 insertions(+), 174 deletions(-) diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index b30e41318..c409161f8 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -122,11 +122,10 @@ where { let sender = Sender { channel: &channel }; let receiver = Receiver { channel: &channel }; - { - let c = channel.get(); + channel.lock(|c| { c.register_receiver(); c.register_sender(); - } + }); (sender, receiver) } @@ -155,11 +154,12 @@ where } fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll> { - match self.channel.get().try_recv_with_context(Some(cx)) { - Ok(v) => Poll::Ready(Some(v)), - Err(TryRecvError::Closed) => Poll::Ready(None), - Err(TryRecvError::Empty) => Poll::Pending, - } + self.channel + .lock(|c| match c.try_recv_with_context(Some(cx)) { + Ok(v) => Poll::Ready(Some(v)), + Err(TryRecvError::Closed) => Poll::Ready(None), + Err(TryRecvError::Empty) => Poll::Pending, + }) } /// Attempts to immediately receive a message on this `Receiver` @@ -167,7 +167,7 @@ where /// This method will either receive a message from the channel immediately or return an error /// if the channel is empty. pub fn try_recv(&self) -> Result { - self.channel.get().try_recv() + self.channel.lock(|c| c.try_recv()) } /// Closes the receiving half of a channel without dropping it. @@ -181,7 +181,7 @@ where /// until those are released. /// pub fn close(&mut self) { - self.channel.get().close() + self.channel.lock(|c| c.close()) } } @@ -190,7 +190,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.get().deregister_receiver() + self.channel.lock(|c| c.deregister_receiver()) } } @@ -245,7 +245,7 @@ where /// [`channel`]: channel /// [`close`]: Receiver::close pub fn try_send(&self, message: T) -> Result<(), TrySendError> { - self.channel.get().try_send(message) + self.channel.lock(|c| c.try_send(message)) } /// Completes when the receiver has dropped. @@ -266,7 +266,7 @@ where /// [`Receiver`]: crate::sync::mpsc::Receiver /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close pub fn is_closed(&self) -> bool { - self.channel.get().is_closed() + self.channel.lock(|c| c.is_closed()) } } @@ -286,7 +286,11 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.message.take() { - Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) { + Some(m) => match self + .sender + .channel + .lock(|c| c.try_send_with_context(m, Some(cx))) + { Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(m)) => { @@ -315,7 +319,11 @@ where type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.sender.channel.get().is_closed_with_context(Some(cx)) { + if self + .sender + .channel + .lock(|c| c.is_closed_with_context(Some(cx))) + { Poll::Ready(()) } else { Poll::Pending @@ -328,7 +336,7 @@ where M: Mutex, { fn drop(&mut self) { - self.channel.get().deregister_sender() + self.channel.lock(|c| c.deregister_sender()) } } @@ -338,7 +346,7 @@ where { #[allow(clippy::clone_double_ref)] fn clone(&self) -> Self { - self.channel.get().register_sender(); + self.channel.lock(|c| c.register_sender()); Sender { channel: self.channel.clone(), } @@ -421,6 +429,116 @@ impl ChannelState { senders_waker: WakerRegistration::new(), } } + + fn try_recv(&mut self) -> Result { + self.try_recv_with_context(None) + } + + fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { + if self.read_pos != self.write_pos || self.full { + if self.full { + self.full = false; + self.senders_waker.wake(); + } + let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() }; + self.read_pos = (self.read_pos + 1) % self.buf.len(); + Ok(message) + } else if !self.closed { + cx.into_iter() + .for_each(|cx| self.set_receiver_waker(&cx.waker())); + Err(TryRecvError::Empty) + } else { + Err(TryRecvError::Closed) + } + } + + fn try_send(&mut self, message: T) -> Result<(), TrySendError> { + self.try_send_with_context(message, None) + } + + fn try_send_with_context( + &mut self, + message: T, + cx: Option<&mut Context<'_>>, + ) -> Result<(), TrySendError> { + if !self.closed { + if !self.full { + self.buf[self.write_pos] = MaybeUninit::new(message.into()); + self.write_pos = (self.write_pos + 1) % self.buf.len(); + if self.write_pos == self.read_pos { + self.full = true; + } + self.receiver_waker.wake(); + Ok(()) + } else { + cx.into_iter() + .for_each(|cx| self.set_senders_waker(&cx.waker())); + Err(TrySendError::Full(message)) + } + } else { + Err(TrySendError::Closed(message)) + } + } + + fn close(&mut self) { + self.receiver_waker.wake(); + self.closed = true; + } + + fn is_closed(&mut self) -> bool { + self.is_closed_with_context(None) + } + + fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { + if self.closed { + cx.into_iter() + .for_each(|cx| self.set_senders_waker(&cx.waker())); + true + } else { + false + } + } + + fn register_receiver(&mut self) { + assert!(!self.receiver_registered); + self.receiver_registered = true; + } + + fn deregister_receiver(&mut self) { + if self.receiver_registered { + self.closed = true; + self.senders_waker.wake(); + } + self.receiver_registered = false; + } + + fn register_sender(&mut self) { + self.senders_registered += 1; + } + + fn deregister_sender(&mut self) { + assert!(self.senders_registered > 0); + self.senders_registered -= 1; + if self.senders_registered == 0 { + self.receiver_waker.wake(); + self.closed = true; + } + } + + fn set_receiver_waker(&mut self, receiver_waker: &Waker) { + self.receiver_waker.register(receiver_waker); + } + + fn set_senders_waker(&mut self, senders_waker: &Waker) { + // Dispose of any existing sender causing them to be polled again. + // This could cause a spin given multiple concurrent senders, however given that + // most sends only block waiting for the receiver to become active, this should + // be a short-lived activity. The upside is a greatly simplified implementation + // that avoids the need for intrusive linked-lists and unsafe operations on pinned + // pointers. + self.senders_waker.wake(); + self.senders_waker.register(senders_waker); + } } impl Drop for ChannelState { @@ -442,6 +560,13 @@ impl Drop for ChannelState { /// /// All data sent will become available in the same order as it was sent. pub struct Channel +where + M: Mutex, +{ + sync_channel: UnsafeCell>, +} + +struct ChannelCell where M: Mutex, { @@ -468,7 +593,10 @@ impl Channel { pub const fn with_critical_sections() -> Self { let mutex = CriticalSectionMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -492,7 +620,10 @@ impl Channel { pub const fn with_thread_mode_only() -> Self { let mutex = ThreadModeMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -513,7 +644,10 @@ impl Channel { pub const fn with_no_threads() -> Self { let mutex = NoopMutex::new(()); let state = ChannelState::new(); - Channel { mutex, state } + let sync_channel = ChannelCell { mutex, state }; + Channel { + sync_channel: UnsafeCell::new(sync_channel), + } } } @@ -521,144 +655,13 @@ impl Channel where M: Mutex, { - fn get(&self) -> &mut Self { - let const_ptr = self as *const Self; - let mut_ptr = const_ptr as *mut Self; - unsafe { &mut *mut_ptr } - } - - fn try_recv(&mut self) -> Result { - self.try_recv_with_context(None) - } - - fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if state.read_pos != state.write_pos || state.full { - if state.full { - state.full = false; - state.senders_waker.wake(); - } - let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; - state.read_pos = (state.read_pos + 1) % state.buf.len(); - Ok(message) - } else if !state.closed { - cx.into_iter() - .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); - Err(TryRecvError::Empty) - } else { - Err(TryRecvError::Closed) - } - }) - } - - fn try_send(&mut self, message: T) -> Result<(), TrySendError> { - self.try_send_with_context(message, None) - } - - fn try_send_with_context( - &mut self, - message: T, - cx: Option<&mut Context<'_>>, - ) -> Result<(), TrySendError> { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if !state.closed { - if !state.full { - state.buf[state.write_pos] = MaybeUninit::new(message.into()); - state.write_pos = (state.write_pos + 1) % state.buf.len(); - if state.write_pos == state.read_pos { - state.full = true; - } - state.receiver_waker.wake(); - Ok(()) - } else { - cx.into_iter() - .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); - Err(TrySendError::Full(message)) - } - } else { - Err(TrySendError::Closed(message)) - } - }) - } - - fn close(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.receiver_waker.wake(); - state.closed = true; - }); - } - - fn is_closed(&mut self) -> bool { - self.is_closed_with_context(None) - } - - fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { - let mut state = &mut self.state; - self.mutex.lock(|_| { - if state.closed { - cx.into_iter() - .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); - true - } else { - false - } - }) - } - - fn register_receiver(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - assert!(!state.receiver_registered); - state.receiver_registered = true; - }); - } - - fn deregister_receiver(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - if state.receiver_registered { - state.closed = true; - state.senders_waker.wake(); - } - state.receiver_registered = false; - }) - } - - fn register_sender(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - state.senders_registered += 1; - }) - } - - fn deregister_sender(&mut self) { - let state = &mut self.state; - self.mutex.lock(|_| { - assert!(state.senders_registered > 0); - state.senders_registered -= 1; - if state.senders_registered == 0 { - state.receiver_waker.wake(); - state.closed = true; - } - }) - } - - fn set_receiver_waker(state: &mut ChannelState, receiver_waker: &Waker) { - state.receiver_waker.register(receiver_waker); - } - - fn set_senders_waker(state: &mut ChannelState, senders_waker: &Waker) { - // Dispose of any existing sender causing them to be polled again. - // This could cause a spin given multiple concurrent senders, however given that - // most sends only block waiting for the receiver to become active, this should - // be a short-lived activity. The upside is a greatly simplified implementation - // that avoids the need for intrusive linked-lists and unsafe operations on pinned - // pointers. - state.senders_waker.wake(); - state.senders_waker.register(senders_waker); + fn lock(&self, f: impl FnOnce(&mut ChannelState) -> R) -> R { + unsafe { + let sync_channel = &mut *(self.sync_channel.get()); + let mutex = &mut sync_channel.mutex; + let mut state = &mut sync_channel.state; + mutex.lock(|_| f(&mut state)) + } } } @@ -672,15 +675,12 @@ mod tests { use super::*; - fn capacity(c: &Channel) -> usize - where - M: Mutex, - { - if !c.state.full { - if c.state.write_pos > c.state.read_pos { - (c.state.buf.len() - c.state.write_pos) + c.state.read_pos + fn capacity(c: &ChannelState) -> usize { + if !c.full { + if c.write_pos > c.read_pos { + (c.buf.len() - c.write_pos) + c.read_pos } else { - (c.state.buf.len() - c.state.read_pos) + c.state.write_pos + (c.buf.len() - c.read_pos) + c.write_pos } } else { 0 @@ -689,14 +689,14 @@ mod tests { #[test] fn sending_once() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); assert!(c.try_send(1).is_ok()); assert_eq!(capacity(&c), 2); } #[test] fn sending_when_full() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); let _ = c.try_send(1); let _ = c.try_send(1); let _ = c.try_send(1); @@ -709,8 +709,8 @@ mod tests { #[test] fn sending_when_closed() { - let mut c = Channel::::with_no_threads(); - c.state.closed = true; + let mut c = ChannelState::::new(); + c.closed = true; match c.try_send(2) { Err(TrySendError::Closed(2)) => assert!(true), _ => assert!(false), @@ -719,7 +719,7 @@ mod tests { #[test] fn receiving_once_with_one_send() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); assert!(c.try_send(1).is_ok()); assert_eq!(c.try_recv().unwrap(), 1); assert_eq!(capacity(&c), 3); @@ -727,7 +727,7 @@ mod tests { #[test] fn receiving_when_empty() { - let mut c = Channel::::with_no_threads(); + let mut c = ChannelState::::new(); match c.try_recv() { Err(TryRecvError::Empty) => assert!(true), _ => assert!(false), @@ -737,8 +737,8 @@ mod tests { #[test] fn receiving_when_closed() { - let mut c = Channel::::with_no_threads(); - c.state.closed = true; + let mut c = ChannelState::::new(); + c.closed = true; match c.try_recv() { Err(TryRecvError::Closed) => assert!(true), _ => assert!(false),