From 62e799da09e144c9bd2bc3935011913e62c86d16 Mon Sep 17 00:00:00 2001
From: Rasmus Melchior Jacobsen <rmja@laesoe.org>
Date: Fri, 26 May 2023 21:40:12 +0200
Subject: [PATCH] Create flash partition for shared flash access

---
 embassy-embedded-hal/Cargo.toml               |   1 +
 .../src/adapter/yielding_async.rs             |  62 +-------
 .../src/{flash.rs => flash/concat_flash.rs}   |  86 ++--------
 embassy-embedded-hal/src/flash/mem_flash.rs   | 127 +++++++++++++++
 embassy-embedded-hal/src/flash/mod.rs         |   9 ++
 embassy-embedded-hal/src/flash/partition.rs   | 150 ++++++++++++++++++
 6 files changed, 308 insertions(+), 127 deletions(-)
 rename embassy-embedded-hal/src/{flash.rs => flash/concat_flash.rs} (72%)
 create mode 100644 embassy-embedded-hal/src/flash/mem_flash.rs
 create mode 100644 embassy-embedded-hal/src/flash/mod.rs
 create mode 100644 embassy-embedded-hal/src/flash/partition.rs

diff --git a/embassy-embedded-hal/Cargo.toml b/embassy-embedded-hal/Cargo.toml
index ad2f14568..35c70bb63 100644
--- a/embassy-embedded-hal/Cargo.toml
+++ b/embassy-embedded-hal/Cargo.toml
@@ -31,4 +31,5 @@ nb = "1.0.0"
 defmt = { version = "0.3", optional = true }
 
 [dev-dependencies]
+critical-section = { version = "1.1.1", features = ["std"] }
 futures-test = "0.3.17"
diff --git a/embassy-embedded-hal/src/adapter/yielding_async.rs b/embassy-embedded-hal/src/adapter/yielding_async.rs
index 96d5cca8e..f51e4076f 100644
--- a/embassy-embedded-hal/src/adapter/yielding_async.rs
+++ b/embassy-embedded-hal/src/adapter/yielding_async.rs
@@ -167,66 +167,18 @@ mod tests {
     use embedded_storage_async::nor_flash::NorFlash;
 
     use super::*;
-
-    extern crate std;
-
-    #[derive(Default)]
-    struct FakeFlash(Vec<(u32, u32)>);
-
-    impl embedded_storage::nor_flash::ErrorType for FakeFlash {
-        type Error = std::convert::Infallible;
-    }
-
-    impl embedded_storage_async::nor_flash::ReadNorFlash for FakeFlash {
-        const READ_SIZE: usize = 1;
-
-        async fn read(&mut self, _offset: u32, _bytes: &mut [u8]) -> Result<(), Self::Error> {
-            unimplemented!()
-        }
-
-        fn capacity(&self) -> usize {
-            unimplemented!()
-        }
-    }
-
-    impl embedded_storage_async::nor_flash::NorFlash for FakeFlash {
-        const WRITE_SIZE: usize = 4;
-        const ERASE_SIZE: usize = 128;
-
-        async fn write(&mut self, _offset: u32, _bytes: &[u8]) -> Result<(), Self::Error> {
-            unimplemented!()
-        }
-
-        async fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
-            self.0.push((from, to));
-            Ok(())
-        }
-    }
+    use crate::flash::mem_flash::MemFlash;
 
     #[futures_test::test]
     async fn can_erase() {
-        let fake = FakeFlash::default();
-        let mut yielding = YieldingAsync::new(fake);
+        let flash = MemFlash::<1024, 128, 4>::new(0x00);
+        let mut yielding = YieldingAsync::new(flash);
 
         yielding.erase(0, 256).await.unwrap();
 
-        let fake = yielding.wrapped;
-        assert_eq!(2, fake.0.len());
-        assert_eq!((0, 128), fake.0[0]);
-        assert_eq!((128, 256), fake.0[1]);
-    }
-
-    #[futures_test::test]
-    async fn can_erase_wrong_erase_size() {
-        let fake = FakeFlash::default();
-        let mut yielding = YieldingAsync::new(fake);
-
-        yielding.erase(0, 257).await.unwrap();
-
-        let fake = yielding.wrapped;
-        assert_eq!(3, fake.0.len());
-        assert_eq!((0, 128), fake.0[0]);
-        assert_eq!((128, 256), fake.0[1]);
-        assert_eq!((256, 257), fake.0[2]);
+        let flash = yielding.wrapped;
+        assert_eq!(2, flash.erases.len());
+        assert_eq!((0, 128), flash.erases[0]);
+        assert_eq!((128, 256), flash.erases[1]);
     }
 }
