From 2afff617f652d0fdcfa9ffcfd49062e3d2c996a3 Mon Sep 17 00:00:00 2001
From: Ulf Lilleengen <ulf.lilleengen@gmail.com>
Date: Tue, 19 Apr 2022 14:42:38 +0200
Subject: [PATCH] Support multiple flash instances in embassy-boot

* Add FlashProvider and FlashConfig traits to define flash
characteristics
* Use traits in bootloader to retrieve flash handles and for
copying data between flash instances
* Add convenience implementations for using a single flash instance.
---
 embassy-boot/boot/src/lib.rs        | 234 +++++++++++++++++++++++-----
 embassy-boot/nrf/.cargo/config.toml |   1 -
 embassy-boot/nrf/Cargo.toml         |   2 +-
 embassy-boot/nrf/src/lib.rs         |   6 +-
 embassy-boot/nrf/src/main.rs        |   6 +-
 examples/boot/.cargo/config.toml    |   1 -
 examples/boot/Cargo.toml            |   4 +-
 examples/boot/src/bin/a.rs          |   5 +-
 8 files changed, 207 insertions(+), 52 deletions(-)

diff --git a/embassy-boot/boot/src/lib.rs b/embassy-boot/boot/src/lib.rs
index 6f31e280d..0d33ad1a6 100644
--- a/embassy-boot/boot/src/lib.rs
+++ b/embassy-boot/boot/src/lib.rs
@@ -14,7 +14,7 @@
 ///!
 mod fmt;
 
-use embedded_storage::nor_flash::{NorFlash, ReadNorFlash};
+use embedded_storage::nor_flash::{NorFlash, NorFlashError, NorFlashErrorKind, ReadNorFlash};
 use embedded_storage_async::nor_flash::AsyncNorFlash;
 
 pub const BOOT_MAGIC: u32 = 0xD00DF00D;
@@ -44,18 +44,41 @@ pub enum State {
 }
 
 #[derive(PartialEq, Debug)]
