From 36b363a5b71497d32d2968fcecf046773822c2e9 Mon Sep 17 00:00:00 2001
From: Dion Dokter <diondokter@gmail.com>
Date: Thu, 16 Jun 2022 13:48:26 +0200
Subject: [PATCH] Changed names of subscriber methods and implemented the
 Stream trait for it

---
 embassy/src/channel/pubsub.rs | 55 +++++++++++++++++++++++++++--------
 1 file changed, 43 insertions(+), 12 deletions(-)

diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs
index fb8d0ef5f..5d81431ec 100644
--- a/embassy/src/channel/pubsub.rs
+++ b/embassy/src/channel/pubsub.rs
@@ -270,14 +270,14 @@ pub struct Subscriber<'a, T: Clone> {
 
 impl<'a, T: Clone> Subscriber<'a, T> {
     /// Wait for a published message
-    pub fn wait<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, T> {
+    pub fn next<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, T> {
         SubscriberWaitFuture { subscriber: self }
     }
 
     /// Try to see if there's a published message we haven't received yet.
     ///
     /// This function does not peek. The message is received if there is one.
-    pub fn check(&mut self) -> Option<WaitResult<T>> {
+    pub fn try_next(&mut self) -> Option<WaitResult<T>> {
         match self.channel.get_message(self.next_message_id) {
             Some(WaitResult::Lagged(amount)) => {
                 self.next_message_id += amount;
@@ -300,6 +300,37 @@ impl<'a, T: Clone> Drop for Subscriber<'a, T> {
     }
 }
 
+impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> {
+    type Item = WaitResult<T>;
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        let this = unsafe { self.get_unchecked_mut() };
+
+        // Check if we can read a message
+        match this.channel.get_message(this.next_message_id) {
+            // Yes, so we are done polling
+            Some(WaitResult::Message(message)) => {
+                this.next_message_id += 1;
+                Poll::Ready(Some(WaitResult::Message(message)))
+            }
+            // No, so we need to reregister our waker and sleep again
+            None => {
+                unsafe {
+                    this
+                        .channel
+                        .register_subscriber_waker(this.subscriber_index, cx.waker());
+                }
+                Poll::Pending
+            }
+            // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged
+            Some(WaitResult::Lagged(amount)) => {
+                this.next_message_id += amount;
+                Poll::Ready(Some(WaitResult::Lagged(amount)))
+            }
+        }
+    }
+}
+
 /// A publisher to a channel
 ///
 /// This instance carries a reference to the channel, but uses a trait object for it so that the channel's
@@ -494,11 +525,11 @@ mod tests {
 
         pub0.publish(42).await;
 
-        assert_eq!(sub0.wait().await, WaitResult::Message(42));
-        assert_eq!(sub1.wait().await, WaitResult::Message(42));
+        assert_eq!(sub0.next().await, WaitResult::Message(42));
+        assert_eq!(sub1.next().await, WaitResult::Message(42));
 
-        assert_eq!(sub0.check(), None);
-        assert_eq!(sub1.check(), None);
+        assert_eq!(sub0.try_next(), None);
+        assert_eq!(sub1.try_next(), None);
     }
 
     #[futures_test::test]
@@ -515,12 +546,12 @@ mod tests {
         pub0.publish_immediate(46);
         pub0.publish_immediate(47);
 
-        assert_eq!(sub0.check(), Some(WaitResult::Lagged(2)));
-        assert_eq!(sub0.wait().await, WaitResult::Message(44));
-        assert_eq!(sub0.wait().await, WaitResult::Message(45));
-        assert_eq!(sub0.wait().await, WaitResult::Message(46));
-        assert_eq!(sub0.wait().await, WaitResult::Message(47));
-        assert_eq!(sub0.check(), None);
+        assert_eq!(sub0.try_next(), Some(WaitResult::Lagged(2)));
+        assert_eq!(sub0.next().await, WaitResult::Message(44));
+        assert_eq!(sub0.next().await, WaitResult::Message(45));
+        assert_eq!(sub0.next().await, WaitResult::Message(46));
+        assert_eq!(sub0.next().await, WaitResult::Message(47));
+        assert_eq!(sub0.try_next(), None);
     }
 
     #[test]