From 6e9e8eeb5f6458833b28a08e7480b2630107d79c Mon Sep 17 00:00:00 2001
From: Caleb Garrett <47389035+caleb-garrett@users.noreply.github.com>
Date: Tue, 5 Mar 2024 11:25:56 -0500
Subject: [PATCH] Refactored cryp din/dout into functions.

---
 embassy-stm32/src/cryp/mod.rs    | 276 +++++++++++++++----------------
 examples/stm32f7/src/bin/cryp.rs |  14 +-
 2 files changed, 144 insertions(+), 146 deletions(-)

diff --git a/embassy-stm32/src/cryp/mod.rs b/embassy-stm32/src/cryp/mod.rs
index 8f259520a..12353baa0 100644
--- a/embassy-stm32/src/cryp/mod.rs
+++ b/embassy-stm32/src/cryp/mod.rs
@@ -4,12 +4,35 @@ use core::cmp::min;
 use core::marker::PhantomData;
 
 use embassy_hal_internal::{into_ref, PeripheralRef};
+use embassy_sync::waitqueue::AtomicWaker;
 
-use crate::{interrupt, pac, peripherals, Peripheral};
+use crate::interrupt::typelevel::Interrupt;
+use crate::{dma::NoDma, interrupt, pac, peripherals, Peripheral};
 
 const DES_BLOCK_SIZE: usize = 8; // 64 bits
 const AES_BLOCK_SIZE: usize = 16; // 128 bits
 
+static CRYP_WAKER: AtomicWaker = AtomicWaker::new();
+
+/// CRYP interrupt handler.
+pub struct InterruptHandler<T: Instance> {
+    _phantom: PhantomData<T>,
+}
+
+impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandler<T> {
+    unsafe fn on_interrupt() {
+        let bits = T::regs().misr().read();
+        if bits.inmis() {
+            T::regs().imscr().modify(|w| w.set_inim(false));
+            CRYP_WAKER.wake();
+        }
+        if bits.outmis() {
+            T::regs().imscr().modify(|w| w.set_outim(false));
+            CRYP_WAKER.wake();
+        }
+    }
+}
+
 /// This trait encapsulates all cipher-specific behavior/
 pub trait Cipher<'c> {
     /// Processing block size. Determined by the processor and the algorithm.
@@ -32,7 +55,7 @@ pub trait Cipher<'c> {
     fn prepare_key(&self, _p: &pac::cryp::Cryp) {}
 
     /// Performs any cipher-specific initialization.
-    fn init_phase(&self, _p: &pac::cryp::Cryp) {}
+    fn init_phase<T: Instance, D>(&self, _p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {}
 
     /// Called prior to processing the last data block for cipher-specific operations.
     fn pre_final_block(&self, _p: &pac::cryp::Cryp, _dir: Direction, _padding_len: usize) -> [u32; 4] {
@@ -40,9 +63,10 @@ pub trait Cipher<'c> {
     }
 
     /// Called after processing the last data block for cipher-specific operations.
-    fn post_final_block(
+    fn post_final_block<T: Instance, D>(
         &self,
         _p: &pac::cryp::Cryp,
+        _cryp: &Cryp<T, D>,
         _dir: Direction,
         _int_data: &mut [u8; AES_BLOCK_SIZE],
         _temp1: [u32; 4],
@@ -425,7 +449,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
         p.cr().modify(|w| w.set_algomode3(true));
     }
 
-    fn init_phase(&self, p: &pac::cryp::Cryp) {
+    fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {
         p.cr().modify(|w| w.set_gcm_ccmph(0));
         p.cr().modify(|w| w.set_crypen(true));
         while p.cr().read().crypen() {}
@@ -453,9 +477,10 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
     }
 
     #[cfg(cryp_v2)]
-    fn post_final_block(
+    fn post_final_block<T: Instance, D>(
         &self,
         p: &pac::cryp::Cryp,
+        cryp: &Cryp<T, D>,
         dir: Direction,
         int_data: &mut [u8; AES_BLOCK_SIZE],
         _temp1: [u32; 4],
@@ -471,17 +496,9 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
             }
             p.cr().modify(|w| w.set_crypen(true));
             p.cr().modify(|w| w.set_gcm_ccmph(3));
-            let mut index = 0;
-            let end_index = Self::BLOCK_SIZE;
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&int_data[index..index + 4]);
-                p.din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
-            for _ in 0..4 {
-                p.dout().read();
-            }
+
+            cryp.write_bytes_blocking(Self::BLOCK_SIZE, int_data);
+            cryp.read_bytes_blocking(Self::BLOCK_SIZE, int_data);
         }
     }
 }