diff --git a/embassy-embedded-hal/src/flash.rs b/embassy-embedded-hal/src/flash/concat_flash.rs
similarity index 72%
rename from embassy-embedded-hal/src/flash.rs
rename to embassy-embedded-hal/src/flash/concat_flash.rs
index 9a6e4bd92..1ea84269c 100644
--- a/embassy-embedded-hal/src/flash.rs
+++ b/embassy-embedded-hal/src/flash/concat_flash.rs
@@ -1,5 +1,3 @@
-//! Utilities related to flash.
-
 use embedded_storage::nor_flash::{ErrorType, NorFlash, NorFlashError, ReadNorFlash};
 #[cfg(feature = "nightly")]
 use embedded_storage_async::nor_flash::{NorFlash as AsyncNorFlash, ReadNorFlash as AsyncReadNorFlash};
@@ -192,18 +190,21 @@ where
 
 #[cfg(test)]
 mod tests {
-    use super::*;
+    use embedded_storage::nor_flash::{NorFlash, ReadNorFlash};
+
+    use super::ConcatFlash;
+    use crate::flash::mem_flash::MemFlash;
 
     #[test]
     fn can_write_and_read_across_flashes() {
-        let first = MemFlash::<64, 16, 4>::new();
-        let second = MemFlash::<64, 64, 4>::new();
+        let first = MemFlash::<64, 16, 4>::default();
+        let second = MemFlash::<64, 64, 4>::default();
         let mut f = ConcatFlash::new(first, second);
 
         f.write(60, &[0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88]).unwrap();
 
-        assert_eq!(&[0x11, 0x22, 0x33, 0x44], &f.0 .0[60..]);
-        assert_eq!(&[0x55, 0x66, 0x77, 0x88], &f.1 .0[0..4]);
+        assert_eq!(&[0x11, 0x22, 0x33, 0x44], &f.0.mem[60..]);
+        assert_eq!(&[0x55, 0x66, 0x77, 0x88], &f.1.mem[0..4]);
 
         let mut read_buf = [0; 8];
         f.read(60, &mut read_buf).unwrap();
@@ -213,74 +214,15 @@ mod tests {
 
     #[test]
     fn can_erase_across_flashes() {
-        let mut first = MemFlash::<128, 16, 4>::new();
-        let mut second = MemFlash::<128, 64, 4>::new();
-        first.0.fill(0x00);
-        second.0.fill(0x00);
-
+        let first = MemFlash::<128, 16, 4>::new(0x00);
+        let second = MemFlash::<128, 64, 4>::new(0x00);
         let mut f = ConcatFlash::new(first, second);
 
         f.erase(64, 192).unwrap();
 
-        assert_eq!(&[0x00; 64], &f.0 .0[0..64]);
-        assert_eq!(&[0xff; 64], &f.0 .0[64..128]);
-        assert_eq!(&[0xff; 64], &f.1 .0[0..64]);
-        assert_eq!(&[0x00; 64], &f.1 .0[64..128]);
-    }
-
-    pub struct MemFlash<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize>([u8; SIZE]);
-
-    impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE> {
-        pub const fn new() -> Self {
-            Self([0xff; SIZE])
-        }
-    }
-
-    impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> ErrorType
-        for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
-    {
-        type Error = core::convert::Infallible;
-    }
-
-    impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> ReadNorFlash
-        for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
-    {
-        const READ_SIZE: usize = 1;
-
-        fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
-            let len = bytes.len();
-            bytes.copy_from_slice(&self.0[offset as usize..offset as usize + len]);
-            Ok(())
-        }
-
-        fn capacity(&self) -> usize {
-            SIZE
-        }
-    }
-
-    impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> NorFlash
-        for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
-    {
-        const WRITE_SIZE: usize = WRITE_SIZE;
-        const ERASE_SIZE: usize = ERASE_SIZE;
-
-        fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
-            let from = from as usize;
-            let to = to as usize;
-            assert_eq!(0, from % ERASE_SIZE);
-            assert_eq!(0, to % ERASE_SIZE);
-            self.0[from..to].fill(0xff);
-            Ok(())
-        }
-
-        fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
-            let offset = offset as usize;
-            assert_eq!(0, bytes.len() % WRITE_SIZE);
-            assert_eq!(0, offset % WRITE_SIZE);
-            assert!(offset + bytes.len() <= SIZE);
-
-            self.0[offset..offset + bytes.len()].copy_from_slice(bytes);
-            Ok(())
-        }
+        assert_eq!(&[0x00; 64], &f.0.mem[0..64]);
+        assert_eq!(&[0xff; 64], &f.0.mem[64..128]);
+        assert_eq!(&[0xff; 64], &f.1.mem[0..64]);
+        assert_eq!(&[0x00; 64], &f.1.mem[64..128]);
     }
 }
