From 690b2118c6fdad88bf1e595b6a0c0afdb0583d28 Mon Sep 17 00:00:00 2001
From: Caleb Garrett <47389035+caleb-garrett@users.noreply.github.com>
Date: Tue, 20 Feb 2024 11:54:39 -0500
Subject: [PATCH] CCM mode functional.

---
 embassy-stm32/src/cryp/mod.rs | 372 ++++++++++++++++++++++++++--------
 1 file changed, 293 insertions(+), 79 deletions(-)

diff --git a/embassy-stm32/src/cryp/mod.rs b/embassy-stm32/src/cryp/mod.rs
index 29c1db12e..fe248def1 100644
--- a/embassy-stm32/src/cryp/mod.rs
+++ b/embassy-stm32/src/cryp/mod.rs
@@ -1,6 +1,6 @@
 //! Crypto Accelerator (CRYP)
+use core::cmp::min;
 use core::marker::PhantomData;
-
 use embassy_hal_internal::{into_ref, PeripheralRef};
 
 use crate::pac;
@@ -21,7 +21,7 @@ pub trait Cipher<'c> {
     const REQUIRES_PADDING: bool = false;
 
     /// Returns the symmetric key.
-    fn key(&self) -> &'c [u8];
+    fn key(&self) -> &[u8];
 
     /// Returns the initialization vector.
     fn iv(&self) -> &[u8];
@@ -36,10 +36,25 @@ pub trait Cipher<'c> {
     fn init_phase(&self, _p: &pac::cryp::Cryp) {}
 
     /// Called prior to processing the last data block for cipher-specific operations.
-    fn pre_final_block(&self, _p: &pac::cryp::Cryp) {}
+    fn pre_final_block(&self, _p: &pac::cryp::Cryp, _dir: Direction) -> [u32; 4] {
+        return [0; 4];
+    }
 
     /// Called after processing the last data block for cipher-specific operations.
-    fn post_final_block(&self, _p: &pac::cryp::Cryp, _dir: Direction, _int_data: &[u8; AES_BLOCK_SIZE]) {}
+    fn post_final_block(
+        &self,
+        _p: &pac::cryp::Cryp,
+        _dir: Direction,
+        _int_data: &[u8; AES_BLOCK_SIZE],
+        _temp1: [u32; 4],
+        _padding_mask: [u8; 16],
+    ) {
+    }
+
+    /// Called prior to processing the first associated data block for cipher-specific operations.
+    fn get_header_block(&self) -> &[u8] {
+        return [0; 0].as_slice();
+    }
 }
 
 /// This trait enables restriction of ciphers to specific key sizes.
@@ -204,17 +219,27 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
         while p.cr().read().crypen() {}
     }
 
