From 86706bdc1438cdccf017d073f9b8f8ff2a0322fe Mon Sep 17 00:00:00 2001
From: Caio <c410.f3r@gmail.com>
Date: Sun, 14 Apr 2024 19:35:59 -0300
Subject: [PATCH] Add map method

---
 embassy-sync/src/mutex.rs | 133 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 133 insertions(+)

diff --git a/embassy-sync/src/mutex.rs b/embassy-sync/src/mutex.rs
index 72459d660..b48a408c4 100644
--- a/embassy-sync/src/mutex.rs
+++ b/embassy-sync/src/mutex.rs
@@ -3,6 +3,7 @@
 //! This module provides a mutex that can be used to synchronize data between asynchronous tasks.
 use core::cell::{RefCell, UnsafeCell};
 use core::future::poll_fn;
+use core::mem;
 use core::ops::{Deref, DerefMut};
 use core::task::Poll;
 
@@ -134,6 +135,7 @@ where
 /// successfully locked the mutex, and grants access to the contents.
 ///
 /// Dropping it unlocks the mutex.
+#[clippy::has_significant_drop]
 pub struct MutexGuard<'a, M, T>
 where
     M: RawMutex,
@@ -142,6 +144,25 @@ where
     mutex: &'a Mutex<M, T>,
 }
 
+impl<'a, M, T> MutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    /// Returns a locked view over a portion of the locked data.
+    pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
+        let mutex = this.mutex;
+        let value = fun(unsafe { &mut *this.mutex.inner.get() });
+        // Don't run the `drop` method for MutexGuard. The ownership of the underlying
+        // locked state is being moved to the returned MappedMutexGuard.
+        mem::forget(this);
+        MappedMutexGuard {
+            state: &mutex.state,
+            value,
+        }
+    }
+}
+
 impl<'a, M, T> Drop for MutexGuard<'a, M, T>
 where
     M: RawMutex,
@@ -180,3 +201,115 @@ where
         unsafe { &mut *(self.mutex.inner.get()) }
     }
 }
+
+/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`] or
+/// [`MappedMutexGuard::map`].
+///
+/// This can be used to hold a subfield of the protected data.
+#[clippy::has_significant_drop]
+pub struct MappedMutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    state: &'a BlockingMutex<M, RefCell<State>>,
+    value: *mut T,
+}
+
+impl<'a, M, T> MappedMutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    /// Returns a locked view over a portion of the locked data.
+    pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
+        let state = this.state;
+        let value = fun(unsafe { &mut *this.value });
+        // Don't run the `drop` method for MutexGuard. The ownership of the underlying
+        // locked state is being moved to the returned MappedMutexGuard.
+        mem::forget(this);
+        MappedMutexGuard { state, value }
+    }
+}
+
+impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    type Target = T;
+    fn deref(&self) -> &Self::Target {
+        // Safety: the MutexGuard represents exclusive access to the contents
+        // of the mutex, so it's OK to get it.
+        unsafe { &*self.value }
+    }
+}
+
+impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        // Safety: the MutexGuard represents exclusive access to the contents
+        // of the mutex, so it's OK to get it.
+        unsafe { &mut *self.value }
+    }
+}
+
+impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T>
+where
+    M: RawMutex,
+    T: ?Sized,
+{
+    fn drop(&mut self) {
+        self.state.lock(|s| {
+            let mut s = unwrap!(s.try_borrow_mut());
+            s.locked = false;
+            s.waker.wake();
+        })
+    }
+}
+
+unsafe impl<M, T> Send for MappedMutexGuard<'_, M, T>
+where
+    M: RawMutex + Sync,
+    T: Send + ?Sized,
+{
+}
+
+unsafe impl<M, T> Sync for MappedMutexGuard<'_, M, T>
+where
+    M: RawMutex + Sync,
+    T: Sync + ?Sized,
+{
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::blocking_mutex::raw::NoopRawMutex;
+    use crate::mutex::{Mutex, MutexGuard};
+
+    #[futures_test::test]
+    async fn mapped_guard_releases_lock_when_dropped() {
+        let mutex: Mutex<NoopRawMutex, [i32; 2]> = Mutex::new([0, 1]);
+
+        {
+            let guard = mutex.lock().await;
+            assert_eq!(*guard, [0, 1]);
+            let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
+            assert_eq!(*mapped, 1);
+            *mapped = 2;
+        }
+
+        {
+            let guard = mutex.lock().await;
+            assert_eq!(*guard, [0, 2]);
+            let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
+            assert_eq!(*mapped, 2);
+            *mapped = 3;
+        }
+
+        assert_eq!(*mutex.lock().await, [0, 3]);
+    }
+}