diff --git a/embassy-embedded-hal/src/flash/mem_flash.rs b/embassy-embedded-hal/src/flash/mem_flash.rs
new file mode 100644
index 000000000..4e10627df
--- /dev/null
+++ b/embassy-embedded-hal/src/flash/mem_flash.rs
@@ -0,0 +1,127 @@
+use alloc::vec::Vec;
+
+use embedded_storage::nor_flash::{ErrorType, NorFlash, ReadNorFlash};
+use embedded_storage_async::nor_flash::{NorFlash as AsyncNorFlash, ReadNorFlash as AsyncReadNorFlash};
+
+extern crate alloc;
+
+pub(crate) struct MemFlash<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> {
+    pub mem: [u8; SIZE],
+    pub writes: Vec<(u32, usize)>,
+    pub erases: Vec<(u32, u32)>,
+}
+
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE> {
+    #[allow(unused)]
+    pub const fn new(fill: u8) -> Self {
+        Self {
+            mem: [fill; SIZE],
+            writes: Vec::new(),
+            erases: Vec::new(),
+        }
+    }
+
+    fn read(&mut self, offset: u32, bytes: &mut [u8]) {
+        let len = bytes.len();
+        bytes.copy_from_slice(&self.mem[offset as usize..offset as usize + len]);
+    }
+
+    fn write(&mut self, offset: u32, bytes: &[u8]) {
+        self.writes.push((offset, bytes.len()));
+        let offset = offset as usize;
+        assert_eq!(0, bytes.len() % WRITE_SIZE);
+        assert_eq!(0, offset % WRITE_SIZE);
+        assert!(offset + bytes.len() <= SIZE);
+
+        self.mem[offset..offset + bytes.len()].copy_from_slice(bytes);
+    }
+
+    fn erase(&mut self, from: u32, to: u32) {
+        self.erases.push((from, to));
+        let from = from as usize;
+        let to = to as usize;
+        assert_eq!(0, from % ERASE_SIZE);
+        assert_eq!(0, to % ERASE_SIZE);
+        self.mem[from..to].fill(0xff);
+    }
+}
+
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> Default
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    fn default() -> Self {
+        Self::new(0xff)
+    }
+}
+
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> ErrorType
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    type Error = core::convert::Infallible;
+}
+
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> ReadNorFlash
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    const READ_SIZE: usize = 1;
+
+    fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
+        self.read(offset, bytes);
+        Ok(())
+    }
+
+    fn capacity(&self) -> usize {
+        SIZE
+    }
+}
+
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> NorFlash
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    const WRITE_SIZE: usize = WRITE_SIZE;
+    const ERASE_SIZE: usize = ERASE_SIZE;
+
+    fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
+        self.write(offset, bytes);
+        Ok(())
+    }
+
+    fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
+        self.erase(from, to);
+        Ok(())
+    }
+}
+
+#[cfg(feature = "nightly")]
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> AsyncReadNorFlash
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    const READ_SIZE: usize = 1;
+
+    async fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
+        self.read(offset, bytes);
+        Ok(())
+    }
+
+    fn capacity(&self) -> usize {
+        SIZE
+    }
+}
+
+#[cfg(feature = "nightly")]
+impl<const SIZE: usize, const ERASE_SIZE: usize, const WRITE_SIZE: usize> AsyncNorFlash
+    for MemFlash<SIZE, ERASE_SIZE, WRITE_SIZE>
+{
+    const WRITE_SIZE: usize = WRITE_SIZE;
+    const ERASE_SIZE: usize = ERASE_SIZE;
+
+    async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
+        self.write(offset, bytes);
+        Ok(())
+    }
+
+    async fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
+        self.erase(from, to);
+        Ok(())
+    }
+}
diff --git a/embassy-embedded-hal/src/flash/mod.rs b/embassy-embedded-hal/src/flash/mod.rs
new file mode 100644
index 000000000..c80dd6aac
--- /dev/null
+++ b/embassy-embedded-hal/src/flash/mod.rs
@@ -0,0 +1,9 @@
+//! Utilities related to flash.
+
+mod concat_flash;
+#[cfg(test)]
+pub(crate) mod mem_flash;
+mod partition;
+
+pub use concat_flash::ConcatFlash;
+pub use partition::Partition;
diff --git a/embassy-embedded-hal/src/flash/partition.rs b/embassy-embedded-hal/src/flash/partition.rs
new file mode 100644
index 000000000..084425e95
--- /dev/null
+++ b/embassy-embedded-hal/src/flash/partition.rs
@@ -0,0 +1,150 @@
+use embassy_sync::blocking_mutex::raw::RawMutex;
+use embassy_sync::mutex::Mutex;
+use embedded_storage::nor_flash::{ErrorType, NorFlashError, NorFlashErrorKind};
+#[cfg(feature = "nightly")]
+use embedded_storage_async::nor_flash::{NorFlash, ReadNorFlash};
+
+/// A logical partition of an underlying shared flash
+///
+/// A partition holds an offset and a size of the flash,
+/// and is restricted to operate with that range.
+/// There is no guarantee that muliple partitions on the same flash
+/// operate on mutually exclusive ranges - such a separation is up to
+/// the user to guarantee.
+pub struct Partition<'a, M: RawMutex, T> {
+    flash: &'a Mutex<M, T>,
+    offset: u32,
+    size: u32,
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub enum Error<T> {
+    Partition,
+    OutOfBounds,
+    Flash(T),
+}
+
+impl<'a, M: RawMutex, T> Partition<'a, M, T> {
+    /// Create a new partition
+    pub const fn new(flash: &'a Mutex<M, T>, offset: u32, size: u32) -> Self {
+        Self { flash, offset, size }
+    }
+}
+
+impl<T: NorFlashError> NorFlashError for Error<T> {
+    fn kind(&self) -> NorFlashErrorKind {
+        match self {
+            Error::Partition => NorFlashErrorKind::Other,
+            Error::OutOfBounds => NorFlashErrorKind::OutOfBounds,
+            Error::Flash(f) => f.kind(),
+        }
+    }
+}
+
+impl<M: RawMutex, T: ErrorType> ErrorType for Partition<'_, M, T> {
+    type Error = Error<T::Error>;
+}
+
+#[cfg(feature = "nightly")]
+impl<M: RawMutex, T: ReadNorFlash> ReadNorFlash for Partition<'_, M, T> {
+    const READ_SIZE: usize = T::READ_SIZE;
+
+    async fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
+        if self.offset % T::READ_SIZE as u32 != 0 || self.size % T::READ_SIZE as u32 != 0 {
+            return Err(Error::Partition);
+        }
+        if offset + bytes.len() as u32 > self.size {
+            return Err(Error::OutOfBounds);
+        }
+
+        let mut flash = self.flash.lock().await;
+        flash.read(self.offset + offset, bytes).await.map_err(Error::Flash)
+    }
+
+    fn capacity(&self) -> usize {
+        self.size as usize
+    }
+}
+
+#[cfg(feature = "nightly")]
+impl<M: RawMutex, T: NorFlash> NorFlash for Partition<'_, M, T> {
+    const WRITE_SIZE: usize = T::WRITE_SIZE;
+    const ERASE_SIZE: usize = T::ERASE_SIZE;
+
+    async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
+        if self.offset % T::WRITE_SIZE as u32 != 0 || self.size % T::WRITE_SIZE as u32 != 0 {
+            return Err(Error::Partition);
+        }
+        if offset + bytes.len() as u32 > self.size {
+            return Err(Error::OutOfBounds);
+        }
+
+        let mut flash = self.flash.lock().await;
+        flash.write(self.offset + offset, bytes).await.map_err(Error::Flash)
+    }
+
+    async fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
+        if self.offset % T::ERASE_SIZE as u32 != 0 || self.size % T::ERASE_SIZE as u32 != 0 {
+            return Err(Error::Partition);
+        }
+        if to > self.size {
+            return Err(Error::OutOfBounds);
+        }
+
+        let mut flash = self.flash.lock().await;
+        flash
+            .erase(self.offset + from, self.offset + to)
+            .await
+            .map_err(Error::Flash)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use embassy_sync::blocking_mutex::raw::NoopRawMutex;
+
+    use super::*;
+    use crate::flash::mem_flash::MemFlash;
+
+    #[futures_test::test]
+    async fn can_read() {
+        let mut flash = MemFlash::<1024, 128, 4>::default();
+        flash.mem[12..20].fill(0xAA);
+
+        let flash = Mutex::<NoopRawMutex, _>::new(flash);
+        let mut partition = Partition::new(&flash, 8, 12);
+
+        let mut read_buf = [0; 8];
+        partition.read(4, &mut read_buf).await.unwrap();
+
+        assert!(read_buf.iter().position(|&x| x != 0xAA).is_none());
+    }
+
+    #[futures_test::test]
+    async fn can_write() {
+        let flash = MemFlash::<1024, 128, 4>::default();
+
+        let flash = Mutex::<NoopRawMutex, _>::new(flash);
+        let mut partition = Partition::new(&flash, 8, 12);
+
+        let write_buf = [0xAA; 8];
+        partition.write(4, &write_buf).await.unwrap();
+
+        let flash = flash.try_lock().unwrap();
+        assert!(flash.mem[12..20].iter().position(|&x| x != 0xAA).is_none());
+    }
+
+    #[futures_test::test]
+    async fn can_erase() {
+        let flash = MemFlash::<1024, 128, 4>::new(0x00);
+
+        let flash = Mutex::<NoopRawMutex, _>::new(flash);
+        let mut partition = Partition::new(&flash, 128, 256);
+
+        partition.erase(0, 128).await.unwrap();
+
+        let flash = flash.try_lock().unwrap();
+        assert!(flash.mem[128..256].iter().position(|&x| x != 0xFF).is_none());
+    }
+}