From 1c9bb7c2e1ff7d68bfa709745f935040c8b415e1 Mon Sep 17 00:00:00 2001
From: Dario Nieuwenhuis <dirbaio@dirbaio.net>
Date: Mon, 13 May 2024 00:35:46 +0200
Subject: [PATCH] time/generic-queue: fix ub in tests.

---
 embassy-time/src/queue_generic.rs | 120 +++++++++++++-----------------
 1 file changed, 53 insertions(+), 67 deletions(-)

diff --git a/embassy-time/src/queue_generic.rs b/embassy-time/src/queue_generic.rs
index cf7a986d5..4882afd3e 100644
--- a/embassy-time/src/queue_generic.rs
+++ b/embassy-time/src/queue_generic.rs
@@ -177,9 +177,10 @@ embassy_time_queue_driver::timer_queue_impl!(static QUEUE: Queue = Queue::new())
 #[cfg(test)]
 #[cfg(feature = "mock-driver")]
 mod tests {
-    use core::cell::Cell;
-    use core::task::{RawWaker, RawWakerVTable, Waker};
-    use std::rc::Rc;
+    use core::sync::atomic::{AtomicBool, Ordering};
+    use core::task::Waker;
+    use std::sync::Arc;
+    use std::task::Wake;
 
     use serial_test::serial;
 
@@ -188,42 +189,26 @@ mod tests {
     use crate::{Duration, Instant};
 
     struct TestWaker {
-        pub awoken: Rc<Cell<bool>>,
-        pub waker: Waker,
+        pub awoken: AtomicBool,
     }
 
-    impl TestWaker {
-        fn new() -> Self {
-            let flag = Rc::new(Cell::new(false));
-
-            const VTABLE: RawWakerVTable = RawWakerVTable::new(
-                |data: *const ()| {
-                    unsafe {
-                        Rc::increment_strong_count(data as *const Cell<bool>);
-                    }
-
-                    RawWaker::new(data as _, &VTABLE)
-                },
-                |data: *const ()| unsafe {
-                    let data = data as *const Cell<bool>;
-                    data.as_ref().unwrap().set(true);
-                    Rc::decrement_strong_count(data);
-                },
-                |data: *const ()| unsafe {
-                    (data as *const Cell<bool>).as_ref().unwrap().set(true);
-                },
-                |data: *const ()| unsafe {
-                    Rc::decrement_strong_count(data);
-                },
-            );
-
-            let raw = RawWaker::new(Rc::into_raw(flag.clone()) as _, &VTABLE);
-
-            Self {
-                awoken: flag.clone(),
-                waker: unsafe { Waker::from_raw(raw) },
-            }
+    impl Wake for TestWaker {
+        fn wake(self: Arc<Self>) {
+            self.awoken.store(true, Ordering::Relaxed);
         }
+
+        fn wake_by_ref(self: &Arc<Self>) {
+            self.awoken.store(true, Ordering::Relaxed);
+        }
+    }
+
+    fn test_waker() -> (Arc<TestWaker>, Waker) {
+        let arc = Arc::new(TestWaker {
+            awoken: AtomicBool::new(false),
+        });
+        let waker = Waker::from(arc.clone());
+
+        (arc, waker)
     }
 
     fn setup() {
@@ -249,11 +234,11 @@ mod tests {
 
         assert_eq!(queue_len(), 0);
 
-        let waker = TestWaker::new();
+        let (flag, waker) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(1), &waker);
 
-        assert!(!waker.awoken.get());
+        assert!(!flag.awoken.load(Ordering::Relaxed));
         assert_eq!(queue_len(), 1);
     }
 
@@ -262,23 +247,23 @@ mod tests {
     fn test_schedule_same() {
         setup();
 
-        let waker = TestWaker::new();
+        let (_flag, waker) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(1), &waker);
 
         assert_eq!(queue_len(), 1);
 
-        QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(1), &waker);
 
         assert_eq!(queue_len(), 1);
 
-        QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(100), &waker);
 
         assert_eq!(queue_len(), 1);
 
-        let waker2 = TestWaker::new();
+        let (_flag2, waker2) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(100), &waker2.waker);
+        QUEUE.schedule_wake(Instant::from_secs(100), &waker2);
 
         assert_eq!(queue_len(), 2);
     }
@@ -288,21 +273,21 @@ mod tests {
     fn test_trigger() {
         setup();
 
-        let waker = TestWaker::new();
+        let (flag, waker) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(100), &waker);
 
-        assert!(!waker.awoken.get());
+        assert!(!flag.awoken.load(Ordering::Relaxed));
 
         MockDriver::get().advance(Duration::from_secs(99));
 
-        assert!(!waker.awoken.get());
+        assert!(!flag.awoken.load(Ordering::Relaxed));
 
         assert_eq!(queue_len(), 1);
 
         MockDriver::get().advance(Duration::from_secs(1));
 
-        assert!(waker.awoken.get());
+        assert!(flag.awoken.load(Ordering::Relaxed));
 
         assert_eq!(queue_len(), 0);
     }
@@ -312,18 +297,18 @@ mod tests {
     fn test_immediate_trigger() {
         setup();
 
-        let waker = TestWaker::new();
+        let (flag, waker) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(100), &waker);
 
         MockDriver::get().advance(Duration::from_secs(50));
 
-        let waker2 = TestWaker::new();
+        let (flag2, waker2) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(40), &waker2.waker);
+        QUEUE.schedule_wake(Instant::from_secs(40), &waker2);
 
-        assert!(!waker.awoken.get());
-        assert!(waker2.awoken.get());
+        assert!(!flag.awoken.load(Ordering::Relaxed));
+        assert!(flag2.awoken.load(Ordering::Relaxed));
         assert_eq!(queue_len(), 1);
     }
 
@@ -333,30 +318,31 @@ mod tests {
         setup();
 
         for i in 1..super::QUEUE_SIZE {
-            let waker = TestWaker::new();
+            let (flag, waker) = test_waker();
 
-            QUEUE.schedule_wake(Instant::from_secs(310), &waker.waker);
+            QUEUE.schedule_wake(Instant::from_secs(310), &waker);
 
             assert_eq!(queue_len(), i);
-            assert!(!waker.awoken.get());
+            assert!(!flag.awoken.load(Ordering::Relaxed));
         }
 
-        let first_waker = TestWaker::new();
+        let (flag, waker) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(300), &first_waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(300), &waker);
 
         assert_eq!(queue_len(), super::QUEUE_SIZE);
-        assert!(!first_waker.awoken.get());
+        assert!(!flag.awoken.load(Ordering::Relaxed));
 
-        let second_waker = TestWaker::new();
+        let (flag2, waker2) = test_waker();
 
-        QUEUE.schedule_wake(Instant::from_secs(305), &second_waker.waker);
+        QUEUE.schedule_wake(Instant::from_secs(305), &waker2);
 
         assert_eq!(queue_len(), super::QUEUE_SIZE);
-        assert!(first_waker.awoken.get());
+        assert!(flag.awoken.load(Ordering::Relaxed));
 
-        QUEUE.schedule_wake(Instant::from_secs(320), &TestWaker::new().waker);
+        let (_flag3, waker3) = test_waker();
+        QUEUE.schedule_wake(Instant::from_secs(320), &waker3);
         assert_eq!(queue_len(), super::QUEUE_SIZE);
-        assert!(second_waker.awoken.get());
+        assert!(flag2.awoken.load(Ordering::Relaxed));
     }
 }