@@ -532,7 +549,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
         p.cr().modify(|w| w.set_algomode3(true));
     }
 
-    fn init_phase(&self, p: &pac::cryp::Cryp) {
+    fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {
         p.cr().modify(|w| w.set_gcm_ccmph(0));
         p.cr().modify(|w| w.set_crypen(true));
         while p.cr().read().crypen() {}
@@ -560,9 +577,10 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
     }
 
     #[cfg(cryp_v2)]
-    fn post_final_block(
+    fn post_final_block<T: Instance, D>(
         &self,
         p: &pac::cryp::Cryp,
+        cryp: &Cryp<T, D>,
         dir: Direction,
         int_data: &mut [u8; AES_BLOCK_SIZE],
         _temp1: [u32; 4],
@@ -578,17 +596,9 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
             }
             p.cr().modify(|w| w.set_crypen(true));
             p.cr().modify(|w| w.set_gcm_ccmph(3));
-            let mut index = 0;
-            let end_index = Self::BLOCK_SIZE;
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&int_data[index..index + 4]);
-                p.din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
-            for _ in 0..4 {
-                p.dout().read();
-            }
+
+            cryp.write_bytes_blocking(Self::BLOCK_SIZE, int_data);
+            cryp.read_bytes_blocking(Self::BLOCK_SIZE, int_data);
         }
     }
 }
@@ -697,18 +707,11 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
         p.cr().modify(|w| w.set_algomode3(true));
     }
 
-    fn init_phase(&self, p: &pac::cryp::Cryp) {
+    fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, cryp: &Cryp<T, D>) {
         p.cr().modify(|w| w.set_gcm_ccmph(0));
 
-        let mut index = 0;
-        let end_index = index + Self::BLOCK_SIZE;
-        // Write block in
-        while index < end_index {
-            let mut in_word: [u8; 4] = [0; 4];
-            in_word.copy_from_slice(&self.block0[index..index + 4]);
-            p.din().write_value(u32::from_ne_bytes(in_word));
-            index += 4;
-        }
+        cryp.write_bytes_blocking(Self::BLOCK_SIZE, &self.block0);
+
         p.cr().modify(|w| w.set_crypen(true));
         while p.cr().read().crypen() {}
     }
@@ -744,9 +747,10 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
     }
 
     #[cfg(cryp_v2)]
-    fn post_final_block(
+    fn post_final_block<T: Instance, D>(
         &self,
         p: &pac::cryp::Cryp,
+        cryp: &Cryp<T, D>,
         dir: Direction,
         int_data: &mut [u8; AES_BLOCK_SIZE],
         temp1: [u32; 4],
@@ -774,8 +778,8 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
                 let int_word = u32::from_le_bytes(int_bytes);
                 in_data[i] = int_word;
                 in_data[i] = in_data[i] ^ temp1[i] ^ temp2[i];
-                p.din().write_value(in_data[i]);
             }
+            cryp.write_words_blocking(Self::BLOCK_SIZE, &in_data);
         }
     }
 }
@@ -845,16 +849,31 @@ pub enum Direction {
 }
 
 /// Crypto Accelerator Driver
