From 1ec9fc58f44987c11ac1e093f117679c56dbe2ed Mon Sep 17 00:00:00 2001
From: Caleb Garrett <47389035+caleb-garrett@users.noreply.github.com>
Date: Tue, 12 Mar 2024 14:52:34 -0400
Subject: [PATCH] Add async CRYP to test.

---
 embassy-stm32/src/cryp/mod.rs | 52 +++++++++++++++--------------------
 tests/stm32/src/bin/cryp.rs   | 31 ++++++++++++++-------
 tests/stm32/src/common.rs     |  1 +
 3 files changed, 44 insertions(+), 40 deletions(-)

diff --git a/embassy-stm32/src/cryp/mod.rs b/embassy-stm32/src/cryp/mod.rs
index 1a601533d..aa4c2a024 100644
--- a/embassy-stm32/src/cryp/mod.rs
+++ b/embassy-stm32/src/cryp/mod.rs
@@ -98,7 +98,7 @@ pub trait Cipher<'c> {
         DmaOut: crate::cryp::DmaOut<T>,
     {}
 
-    /// Called prior to processing the first associated data block for cipher-specific operations.
+    /// Returns the AAD header block as required by the cipher.
     fn get_header_block(&self) -> &[u8] {
         return [0; 0].as_slice();
     }
@@ -500,7 +500,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
     }
 
     #[cfg(cryp_v3)]