-    fn pre_final_block(&self, p: &pac::cryp::Cryp) {
+    fn pre_final_block(&self, p: &pac::cryp::Cryp, dir: Direction) -> [u32; 4] {
         //Handle special GCM partial block process.
-        p.cr().modify(|w| w.set_crypen(false));
-        p.cr().modify(|w| w.set_algomode3(false));
-        p.cr().modify(|w| w.set_algomode0(6));
-        let iv1r = p.csgcmccmr(7).read() - 1;
-        p.init(1).ivrr().write_value(iv1r);
-        p.cr().modify(|w| w.set_crypen(true));
+        if dir == Direction::Encrypt {
+            p.cr().modify(|w| w.set_crypen(false));
+            p.cr().modify(|w| w.set_algomode3(false));
+            p.cr().modify(|w| w.set_algomode0(6));
+            let iv1r = p.csgcmccmr(7).read() - 1;
+            p.init(1).ivrr().write_value(iv1r);
+            p.cr().modify(|w| w.set_crypen(true));
+        }
+        [0; 4]
     }
 
-    fn post_final_block(&self, p: &pac::cryp::Cryp, dir: Direction, int_data: &[u8; AES_BLOCK_SIZE]) {
+    fn post_final_block(
+        &self,
+        p: &pac::cryp::Cryp,
+        dir: Direction,
+        int_data: &[u8; AES_BLOCK_SIZE],
+        _temp1: [u32; 4],
+        _padding_mask: [u8; 16],
+    ) {
         if dir == Direction::Encrypt {
             //Handle special GCM partial block process.
             p.cr().modify(|w| w.set_crypen(false));
@@ -281,17 +306,27 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
         while p.cr().read().crypen() {}
     }
 
-    fn pre_final_block(&self, p: &pac::cryp::Cryp) {
+    fn pre_final_block(&self, p: &pac::cryp::Cryp, dir: Direction) -> [u32; 4] {
         //Handle special GCM partial block process.
-        p.cr().modify(|w| w.set_crypen(false));
-        p.cr().modify(|w| w.set_algomode3(false));
-        p.cr().modify(|w| w.set_algomode0(6));
-        let iv1r = p.csgcmccmr(7).read() - 1;
-        p.init(1).ivrr().write_value(iv1r);
-        p.cr().modify(|w| w.set_crypen(true));
+        if dir == Direction::Encrypt {
+            p.cr().modify(|w| w.set_crypen(false));
+            p.cr().modify(|w| w.set_algomode3(false));
+            p.cr().modify(|w| w.set_algomode0(6));
+            let iv1r = p.csgcmccmr(7).read() - 1;
+            p.init(1).ivrr().write_value(iv1r);
+            p.cr().modify(|w| w.set_crypen(true));
+        }
+        [0; 4]
     }
 
-    fn post_final_block(&self, p: &pac::cryp::Cryp, dir: Direction, int_data: &[u8; AES_BLOCK_SIZE]) {
+    fn post_final_block(
+        &self,
+        p: &pac::cryp::Cryp,
+        dir: Direction,
+        int_data: &[u8; AES_BLOCK_SIZE],
+        _temp1: [u32; 4],
+        _padding_mask: [u8; 16],
+    ) {
         if dir == Direction::Encrypt {
             //Handle special GCM partial block process.
             p.cr().modify(|w| w.set_crypen(false));
@@ -320,49 +355,180 @@ impl<'c> CipherSized for AesGmac<'c, { 192 / 8 }> {}
 impl<'c> CipherSized for AesGmac<'c, { 256 / 8 }> {}
 impl<'c, const KEY_SIZE: usize> CipherAuthenticated for AesGmac<'c, KEY_SIZE> {}
 
-// struct AesCcm<'c, const KEY_SIZE: usize> {
-//     iv: &'c [u8],
-//     key: &'c [u8; KEY_SIZE],
-//     aad_len: usize,
-//     payload_len: usize,
-// }
+pub struct AesCcm<'c, const KEY_SIZE: usize> {
+    key: &'c [u8; KEY_SIZE],
+    aad_header: [u8; 6],
+    aad_header_len: usize,
+    block0: [u8; 16],
+    ctr: [u8; 16],
+}
 
-// impl<'c, const KEY_SIZE: usize> AesCcm<'c, KEY_SIZE> {
-//     pub fn new(&self, key: &[u8; KEY_SIZE], iv: &[u8], aad_len: usize, payload_len: usize) {
-//         if iv.len() > 13 {
-//             panic!("CCM IV length must be 13 bytes or less.");
-//         }
-//         self.key = key;
-//         self.iv = iv;
-//         self.aad_len = aad_len;
-//         self.payload_len = payload_len;
-//     }
-// }
+impl<'c, const KEY_SIZE: usize> AesCcm<'c, KEY_SIZE> {
+    pub fn new(key: &'c [u8; KEY_SIZE], iv: &'c [u8], aad_len: usize, payload_len: usize, tag_len: u8) -> Self {
+        if (iv.len()) > 13 || (iv.len() < 7) {
+            panic!("CCM IV length must be 7-13 bytes.");
+        }
+        if (tag_len < 4) || (tag_len > 16) {
+            panic!("Tag length must be between 4 and 16 bytes.");
+        }
+        if tag_len % 2 > 0 {
+            panic!("Tag length must be a multiple of 2 bytes.");
+        }
 
-// impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesCcm<'c, KEY_SIZE> {
-//     const BLOCK_SIZE: usize = AES_BLOCK_SIZE;
+        let mut aad_header: [u8; 6] = [0; 6];
+        let mut aad_header_len = 0;
+        let mut block0: [u8; 16] = [0; 16];
+        if aad_len != 0 {
+            if aad_len < 65280 {
+                aad_header[0] = (aad_len >> 8) as u8 & 0xFF;
+                aad_header[1] = aad_len as u8 & 0xFF;
+                aad_header_len = 2;
+            } else {
+                aad_header[0] = 0xFF;
+                aad_header[1] = 0xFE;
+                let aad_len_bytes: [u8; 4] = aad_len.to_be_bytes();
+                aad_header[2] = aad_len_bytes[0];
+                aad_header[3] = aad_len_bytes[1];
+                aad_header[4] = aad_len_bytes[2];
+                aad_header[5] = aad_len_bytes[3];
+                aad_header_len = 6;
+            }
+        }
+        let total_aad_len = aad_header_len + aad_len;
+        let mut aad_padding_len = 16 - (total_aad_len % 16);
+        if aad_padding_len == 16 {
+            aad_padding_len = 0;
+        }
+        aad_header_len += aad_padding_len;
+        let total_aad_len_padded = aad_header_len + aad_len;
+        if total_aad_len_padded > 0 {
+            block0[0] = 0x40;
+        }
+        block0[0] |= (((tag_len - 2) >> 1) & 0x07) << 3;
+        block0[0] |= ((15 - (iv.len() as u8)) - 1) & 0x07;
+        block0[1..1 + iv.len()].copy_from_slice(iv);
+        let payload_len_bytes: [u8; 4] = payload_len.to_be_bytes();
+        if iv.len() <= 11 {
+            block0[12] = payload_len_bytes[0];
+        } else if payload_len_bytes[0] > 0 {
+            panic!("Message is too large for given IV size.");
+        }
+        if iv.len() <= 12 {
+            block0[13] = payload_len_bytes[1];
+        } else if payload_len_bytes[1] > 0 {
+            panic!("Message is too large for given IV size.");
+        }
+        block0[14] = payload_len_bytes[2];
+        block0[15] = payload_len_bytes[3];
+        let mut ctr: [u8; 16] = [0; 16];
+        ctr[0] = block0[0] & 0x07;
+        ctr[1..1 + iv.len()].copy_from_slice(&block0[1..1 + iv.len()]);
+        ctr[15] = 0x01;
 
-//     fn key(&self) -> &'c [u8] {
-//         self.key
-//     }
+        return Self {
+            key: key,
+            aad_header: aad_header,
+            aad_header_len: aad_header_len,
+            block0: block0,
+            ctr: ctr,
+        };
+    }
+}
 
-//     fn iv(&self) -> &'c [u8] {
-//         self.iv
-//     }
+impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesCcm<'c, KEY_SIZE> {
+    const BLOCK_SIZE: usize = AES_BLOCK_SIZE;
 
-//     fn set_algomode(&self, p: &pac::cryp::Cryp) {
-//         p.cr().modify(|w| w.set_algomode0(1));
-//         p.cr().modify(|w| w.set_algomode3(true));
-//     }
+    fn key(&self) -> &'c [u8] {
+        self.key
+    }
 
-//     fn init_phase(&self, p: &pac::cryp::Cryp) {
-//         todo!();
-//     }
-// }
+    fn iv(&self) -> &[u8] {
+        self.ctr.as_slice()
+    }
 
-// impl<'c> CipherSized for AesCcm<'c, { 128 / 8 }> {}
-// impl<'c> CipherSized for AesCcm<'c, { 192 / 8 }> {}
-// impl<'c> CipherSized for AesCcm<'c, { 256 / 8 }> {}
+    fn set_algomode(&self, p: &pac::cryp::Cryp) {
+        p.cr().modify(|w| w.set_algomode0(1));
+        p.cr().modify(|w| w.set_algomode3(true));
+    }
+
+    fn init_phase(&self, p: &pac::cryp::Cryp) {
+        p.cr().modify(|w| w.set_gcm_ccmph(0));
+
+        let mut index = 0;
+        let end_index = index + Self::BLOCK_SIZE;
+        // Write block in
+        while index < end_index {
+            let mut in_word: [u8; 4] = [0; 4];
+            in_word.copy_from_slice(&self.block0[index..index + 4]);
+            p.din().write_value(u32::from_ne_bytes(in_word));
+            index += 4;
+        }
+        p.cr().modify(|w| w.set_crypen(true));
+        while p.cr().read().crypen() {}
+    }
+
+    fn get_header_block(&self) -> &[u8] {
+        return &self.aad_header[0..self.aad_header_len];
+    }
+
+    fn pre_final_block(&self, p: &pac::cryp::Cryp, dir: Direction) -> [u32; 4] {
+        //Handle special CCM partial block process.
+        let mut temp1 = [0; 4];
+        if dir == Direction::Decrypt {
+            p.cr().modify(|w| w.set_crypen(false));
+            let iv1temp = p.init(1).ivrr().read();
+            temp1[0] = p.csgcmccmr(0).read();
+            temp1[1] = p.csgcmccmr(1).read();
+            temp1[2] = p.csgcmccmr(2).read();
+            temp1[3] = p.csgcmccmr(3).read();
+            p.init(1).ivrr().write_value(iv1temp);
+            p.cr().modify(|w| w.set_algomode3(false));
+            p.cr().modify(|w| w.set_algomode0(6));
+            p.cr().modify(|w| w.set_crypen(true));
+        }
+        return temp1;
+    }
+
+    fn post_final_block(
+        &self,
+        p: &pac::cryp::Cryp,
+        dir: Direction,
+        int_data: &[u8; AES_BLOCK_SIZE],
+        temp1: [u32; 4],
+        padding_mask: [u8; 16],
+    ) {
+        if dir == Direction::Decrypt {
+            //Handle special CCM partial block process.
+            let mut intdata_o: [u32; 4] = [0; 4];
+            for i in 0..intdata_o.len() {
+                intdata_o[i] = p.dout().read();
+            }
+            let mut temp2 = [0; 4];
+            temp2[0] = p.csgcmccmr(0).read();
+            temp2[1] = p.csgcmccmr(1).read();
+            temp2[2] = p.csgcmccmr(2).read();
+            temp2[3] = p.csgcmccmr(3).read();
+            p.cr().write(|w| w.set_algomode3(true));
+            p.cr().write(|w| w.set_algomode0(1));
+            p.cr().modify(|w| w.set_gcm_ccmph(3));
+            // Header phase
+            p.cr().modify(|w| w.set_gcm_ccmph(1));
+            let mut in_data: [u32; 4] = [0; 4];
+            for i in 0..in_data.len() {
+                let mut mask_bytes: [u8; 4] = [0; 4];
+                mask_bytes.copy_from_slice(&padding_mask[(i * 4)..(i * 4) + 4]);
+                let mask_word = u32::from_le_bytes(mask_bytes);
+                in_data[i] = intdata_o[i] & mask_word;
+                in_data[i] = in_data[i] ^ temp1[i] ^ temp2[i];
+            }
+        }
+    }
+}
+
+impl<'c> CipherSized for AesCcm<'c, { 128 / 8 }> {}
+impl<'c> CipherSized for AesCcm<'c, { 192 / 8 }> {}
+impl<'c> CipherSized for AesCcm<'c, { 256 / 8 }> {}
+impl<'c, const KEY_SIZE: usize> CipherAuthenticated for AesCcm<'c, KEY_SIZE> {}
 
 /// Holds the state information for a cipher operation.
 /// Allows suspending/resuming of cipher operations.
@@ -371,6 +537,7 @@ pub struct Context<'c, C: Cipher<'c> + CipherSized> {
     cipher: &'c C,
     dir: Direction,
     last_block_processed: bool,
+    header_processed: bool,
     aad_complete: bool,
     cr: u32,
     iv: [u32; 4],
@@ -378,6 +545,8 @@ pub struct Context<'c, C: Cipher<'c> + CipherSized> {
     csgcm: [u32; 8],
     header_len: u64,
     payload_len: u64,
+    aad_buffer: [u8; 16],
+    aad_buffer_len: usize,
 }
 
 /// Selects whether the crypto processor operates in encryption or decryption mode.
@@ -420,6 +589,9 @@ impl<'d, T: Instance> Cryp<'d, T> {
             payload_len: 0,
             cipher: cipher,
             phantom_data: PhantomData,
+            header_processed: false,
+            aad_buffer: [0; 16],
+            aad_buffer_len: 0,
         };
 
         T::regs().cr().modify(|w| w.set_crypen(false));
@@ -492,16 +664,9 @@ impl<'d, T: Instance> Cryp<'d, T> {
     ) {
         self.load_context(ctx);
 
-        let last_block_remainder = aad.len() % C::BLOCK_SIZE;
-
         // Perform checks for correctness.
         if ctx.aad_complete {
-            panic!("Cannot update AAD after calling 'update'!")
-        }
-        if !last_aad_block {
-            if last_block_remainder != 0 {
-                panic!("Input length must be a multiple of {} bytes.", C::BLOCK_SIZE);
-            }
+            panic!("Cannot update AAD after starting payload!")
         }
 
         ctx.header_len += aad.len() as u64;
@@ -511,11 +676,49 @@ impl<'d, T: Instance> Cryp<'d, T> {
         T::regs().cr().modify(|w| w.set_gcm_ccmph(1));
         T::regs().cr().modify(|w| w.set_crypen(true));
 
-        // Load data into core, block by block.
-        let num_full_blocks = aad.len() / C::BLOCK_SIZE;
-        for block in 0..num_full_blocks {
-            let mut index = block * C::BLOCK_SIZE;
-            let end_index = index + C::BLOCK_SIZE;
+        // First write the header B1 block if not yet written.
+        if !ctx.header_processed {
+            ctx.header_processed = true;
+            let header = ctx.cipher.get_header_block();
+            ctx.aad_buffer[0..header.len()].copy_from_slice(header);
+            ctx.aad_buffer_len += header.len();
+        }
+
+        // Fill the header block to make a full block.
+        let len_to_copy = min(aad.len(), C::BLOCK_SIZE - ctx.aad_buffer_len);
+        ctx.aad_buffer[ctx.aad_buffer_len..ctx.aad_buffer_len + len_to_copy].copy_from_slice(&aad[..len_to_copy]);
+        ctx.aad_buffer_len += len_to_copy;
+        ctx.aad_buffer[ctx.aad_buffer_len..].fill(0);
+        let mut aad_len_remaining = aad.len() - len_to_copy;
+
+        if ctx.aad_buffer_len < C::BLOCK_SIZE {
+            // The buffer isn't full and this is the last buffer, so process it as is (already padded).
+            if last_aad_block {
+                let mut index = 0;
+                let end_index = C::BLOCK_SIZE;
+                // Write block in
+                while index < end_index {
+                    let mut in_word: [u8; 4] = [0; 4];
+                    in_word.copy_from_slice(&aad[index..index + 4]);
+                    T::regs().din().write_value(u32::from_ne_bytes(in_word));
+                    index += 4;
+                }
+                // Block until input FIFO is empty.
+                while !T::regs().sr().read().ifem() {}
+
+                // Switch to payload phase.
+                ctx.aad_complete = true;
+                T::regs().cr().modify(|w| w.set_crypen(false));
+                T::regs().cr().modify(|w| w.set_gcm_ccmph(2));
+                T::regs().cr().modify(|w| w.fflush());
+            } else {
+                // Just return because we don't yet have a full block to process.
+                return;
+            }
+        } else {
+            // Load the full block from the buffer.
+            let mut index = 0;
+            let end_index = C::BLOCK_SIZE;
             // Write block in
             while index < end_index {
                 let mut in_word: [u8; 4] = [0; 4];
@@ -527,20 +730,26 @@ impl<'d, T: Instance> Cryp<'d, T> {
             while !T::regs().sr().read().ifem() {}
         }
 
-        // Handle the final block, which is incomplete.
-        if last_block_remainder > 0 {
-            let mut last_block: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
-            last_block[..last_block_remainder].copy_from_slice(&aad[aad.len() - last_block_remainder..aad.len()]);
-            let mut index = 0;
-            let end_index = C::BLOCK_SIZE;
+        // Handle a partial block that is passed in.
+        ctx.aad_buffer_len = 0;
+        let leftovers = aad_len_remaining % C::BLOCK_SIZE;
+        ctx.aad_buffer[..leftovers].copy_from_slice(&aad[aad.len() - leftovers..aad.len()]);
+        aad_len_remaining -= leftovers;
+        assert_eq!(aad_len_remaining % C::BLOCK_SIZE, 0);
+
+        // Load full data blocks into core.
+        let num_full_blocks = aad_len_remaining / C::BLOCK_SIZE;
+        for _ in 0..num_full_blocks {
+            let mut index = len_to_copy;
+            let end_index = len_to_copy + C::BLOCK_SIZE;
             // Write block in
             while index < end_index {
                 let mut in_word: [u8; 4] = [0; 4];
-                in_word.copy_from_slice(&last_block[index..index + 4]);
+                in_word.copy_from_slice(&aad[index..index + 4]);
                 T::regs().din().write_value(u32::from_ne_bytes(in_word));
                 index += 4;
             }
-            // Block until input FIFO is empty
+            // Block until input FIFO is empty.
             while !T::regs().sr().read().ifem() {}
         }
 
@@ -630,7 +839,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
 
         // Handle the final block, which is incomplete.
         if last_block_remainder > 0 {
-            ctx.cipher.pre_final_block(&T::regs());
+            let temp1 = ctx.cipher.pre_final_block(&T::regs(), ctx.dir);
 
             let mut intermediate_data: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
             let mut last_block: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
@@ -660,10 +869,15 @@ impl<'d, T: Instance> Cryp<'d, T> {
             output[output_len - last_block_remainder..output_len]
                 .copy_from_slice(&intermediate_data[0..last_block_remainder]);
 
-            ctx.cipher.post_final_block(&T::regs(), ctx.dir, &intermediate_data);
+            let mut mask: [u8; 16] = [0; 16];
+            mask[..last_block_remainder].fill(0xFF);
+            ctx.cipher
+                .post_final_block(&T::regs(), ctx.dir, &intermediate_data, temp1, mask);
         }
 
         ctx.payload_len += input.len() as u64;
+
+        self.store_context(ctx);
     }
 
     /// This function only needs to be called for GCM, CCM, and GMAC modes to