From baab52d40cb8d1ede339a3a422006108a86d8efb Mon Sep 17 00:00:00 2001
From: huntc <huntchr@gmail.com>
Date: Sun, 11 Jul 2021 13:01:36 +1000
Subject: [PATCH] Avoid a race condition by reducing the locks to one

---
 embassy/src/util/mpsc.rs | 84 +++++++++++++++++++++++++---------------
 1 file changed, 53 insertions(+), 31 deletions(-)

diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs
index 68fcdf7f9..8d534dc49 100644
--- a/embassy/src/util/mpsc.rs
+++ b/embassy/src/util/mpsc.rs
@@ -145,14 +145,11 @@ where
         futures::future::poll_fn(|cx| self.recv_poll(cx)).await
     }
 
-    fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> {
-        match self.try_recv() {
+    fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
+        match self.channel.get().try_recv_with_context(Some(cx)) {
             Ok(v) => Poll::Ready(Some(v)),
             Err(TryRecvError::Closed) => Poll::Ready(None),
-            Err(TryRecvError::Empty) => {
-                self.channel.get().set_receiver_waker(&cx.waker());
-                Poll::Pending
-            }
+            Err(TryRecvError::Empty) => Poll::Pending,
         }
     }
 
@@ -279,11 +276,15 @@ where
     type Output = Result<(), SendError<T>>;
 
     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        match self.sender.try_send(unsafe { self.message.get().read() }) {
+        match self
+            .sender
+            .channel
+            .get()
+            .try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
+        {
             Ok(..) => Poll::Ready(Ok(())),
             Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
             Err(TrySendError::Full(..)) => {
-                self.sender.channel.get().set_senders_waker(&cx.waker());
                 Poll::Pending
                 // Note we leave the existing UnsafeCell contents - they still
                 // contain the original message. We could create another UnsafeCell
@@ -307,10 +308,9 @@ where
     type Output = ();
 
     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
-        if self.sender.is_closed() {
+        if self.sender.channel.get().is_closed_with_context(Some(cx)) {
             Poll::Ready(())
         } else {
-            self.sender.channel.get().set_senders_waker(&cx.waker());
             Poll::Pending
         }
     }
@@ -513,7 +513,11 @@ where
     }
 
     fn try_recv(&mut self) -> Result<T, TryRecvError> {
-        let state = &mut self.state;
+        self.try_recv_with_context(None)
+    }
+
+    fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
+        let mut state = &mut self.state;
         self.mutex.lock(|_| {
             if !state.closed {
                 if state.read_pos != state.write_pos || state.full {
@@ -526,6 +530,8 @@ where
                     state.read_pos = (state.read_pos + 1) % state.buf.len();
                     Ok(message)
                 } else if !state.closing {
+                    cx.into_iter()
+                        .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
                     Err(TryRecvError::Empty)
                 } else {
                     state.closed = true;
@@ -539,7 +545,15 @@ where
     }
 
     fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
-        let state = &mut self.state;
+        self.try_send_with_context(message, None)
+    }
+
+    fn try_send_with_context(
+        &mut self,
+        message: T,
+        cx: Option<&mut Context<'_>>,
+    ) -> Result<(), TrySendError<T>> {
+        let mut state = &mut self.state;
         self.mutex.lock(|_| {
             if !state.closed {
                 if !state.full {
@@ -551,6 +565,8 @@ where
                     state.receiver_waker.wake();
                     Ok(())
                 } else {
+                    cx.into_iter()
+                        .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
                     Err(TrySendError::Full(message))
                 }
             } else {
@@ -568,8 +584,20 @@ where
     }
 
     fn is_closed(&mut self) -> bool {
-        let state = &self.state;
-        self.mutex.lock(|_| state.closing || state.closed)
+        self.is_closed_with_context(None)
+    }
+
+    fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
+        let mut state = &mut self.state;
+        self.mutex.lock(|_| {
+            if state.closing || state.closed {
+                cx.into_iter()
+                    .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
+                true
+            } else {
+                false
+            }
+        })
     }
 
     fn register_receiver(&mut self) {
@@ -610,25 +638,19 @@ where
         })
     }
 
-    fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
-        let state = &mut self.state;
-        self.mutex.lock(|_| {
-            state.receiver_waker.register(receiver_waker);
-        })
+    fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
+        state.receiver_waker.register(receiver_waker);
     }
 
-    fn set_senders_waker(&mut self, senders_waker: &Waker) {
-        let state = &mut self.state;
-        self.mutex.lock(|_| {
-            // Dispose of any existing sender causing them to be polled again.
-            // This could cause a spin given multiple concurrent senders, however given that
-            // most sends only block waiting for the receiver to become active, this should
-            // be a short-lived activity. The upside is a greatly simplified implementation
-            // that avoids the need for intrusive linked-lists and unsafe operations on pinned
-            // pointers.
-            state.senders_waker.wake();
-            state.senders_waker.register(senders_waker);
-        })
+    fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
+        // Dispose of any existing sender causing them to be polled again.
+        // This could cause a spin given multiple concurrent senders, however given that
+        // most sends only block waiting for the receiver to become active, this should
+        // be a short-lived activity. The upside is a greatly simplified implementation
+        // that avoids the need for intrusive linked-lists and unsafe operations on pinned
+        // pointers.
+        state.senders_waker.wake();
+        state.senders_waker.register(senders_waker);
     }
 }