-    fn pre_final_block(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
+    fn pre_final(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
         //Handle special GCM partial block process.
         p.cr().modify(|w| w.set_npblb(padding_len as u8));
         [0; 4]
@@ -643,7 +643,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
     }
 
     #[cfg(cryp_v3)]
-    fn pre_final_block(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
+    fn pre_final(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
         //Handle special GCM partial block process.
         p.cr().modify(|w| w.set_npblb(padding_len as u8));
         [0; 4]
@@ -861,7 +861,7 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
     }
 
     #[cfg(cryp_v3)]
-    fn pre_final_block(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
+    fn pre_final(&self, p: &pac::cryp::Cryp, _dir: Direction, padding_len: usize) -> [u32; 4] {
         //Handle special GCM partial block process.
         p.cr().modify(|w| w.set_npblb(padding_len as u8));
         [0; 4]
@@ -1039,10 +1039,7 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
         instance
     }
 
-    /// Start a new cipher operation.
-    /// Key size must be 128, 192, or 256 bits.
-    /// Initialization vector must only be supplied if necessary.
-    /// Panics if there is any mismatch in parameters, such as an incorrect IV length or invalid mode.
+    /// Start a new encrypt or decrypt operation for the given cipher.
     pub fn start_blocking<'c, C: Cipher<'c> + CipherSized + IVSized>(&self, cipher: &'c C, dir: Direction) -> Context<'c, C> {
         let mut ctx: Context<'c, C> = Context {
             dir,
@@ -1117,10 +1114,7 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
         ctx
     }
 
-    /// Start a new cipher operation.
-    /// Key size must be 128, 192, or 256 bits.
-    /// Initialization vector must only be supplied if necessary.
-    /// Panics if there is any mismatch in parameters, such as an incorrect IV length or invalid mode.
+    /// Start a new encrypt or decrypt operation for the given cipher.
     pub async fn start<'c, C: Cipher<'c> + CipherSized + IVSized>(&mut self, cipher: &'c C, dir: Direction) -> Context<'c, C> 
     where
         DmaIn: crate::cryp::DmaIn<T>,
@@ -1201,10 +1195,9 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
 
     #[cfg(any(cryp_v2, cryp_v3))]
     /// Controls the header phase of cipher processing.
-    /// This function is only valid for GCM, CCM, and GMAC modes.
-    /// It only needs to be called if using one of these modes and there is associated data.
-    /// All AAD must be supplied to this function prior to starting the payload phase with `payload_blocking`.
-    /// The AAD must be supplied in multiples of the block size (128 bits), except when supplying the last block.
+    /// This function is only valid for authenticated ciphers including GCM, CCM, and GMAC.
+    /// All additional associated data (AAD) must be supplied to this function prior to starting the payload phase with `payload_blocking`.
+    /// The AAD must be supplied in multiples of the block size (128-bits for AES, 64-bits for DES), except when supplying the last block.
     /// When supplying the last block of AAD, `last_aad_block` must be `true`.
     pub fn aad_blocking<
         'c,
@@ -1299,10 +1292,9 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
 
     #[cfg(any(cryp_v2, cryp_v3))]
     /// Controls the header phase of cipher processing.
-    /// This function is only valid for GCM, CCM, and GMAC modes.
-    /// It only needs to be called if using one of these modes and there is associated data.
-    /// All AAD must be supplied to this function prior to starting the payload phase with `payload_blocking`.
-    /// The AAD must be supplied in multiples of the block size (128 bits), except when supplying the last block.
+    /// This function is only valid for authenticated ciphers including GCM, CCM, and GMAC.
+    /// All additional associated data (AAD) must be supplied to this function prior to starting the payload phase with `payload`.
+    /// The AAD must be supplied in multiples of the block size (128-bits for AES, 64-bits for DES), except when supplying the last block.
     /// When supplying the last block of AAD, `last_aad_block` must be `true`.
     pub async fn aad<
         'c,
@@ -1402,7 +1394,7 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
     /// The context determines algorithm, mode, and state of the crypto accelerator.
     /// When the last piece of data is supplied, `last_block` should be `true`.
     /// This function panics under various mismatches of parameters.
-    /// Input and output buffer lengths must match.
+    /// Output buffer must be at least as long as the input buffer.
     /// Data must be a multiple of block size (128-bits for AES, 64-bits for DES) for CBC and ECB modes.
     /// Padding or ciphertext stealing must be managed by the application for these modes.
     /// Data must also be a multiple of block size unless `last_block` is `true`.
@@ -1455,9 +1447,9 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
         for block in 0..num_full_blocks {
             let index = block * C::BLOCK_SIZE;
             // Write block in
-            self.write_bytes_blocking(C::BLOCK_SIZE, &input[index..index + 4]);
+            self.write_bytes_blocking(C::BLOCK_SIZE, &input[index..index + C::BLOCK_SIZE]);
             // Read block out
-            self.read_bytes_blocking(C::BLOCK_SIZE, &mut output[index..index + 4]);
+            self.read_bytes_blocking(C::BLOCK_SIZE, &mut output[index..index + C::BLOCK_SIZE]);
         }
 
         // Handle the final block, which is incomplete.
@@ -1491,7 +1483,7 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
     /// The context determines algorithm, mode, and state of the crypto accelerator.
     /// When the last piece of data is supplied, `last_block` should be `true`.
     /// This function panics under various mismatches of parameters.
-    /// Input and output buffer lengths must match.
+    /// Output buffer must be at least as long as the input buffer.
     /// Data must be a multiple of block size (128-bits for AES, 64-bits for DES) for CBC and ECB modes.
     /// Padding or ciphertext stealing must be managed by the application for these modes.
     /// Data must also be a multiple of block size unless `last_block` is `true`.
@@ -1548,9 +1540,9 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
         for block in 0..num_full_blocks {
             let index = block * C::BLOCK_SIZE;
             // Read block out
-            let read = Self::read_bytes(&mut self.outdma, C::BLOCK_SIZE, &mut output[index..index + 4]);
+            let read = Self::read_bytes(&mut self.outdma, C::BLOCK_SIZE, &mut output[index..index + C::BLOCK_SIZE]);
             // Write block in
-            let write = Self::write_bytes(&mut self.indma, C::BLOCK_SIZE, &input[index..index + 4]);
+            let write = Self::write_bytes(&mut self.indma, C::BLOCK_SIZE, &input[index..index + C::BLOCK_SIZE]);
             embassy_futures::join::join(read, write).await;
         }
 
@@ -1583,8 +1575,8 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
     }
 
     #[cfg(any(cryp_v2, cryp_v3))]
-    /// This function only needs to be called for GCM, CCM, and GMAC modes to
-    /// generate an authentication tag.
+    /// Generates an authentication tag for authenticated ciphers including GCM, CCM, and GMAC.
+    /// Called after the all data has been encrypted/decrypted by `payload`.
     pub fn finish_blocking<
         'c,
         const TAG_SIZE: usize,
@@ -1629,8 +1621,8 @@ impl<'d, T: Instance, DmaIn, DmaOut> Cryp<'d, T, DmaIn, DmaOut> {
     }
 
     #[cfg(any(cryp_v2, cryp_v3))]
-    /// This function only needs to be called for GCM, CCM, and GMAC modes to
-    /// generate an authentication tag.
+    // Generates an authentication tag for authenticated ciphers including GCM, CCM, and GMAC.
+    /// Called after the all data has been encrypted/decrypted by `payload`.
     pub async fn finish<'c, const TAG_SIZE: usize, C: Cipher<'c> + CipherSized + IVSized + CipherAuthenticated<TAG_SIZE>>(&mut self, mut ctx: Context<'c, C>) -> [u8; TAG_SIZE]
     where
         DmaIn: crate::cryp::DmaIn<T>,
diff --git a/tests/stm32/src/bin/cryp.rs b/tests/stm32/src/bin/cryp.rs
index f105abf26..6bca55f55 100644
--- a/tests/stm32/src/bin/cryp.rs
+++ b/tests/stm32/src/bin/cryp.rs
@@ -10,9 +10,17 @@ use aes_gcm::aead::{AeadInPlace, KeyInit};
 use aes_gcm::Aes128Gcm;
 use common::*;
 use embassy_executor::Spawner;
-use embassy_stm32::cryp::*;
+use embassy_stm32::{
+    bind_interrupts,
+    cryp::{self, *},
+    peripherals
+};
 use {defmt_rtt as _, panic_probe as _};
 
+bind_interrupts!(struct Irqs {
+    CRYP => cryp::InterruptHandler<peripherals::CRYP>;
+});
+
 #[embassy_executor::main]
 async fn main(_spawner: Spawner) {
     let p: embassy_stm32::Peripherals = embassy_stm32::init(config());
@@ -22,27 +30,30 @@ async fn main(_spawner: Spawner) {
     const AAD1: &[u8] = b"additional data 1 stdargadrhaethaethjatjatjaetjartjstrjsfkk;'jopofyuisrteytweTASTUIKFUKIXTRDTEREharhaeryhaterjartjarthaethjrtjarthaetrhartjatejatrjsrtjartjyt1";
     const AAD2: &[u8] = b"additional data 2 stdhthsthsthsrthsrthsrtjdykjdukdyuldadfhsdghsdghsdghsadghjk'hioethjrtjarthaetrhartjatecfgjhzdfhgzdfhzdfghzdfhzdfhzfhjatrjsrtjartjytjfytjfyg";
 
-    let hw_cryp = Cryp::new(p.CRYP);
+    let in_dma = peri!(p, CRYP_IN_DMA);
+    let out_dma = peri!(p, CRYP_OUT_DMA);
+
+    let mut hw_cryp = Cryp::new(p.CRYP, in_dma, out_dma, Irqs);
     let key: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
     let mut ciphertext: [u8; PAYLOAD1.len() + PAYLOAD2.len()] = [0; PAYLOAD1.len() + PAYLOAD2.len()];
     let mut plaintext: [u8; PAYLOAD1.len() + PAYLOAD2.len()] = [0; PAYLOAD1.len() + PAYLOAD2.len()];
     let iv: [u8; 12] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
 
-    // Encrypt in hardware using AES-GCM 128-bit
+    // Encrypt in hardware using AES-GCM 128-bit in blocking mode.
     let aes_gcm = AesGcm::new(&key, &iv);
-    let mut gcm_encrypt = hw_cryp.start(&aes_gcm, Direction::Encrypt);
+    let mut gcm_encrypt = hw_cryp.start_blocking(&aes_gcm, Direction::Encrypt);
     hw_cryp.aad_blocking(&mut gcm_encrypt, AAD1, false);
     hw_cryp.aad_blocking(&mut gcm_encrypt, AAD2, true);
     hw_cryp.payload_blocking(&mut gcm_encrypt, PAYLOAD1, &mut ciphertext[..PAYLOAD1.len()], false);
     hw_cryp.payload_blocking(&mut gcm_encrypt, PAYLOAD2, &mut ciphertext[PAYLOAD1.len()..], true);
     let encrypt_tag = hw_cryp.finish_blocking(gcm_encrypt);
 
-    // Decrypt in hardware using AES-GCM 128-bit
-    let mut gcm_decrypt = hw_cryp.start(&aes_gcm, Direction::Decrypt);
-    hw_cryp.aad_blocking(&mut gcm_decrypt, AAD1, false);
-    hw_cryp.aad_blocking(&mut gcm_decrypt, AAD2, true);
-    hw_cryp.payload_blocking(&mut gcm_decrypt, &ciphertext, &mut plaintext, true);
-    let decrypt_tag = hw_cryp.finish_blocking(gcm_decrypt);
+    // Decrypt in hardware using AES-GCM 128-bit in async (DMA) mode.
+    let mut gcm_decrypt = hw_cryp.start(&aes_gcm, Direction::Decrypt).await;
+    hw_cryp.aad(&mut gcm_decrypt, AAD1, false).await;
+    hw_cryp.aad(&mut gcm_decrypt, AAD2, true).await;
+    hw_cryp.payload(&mut gcm_decrypt, &ciphertext, &mut plaintext, true).await;
+    let decrypt_tag = hw_cryp.finish(gcm_decrypt).await;
 
     info!("AES-GCM Ciphertext: {:?}", ciphertext);
     info!("AES-GCM Plaintext: {:?}", plaintext);
diff --git a/tests/stm32/src/common.rs b/tests/stm32/src/common.rs
index 3297ea7e2..c379863a8 100644
--- a/tests/stm32/src/common.rs
+++ b/tests/stm32/src/common.rs
@@ -140,6 +140,7 @@ define_peris!(
 );
 #[cfg(any(feature = "stm32h755zi", feature = "stm32h753zi"))]
 define_peris!(
+    CRYP_IN_DMA = DMA1_CH0, CRYP_OUT_DMA = DMA1_CH1,
     UART = USART1, UART_TX = PB6, UART_RX = PB7, UART_TX_DMA = DMA1_CH0, UART_RX_DMA = DMA1_CH1,
     SPI = SPI1, SPI_SCK = PA5, SPI_MOSI = PB5, SPI_MISO = PA6, SPI_TX_DMA = DMA1_CH0, SPI_RX_DMA = DMA1_CH1,
     ADC = ADC1, DAC = DAC1, DAC_PIN = PA4,