-pub struct Cryp<'d, T: Instance> {
+pub struct Cryp<'d, T: Instance, D = NoDma> {
     _peripheral: PeripheralRef<'d, T>,
+    indma: PeripheralRef<'d, D>,
+    outdma: PeripheralRef<'d, D>,
 }
 
-impl<'d, T: Instance> Cryp<'d, T> {
+impl<'d, T: Instance, D> Cryp<'d, T, D> {
     /// Create a new CRYP driver.
-    pub fn new(peri: impl Peripheral<P = T> + 'd) -> Self {
+    pub fn new(
+        peri: impl Peripheral<P = T> + 'd,
+        indma: impl Peripheral<P = D> + 'd,
+        outdma: impl Peripheral<P = D> + 'd,
+        _irq: impl interrupt::typelevel::Binding<T::Interrupt, InterruptHandler<T>> + 'd,
+    ) -> Self {
         T::enable_and_reset();
-        into_ref!(peri);
-        let instance = Self { _peripheral: peri };
+        into_ref!(peri, indma, outdma);
+        let instance = Self {
+            _peripheral: peri,
+            indma: indma,
+            outdma: outdma,
+        };
+
+        T::Interrupt::unpend();
+        unsafe { T::Interrupt::enable() };
+
         instance
     }
 
@@ -929,7 +948,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
         // Flush in/out FIFOs
         T::regs().cr().modify(|w| w.fflush());
 
-        ctx.cipher.init_phase(&T::regs());
+        ctx.cipher.init_phase(&T::regs(), self);
 
         self.store_context(&mut ctx);
 
@@ -985,15 +1004,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
         if ctx.aad_buffer_len < C::BLOCK_SIZE {
             // The buffer isn't full and this is the last buffer, so process it as is (already padded).
             if last_aad_block {
-                let mut index = 0;
-                let end_index = C::BLOCK_SIZE;
-                // Write block in
-                while index < end_index {
-                    let mut in_word: [u8; 4] = [0; 4];
-                    in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
-                    T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                    index += 4;
-                }
+                self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
                 // Block until input FIFO is empty.
                 while !T::regs().sr().read().ifem() {}
 
@@ -1008,15 +1019,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
             }
         } else {
             // Load the full block from the buffer.
-            let mut index = 0;
-            let end_index = C::BLOCK_SIZE;
-            // Write block in
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
-                T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
+            self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
             // Block until input FIFO is empty.
             while !T::regs().sr().read().ifem() {}
         }
@@ -1032,33 +1035,13 @@ impl<'d, T: Instance> Cryp<'d, T> {
 
         // Load full data blocks into core.
         let num_full_blocks = aad_len_remaining / C::BLOCK_SIZE;
-        for block in 0..num_full_blocks {
-            let mut index = len_to_copy + (block * C::BLOCK_SIZE);
-            let end_index = index + C::BLOCK_SIZE;
-            // Write block in
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&aad[index..index + 4]);
-                T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
-            // Block until input FIFO is empty.
-            while !T::regs().sr().read().ifem() {}
-        }
+        let start_index = len_to_copy;
+        let end_index = start_index + (C::BLOCK_SIZE * num_full_blocks);
+        self.write_bytes_blocking(C::BLOCK_SIZE, &aad[start_index..end_index]);
 
         if last_aad_block {
             if leftovers > 0 {
-                let mut index = 0;
-                let end_index = C::BLOCK_SIZE;
-                // Write block in
-                while index < end_index {
-                    let mut in_word: [u8; 4] = [0; 4];
-                    in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
-                    T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                    index += 4;
-                }
-                // Block until input FIFO is empty.
-                while !T::regs().sr().read().ifem() {}
+                self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
             }
             // Switch to payload phase.
             ctx.aad_complete = true;
@@ -1125,25 +1108,11 @@ impl<'d, T: Instance> Cryp<'d, T> {
         // Load data into core, block by block.
         let num_full_blocks = input.len() / C::BLOCK_SIZE;
         for block in 0..num_full_blocks {
-            let mut index = block * C::BLOCK_SIZE;
-            let end_index = index + C::BLOCK_SIZE;
+            let index = block * C::BLOCK_SIZE;
             // Write block in
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&input[index..index + 4]);
-                T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
-            let mut index = block * C::BLOCK_SIZE;
-            let end_index = index + C::BLOCK_SIZE;
-            // Block until there is output to read.
-            while !T::regs().sr().read().ofne() {}
+            self.write_bytes_blocking(C::BLOCK_SIZE, &input[index..index + 4]);
             // Read block out
-            while index < end_index {
-                let out_word: u32 = T::regs().dout().read();
-                output[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
-                index += 4;
-            }
+            self.read_bytes_blocking(C::BLOCK_SIZE, &mut output[index..index + 4]);
         }
 
         // Handle the final block, which is incomplete.
@@ -1154,25 +1123,8 @@ impl<'d, T: Instance> Cryp<'d, T> {
             let mut intermediate_data: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
             let mut last_block: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
             last_block[..last_block_remainder].copy_from_slice(&input[input.len() - last_block_remainder..input.len()]);
-            let mut index = 0;
-            let end_index = C::BLOCK_SIZE;
-            // Write block in
-            while index < end_index {
-                let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&last_block[index..index + 4]);
-                T::regs().din().write_value(u32::from_ne_bytes(in_word));
-                index += 4;
-            }
-            let mut index = 0;
-            let end_index = C::BLOCK_SIZE;
-            // Block until there is output to read.
-            while !T::regs().sr().read().ofne() {}
-            // Read block out
-            while index < end_index {
-                let out_word: u32 = T::regs().dout().read();
-                intermediate_data[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
-                index += 4;
-            }
+            self.write_bytes_blocking(C::BLOCK_SIZE, &last_block);
+            self.read_bytes_blocking(C::BLOCK_SIZE, &mut intermediate_data);
 
             // Handle the last block depending on mode.
             let output_len = output.len();
@@ -1182,7 +1134,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
             let mut mask: [u8; 16] = [0; 16];
             mask[..last_block_remainder].fill(0xFF);
             ctx.cipher
-                .post_final_block(&T::regs(), ctx.dir, &mut intermediate_data, temp1, mask);
+                .post_final_block(&T::regs(), self, ctx.dir, &mut intermediate_data, temp1, mask);
         }
 
         ctx.payload_len += input.len() as u64;
@@ -1213,28 +1165,21 @@ impl<'d, T: Instance> Cryp<'d, T> {
         let payloadlen2: u32 = (ctx.payload_len * 8) as u32;
 
         #[cfg(cryp_v2)]
-        {
-            T::regs().din().write_value(headerlen1.swap_bytes());
-            T::regs().din().write_value(headerlen2.swap_bytes());
-            T::regs().din().write_value(payloadlen1.swap_bytes());
-            T::regs().din().write_value(payloadlen2.swap_bytes());
-        }
-
+        let footer: [u32; 4] = [
+            headerlen1.swap_bytes(),
+            headerlen2.swap_bytes(),
+            payloadlen1.swap_bytes(),
+            payloadlen2.swap_bytes(),
+        ];
         #[cfg(cryp_v3)]
-        {
-            T::regs().din().write_value(headerlen1);
-            T::regs().din().write_value(headerlen2);
-            T::regs().din().write_value(payloadlen1);
-            T::regs().din().write_value(payloadlen2);
-        }
+        let footer: [u32; 4] = [headerlen1, headerlen2, payloadlen1, payloadlen2];
+
+        self.write_words_blocking(C::BLOCK_SIZE, &footer);
 
         while !T::regs().sr().read().ofne() {}
 
         let mut full_tag: [u8; 16] = [0; 16];
-        full_tag[0..4].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
-        full_tag[4..8].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
-        full_tag[8..12].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
-        full_tag[12..16].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
+        self.read_bytes_blocking(C::BLOCK_SIZE, &mut full_tag);
         let mut tag: [u8; TAG_SIZE] = [0; TAG_SIZE];
         tag.copy_from_slice(&full_tag[0..TAG_SIZE]);
 
@@ -1325,6 +1270,51 @@ impl<'d, T: Instance> Cryp<'d, T> {
         // Enable crypto processor.
         T::regs().cr().modify(|w| w.set_crypen(true));
     }
+
+    fn write_bytes_blocking(&self, block_size: usize, blocks: &[u8]) {
+        // Ensure input is a multiple of block size.
+        assert_eq!(blocks.len() % block_size, 0);
+        let mut index = 0;
+        let end_index = blocks.len();
+        while index < end_index {
+            let mut in_word: [u8; 4] = [0; 4];
+            in_word.copy_from_slice(&blocks[index..index + 4]);
+            T::regs().din().write_value(u32::from_ne_bytes(in_word));
+            index += 4;
+            if index % block_size == 0 {
+                // Block until input FIFO is empty.
+                while !T::regs().sr().read().ifem() {}
+            }
+        }
+    }
+
+    fn write_words_blocking(&self, block_size: usize, blocks: &[u32]) {
+        assert_eq!((blocks.len() * 4) % block_size, 0);
+        let mut byte_counter: usize = 0;
+        for word in blocks {
+            T::regs().din().write_value(*word);
+            byte_counter += 4;
+            if byte_counter % block_size == 0 {
+                // Block until input FIFO is empty.
+                while !T::regs().sr().read().ifem() {}
+            }
+        }
+    }
+
+    fn read_bytes_blocking(&self, block_size: usize, blocks: &mut [u8]) {
+        // Block until there is output to read.
+        while !T::regs().sr().read().ofne() {}
+        // Ensure input is a multiple of block size.
+        assert_eq!(blocks.len() % block_size, 0);
+        // Read block out
+        let mut index = 0;
+        let end_index = blocks.len();
+        while index < end_index {
+            let out_word: u32 = T::regs().dout().read();
+            blocks[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
+            index += 4;
+        }
+    }
 }
 
 pub(crate) mod sealed {
diff --git a/examples/stm32f7/src/bin/cryp.rs b/examples/stm32f7/src/bin/cryp.rs
index 04927841a..79b74e569 100644
--- a/examples/stm32f7/src/bin/cryp.rs
+++ b/examples/stm32f7/src/bin/cryp.rs
@@ -6,11 +6,19 @@ use aes_gcm::aead::{AeadInPlace, KeyInit};
 use aes_gcm::Aes128Gcm;
 use defmt::info;
 use embassy_executor::Spawner;
-use embassy_stm32::cryp::*;
-use embassy_stm32::Config;
+use embassy_stm32::dma::NoDma;
+use embassy_stm32::{
+    bind_interrupts,
+    cryp::{self, *},
+};
+use embassy_stm32::{peripherals, Config};
 use embassy_time::Instant;
 use {defmt_rtt as _, panic_probe as _};
 
+bind_interrupts!(struct Irqs {
+    CRYP => cryp::InterruptHandler<peripherals::CRYP>;
+});
+
 #[embassy_executor::main]
 async fn main(_spawner: Spawner) -> ! {
     let config = Config::default();
@@ -19,7 +27,7 @@ async fn main(_spawner: Spawner) -> ! {
     let payload: &[u8] = b"hello world";
     let aad: &[u8] = b"additional data";
 
-    let hw_cryp = Cryp::new(p.CRYP);
+    let hw_cryp = Cryp::new(p.CRYP, NoDma, NoDma, Irqs);
     let key: [u8; 16] = [0; 16];
     let mut ciphertext: [u8; 11] = [0; 11];
     let mut plaintext: [u8; 11] = [0; 11];