-#[cfg_attr(feature = "defmt", derive(defmt::Format))]
-pub enum BootError<E> {
-    Flash(E),
+pub enum BootError {
+    Flash(NorFlashErrorKind),
     BadMagic,
 }
 
-impl<E> From<E> for BootError<E> {
+impl<E> From<E> for BootError
+where
+    E: NorFlashError,
+{
     fn from(error: E) -> Self {
-        BootError::Flash(error)
+        BootError::Flash(error.kind())
     }
 }
 
+pub trait FlashConfig {
+    const BLOCK_SIZE: usize;
+    type FLASH: NorFlash + ReadNorFlash;
+
+    fn flash(&mut self) -> &mut Self::FLASH;
+}
+
+/// Trait defining the flash handles used for active and DFU partition
+pub trait FlashProvider {
+    type STATE: FlashConfig;
+    type ACTIVE: FlashConfig;
+    type DFU: FlashConfig;
+
+    /// Return flash instance used to write/read to/from active partition.
+    fn active(&mut self) -> &mut Self::ACTIVE;
+    /// Return flash instance used to write/read to/from dfu partition.
+    fn dfu(&mut self) -> &mut Self::DFU;
+    /// Return flash instance used to write/read to/from bootloader state.
+    fn state(&mut self) -> &mut Self::STATE;
+}
+
 /// BootLoader works with any flash implementing embedded_storage and can also work with
 /// different page sizes.
 pub struct BootLoader<const PAGE_SIZE: usize> {
@@ -168,29 +191,27 @@ impl<const PAGE_SIZE: usize> BootLoader<PAGE_SIZE> {
     /// |       DFU |            3 |      3 |      2 |      1 |      3 |
     /// +-----------+--------------+--------+--------+--------+--------+
     ///
-    pub fn prepare_boot<F: NorFlash + ReadNorFlash>(
-        &mut self,
-        flash: &mut F,
-    ) -> Result<State, BootError<F::Error>> {
+    pub fn prepare_boot<P: FlashProvider>(&mut self, p: &mut P) -> Result<State, BootError> {
         // Copy contents from partition N to active
-        let state = self.read_state(flash)?;
+        let state = self.read_state(p.state())?;
         match state {
             State::Swap => {
                 //
                 // Check if we already swapped. If we're in the swap state, this means we should revert
                 // since the app has failed to mark boot as successful
                 //
-                if !self.is_swapped(flash)? {
+                if !self.is_swapped(p.state())? {
                     trace!("Swapping");
-                    self.swap(flash)?;
+                    self.swap(p)?;
                 } else {
                     trace!("Reverting");
-                    self.revert(flash)?;
+                    self.revert(p)?;
 
                     // Overwrite magic and reset progress
-                    flash.write(self.state.from as u32, &[0, 0, 0, 0])?;
-                    flash.erase(self.state.from as u32, self.state.to as u32)?;
-                    flash.write(self.state.from as u32, &BOOT_MAGIC.to_le_bytes())?;
+                    let fstate = p.state().flash();
+                    fstate.write(self.state.from as u32, &[0, 0, 0, 0])?;
+                    fstate.erase(self.state.from as u32, self.state.to as u32)?;
+                    fstate.write(self.state.from as u32, &BOOT_MAGIC.to_le_bytes())?;
                 }
             }
             _ => {}
@@ -198,15 +219,16 @@ impl<const PAGE_SIZE: usize> BootLoader<PAGE_SIZE> {
         Ok(state)
     }
 
-    fn is_swapped<F: ReadNorFlash>(&mut self, flash: &mut F) -> Result<bool, F::Error> {
+    fn is_swapped<P: FlashConfig>(&mut self, p: &mut P) -> Result<bool, BootError> {
         let page_count = self.active.len() / PAGE_SIZE;
-        let progress = self.current_progress(flash)?;
+        let progress = self.current_progress(p)?;
 
         Ok(progress >= page_count * 2)
     }
 
-    fn current_progress<F: ReadNorFlash>(&mut self, flash: &mut F) -> Result<usize, F::Error> {
+    fn current_progress<P: FlashConfig>(&mut self, p: &mut P) -> Result<usize, BootError> {
         let max_index = ((self.state.len() - 4) / 4) - 1;
+        let flash = p.flash();
         for i in 0..max_index {
             let mut buf: [u8; 4] = [0; 4];
             flash.read((self.state.from + 4 + i * 4) as u32, &mut buf)?;
@@ -217,7 +239,8 @@ impl<const PAGE_SIZE: usize> BootLoader<PAGE_SIZE> {
         Ok(max_index)
     }
 
-    fn update_progress<F: NorFlash>(&mut self, idx: usize, flash: &mut F) -> Result<(), F::Error> {
+    fn update_progress<P: FlashConfig>(&mut self, idx: usize, p: &mut P) -> Result<(), BootError> {
+        let flash = p.flash();
         let w = self.state.from + 4 + idx * 4;
         flash.write(w as u32, &[0, 0, 0, 0])?;
         Ok(())
@@ -231,62 +254,104 @@ impl<const PAGE_SIZE: usize> BootLoader<PAGE_SIZE> {
         self.dfu.from + n * PAGE_SIZE
     }
 
-    fn copy_page_once<F: NorFlash + ReadNorFlash>(
+    fn copy_page_once_to_active<P: FlashProvider>(
         &mut self,
         idx: usize,
-        from: usize,
-        to: usize,
-        flash: &mut F,
-    ) -> Result<(), F::Error> {
+        from_page: usize,
+        to_page: usize,
+        p: &mut P,
+    ) -> Result<(), BootError> {
         let mut buf: [u8; PAGE_SIZE] = [0; PAGE_SIZE];
-        if self.current_progress(flash)? <= idx {
-            flash.read(from as u32, &mut buf)?;
-            flash.erase(to as u32, (to + PAGE_SIZE) as u32)?;
-            flash.write(to as u32, &buf)?;
-            self.update_progress(idx, flash)?;
+        if self.current_progress(p.state())? <= idx {
+            let mut offset = from_page;
+            for chunk in buf.chunks_mut(P::DFU::BLOCK_SIZE) {
+                p.dfu().flash().read(offset as u32, chunk)?;
+                offset += chunk.len();
+            }
+
+            p.active()
+                .flash()
+                .erase(to_page as u32, (to_page + PAGE_SIZE) as u32)?;
+
+            let mut offset = to_page;
+            for chunk in buf.chunks(P::ACTIVE::BLOCK_SIZE) {
+                p.active().flash().write(offset as u32, &chunk)?;
+                offset += chunk.len();
+            }
+            self.update_progress(idx, p.state())?;
         }
         Ok(())
     }
 
-    fn swap<F: NorFlash + ReadNorFlash>(&mut self, flash: &mut F) -> Result<(), F::Error> {
+    fn copy_page_once_to_dfu<P: FlashProvider>(
+        &mut self,
+        idx: usize,
+        from_page: usize,
+        to_page: usize,
+        p: &mut P,
+    ) -> Result<(), BootError> {
+        let mut buf: [u8; PAGE_SIZE] = [0; PAGE_SIZE];
+        if self.current_progress(p.state())? <= idx {
+            let mut offset = from_page;
+            for chunk in buf.chunks_mut(P::ACTIVE::BLOCK_SIZE) {
+                p.active().flash().read(offset as u32, chunk)?;
+                offset += chunk.len();
+            }
+
+            p.dfu()
+                .flash()
+                .erase(to_page as u32, (to_page + PAGE_SIZE) as u32)?;
+
+            let mut offset = to_page;
+            for chunk in buf.chunks(P::DFU::BLOCK_SIZE) {
+                p.dfu().flash().write(offset as u32, chunk)?;
+                offset += chunk.len();
+            }
+            self.update_progress(idx, p.state())?;
+        }
+        Ok(())
+    }
+
+    fn swap<P: FlashProvider>(&mut self, p: &mut P) -> Result<(), BootError> {
         let page_count = self.active.len() / PAGE_SIZE;
         // trace!("Page count: {}", page_count);
         for page in 0..page_count {
             // Copy active page to the 'next' DFU page.
             let active_page = self.active_addr(page_count - 1 - page);
             let dfu_page = self.dfu_addr(page_count - page);
-            // info!("Copy active {} to dfu {}", active_page, dfu_page);
-            self.copy_page_once(page * 2, active_page, dfu_page, flash)?;
+            info!("Copy active {} to dfu {}", active_page, dfu_page);
+            self.copy_page_once_to_dfu(page * 2, active_page, dfu_page, p)?;
 
             // Copy DFU page to the active page
             let active_page = self.active_addr(page_count - 1 - page);
             let dfu_page = self.dfu_addr(page_count - 1 - page);
-            //info!("Copy dfy {} to active {}", dfu_page, active_page);
-            self.copy_page_once(page * 2 + 1, dfu_page, active_page, flash)?;
+            info!("Copy dfy {} to active {}", dfu_page, active_page);
+            self.copy_page_once_to_active(page * 2 + 1, dfu_page, active_page, p)?;
         }
 
         Ok(())
     }
 
-    fn revert<F: NorFlash + ReadNorFlash>(&mut self, flash: &mut F) -> Result<(), F::Error> {
+    fn revert<P: FlashProvider>(&mut self, p: &mut P) -> Result<(), BootError> {
         let page_count = self.active.len() / PAGE_SIZE;
         for page in 0..page_count {
             // Copy the bad active page to the DFU page
             let active_page = self.active_addr(page);
             let dfu_page = self.dfu_addr(page);
-            self.copy_page_once(page_count * 2 + page * 2, active_page, dfu_page, flash)?;
+            self.copy_page_once_to_dfu(page_count * 2 + page * 2, active_page, dfu_page, p)?;
 
             // Copy the DFU page back to the active page
             let active_page = self.active_addr(page);
             let dfu_page = self.dfu_addr(page + 1);
-            self.copy_page_once(page_count * 2 + page * 2 + 1, dfu_page, active_page, flash)?;
+            self.copy_page_once_to_active(page_count * 2 + page * 2 + 1, dfu_page, active_page, p)?;
         }
 
         Ok(())
     }
 
-    fn read_state<F: ReadNorFlash>(&mut self, flash: &mut F) -> Result<State, BootError<F::Error>> {
+    fn read_state<P: FlashConfig>(&mut self, p: &mut P) -> Result<State, BootError> {
         let mut magic: [u8; 4] = [0; 4];
+        let flash = p.flash();
         flash.read(self.state.from as u32, &mut magic)?;
 
         match u32::from_le_bytes(magic) {
@@ -296,6 +361,62 @@ impl<const PAGE_SIZE: usize> BootLoader<PAGE_SIZE> {
     }
 }
 
+/// Convenience provider that uses a single flash for everything
+pub struct SingleFlashProvider<'a, F>
+where
+    F: NorFlash + ReadNorFlash,
+{
+    config: SingleFlashConfig<'a, F>,
+}
+
+impl<'a, F> SingleFlashProvider<'a, F>
+where
+    F: NorFlash + ReadNorFlash,
+{
+    pub fn new(flash: &'a mut F) -> Self {
+        Self {
+            config: SingleFlashConfig { flash },
+        }
+    }
+}
+
+pub struct SingleFlashConfig<'a, F>
+where
+    F: NorFlash + ReadNorFlash,
+{
+    flash: &'a mut F,
+}
+
+impl<'a, F> FlashProvider for SingleFlashProvider<'a, F>
+where
+    F: NorFlash + ReadNorFlash,
+{
+    type STATE = SingleFlashConfig<'a, F>;
+    type ACTIVE = SingleFlashConfig<'a, F>;
+    type DFU = SingleFlashConfig<'a, F>;
+
+    fn active(&mut self) -> &mut Self::STATE {
+        &mut self.config
+    }
+    fn dfu(&mut self) -> &mut Self::ACTIVE {
+        &mut self.config
+    }
+    fn state(&mut self) -> &mut Self::DFU {
+        &mut self.config
+    }
+}
+
+impl<'a, F> FlashConfig for SingleFlashConfig<'a, F>
+where
+    F: NorFlash + ReadNorFlash,
+{
+    const BLOCK_SIZE: usize = F::ERASE_SIZE;
+    type FLASH = F;
+    fn flash(&mut self) -> &mut F {
+        self.flash
+    }
+}
+
 /// FirmwareUpdater is an application API for interacting with the BootLoader without the ability to
 /// 'mess up' the internal bootloader state
 pub struct FirmwareUpdater {
@@ -371,7 +492,10 @@ impl FirmwareUpdater {
         offset: usize,
         data: &[u8],
         flash: &mut F,
+        block_size: usize,
     ) -> Result<(), F::Error> {
+        assert!(data.len() >= F::ERASE_SIZE);
+
         trace!(
             "Writing firmware at offset 0x{:x} len {}",
             self.dfu.from + offset,
@@ -384,7 +508,35 @@ impl FirmwareUpdater {
                 (self.dfu.from + offset + data.len()) as u32,
             )
             .await?;
-        flash.write((self.dfu.from + offset) as u32, data).await
+
+        trace!(
+            "Erased from {} to {}",
+            self.dfu.from + offset,
+            self.dfu.from + offset + data.len()
+        );
+
+        let mut write_offset = self.dfu.from + offset;
+        for chunk in data.chunks(block_size) {
+            trace!("Wrote chunk at {}: {:?}", write_offset, chunk);
+            flash.write(write_offset as u32, chunk).await?;
+            write_offset += chunk.len();
+        }
+        /*
+        trace!("Wrote data, reading back for verification");
+
+        let mut buf: [u8; 4096] = [0; 4096];
+        let mut data_offset = 0;
+        let mut read_offset = self.dfu.from + offset;
+        for chunk in buf.chunks_mut(block_size) {
+            flash.read(read_offset as u32, chunk).await?;
+            trace!("Read chunk at {}: {:?}", read_offset, chunk);
+            assert_eq!(&data[data_offset..data_offset + block_size], chunk);
+            read_offset += chunk.len();
+            data_offset += chunk.len();
+        }
+        */
+
+        Ok(())
     }
 }
 
diff --git a/embassy-boot/nrf/.cargo/config.toml b/embassy-boot/nrf/.cargo/config.toml
index c3957b866..27bc9708c 100644
--- a/embassy-boot/nrf/.cargo/config.toml
+++ b/embassy-boot/nrf/.cargo/config.toml
@@ -1,5 +1,4 @@
 [unstable]
-namespaced-features = true
 build-std = ["core"]
 build-std-features = ["panic_immediate_abort"]
 
diff --git a/embassy-boot/nrf/Cargo.toml b/embassy-boot/nrf/Cargo.toml
index 512e7d378..97207ac29 100644
--- a/embassy-boot/nrf/Cargo.toml
+++ b/embassy-boot/nrf/Cargo.toml
@@ -12,7 +12,7 @@ defmt = { version = "0.3", optional = true }
 defmt-rtt = { version = "0.3", optional = true }
 
 embassy = { path = "../../embassy", default-features = false }
-embassy-nrf = { path = "../../embassy-nrf", default-features = false }
+embassy-nrf = { path = "../../embassy-nrf", default-features = false, features = ["nightly"] }
 embassy-boot = { path = "../boot", default-features = false }
 cortex-m = { version = "0.7" }
 cortex-m-rt = { version = "0.7" }
diff --git a/embassy-boot/nrf/src/lib.rs b/embassy-boot/nrf/src/lib.rs
index 32250b2db..785cb67e8 100644
--- a/embassy-boot/nrf/src/lib.rs
+++ b/embassy-boot/nrf/src/lib.rs
@@ -4,7 +4,9 @@
 
 mod fmt;
 
-pub use embassy_boot::{FirmwareUpdater, Partition, State, BOOT_MAGIC};
+pub use embassy_boot::{
+    FirmwareUpdater, FlashProvider, Partition, SingleFlashProvider, State, BOOT_MAGIC,
+};
 use embassy_nrf::{
     nvmc::{Nvmc, PAGE_SIZE},
     peripherals::WDT,
@@ -62,7 +64,7 @@ impl BootLoader {
     }
 
     /// Boots the application without softdevice mechanisms
-    pub fn prepare<F: NorFlash + ReadNorFlash>(&mut self, flash: &mut F) -> usize {
+    pub fn prepare<F: FlashProvider>(&mut self, flash: &mut F) -> usize {
         match self.boot.prepare_boot(flash) {
             Ok(_) => self.boot.boot_address(),
             Err(_) => panic!("boot prepare error!"),
diff --git a/embassy-boot/nrf/src/main.rs b/embassy-boot/nrf/src/main.rs
index cd264d4c2..63de7c869 100644
--- a/embassy-boot/nrf/src/main.rs
+++ b/embassy-boot/nrf/src/main.rs
@@ -22,7 +22,11 @@ fn main() -> ! {
     */
 
     let mut bl = BootLoader::default();
-    let start = bl.prepare(&mut WatchdogFlash::start(Nvmc::new(p.NVMC), p.WDT, 5));
+    let start = bl.prepare(&mut SingleFlashProvider::new(&mut WatchdogFlash::start(
+        Nvmc::new(p.NVMC),
+        p.WDT,
+        5,
+    )));
     unsafe { bl.load(start) }
 }
 
diff --git a/examples/boot/.cargo/config.toml b/examples/boot/.cargo/config.toml
index d044e9b4c..bbe06fd03 100644
--- a/examples/boot/.cargo/config.toml
+++ b/examples/boot/.cargo/config.toml
@@ -1,5 +1,4 @@
 [unstable]
-namespaced-features = true
 build-std = ["core"]
 build-std-features = ["panic_immediate_abort"]
 
diff --git a/examples/boot/Cargo.toml b/examples/boot/Cargo.toml
index 36e2e169d..2da659478 100644
--- a/examples/boot/Cargo.toml
+++ b/examples/boot/Cargo.toml
@@ -5,8 +5,8 @@ name = "embassy-boot-examples"
 version = "0.1.0"
 
 [dependencies]
-embassy = { version = "0.1.0", path = "../../embassy" }
-embassy-nrf = { version = "0.1.0", path = "../../embassy-nrf", features = ["time-driver-rtc1", "gpiote"] }
+embassy = { version = "0.1.0", path = "../../embassy", features = ["nightly"] }
+embassy-nrf = { version = "0.1.0", path = "../../embassy-nrf", features = ["time-driver-rtc1", "gpiote", "nightly"] }
 embassy-boot-nrf = { version = "0.1.0", path = "../../embassy-boot/nrf" }
 embassy-traits = { version = "0.1.0", path = "../../embassy-traits" }
 
diff --git a/examples/boot/src/bin/a.rs b/examples/boot/src/bin/a.rs
index 88880e688..d18b508cc 100644
--- a/examples/boot/src/bin/a.rs
+++ b/examples/boot/src/bin/a.rs
@@ -12,7 +12,6 @@ use embassy_nrf::{
     Peripherals,
 };
 use embassy_traits::adapter::BlockingAsync;
-use embedded_hal::digital::v2::InputPin;
 use panic_reset as _;
 
 static APP_B: &[u8] = include_bytes!("../../b.bin");
@@ -29,14 +28,14 @@ async fn main(_s: embassy::executor::Spawner, p: Peripherals) {
 
     loop {
         button.wait_for_any_edge().await;
-        if button.is_low().unwrap() {
+        if button.is_low() {
             let mut updater = updater::new();
             let mut offset = 0;
             for chunk in APP_B.chunks(4096) {
                 let mut buf: [u8; 4096] = [0; 4096];
                 buf[..chunk.len()].copy_from_slice(chunk);
                 updater
-                    .write_firmware(offset, &buf, &mut nvmc)
+                    .write_firmware(offset, &buf, &mut nvmc, 4096)
                     .await
                     .unwrap();
                 offset += chunk.len();