From f64a62149e423f6fdb643f7343d971eedc4a3a12 Mon Sep 17 00:00:00 2001
From: Caleb Garrett <47389035+caleb-garrett@users.noreply.github.com>
Date: Tue, 20 Feb 2024 15:26:31 -0500
Subject: [PATCH] Corrected CCM partial block ops.

---
 embassy-stm32/src/cryp/mod.rs | 46 ++++++++++++++++++-----------------
 1 file changed, 24 insertions(+), 22 deletions(-)

diff --git a/embassy-stm32/src/cryp/mod.rs b/embassy-stm32/src/cryp/mod.rs
index 81446e39e..634c85883 100644
--- a/embassy-stm32/src/cryp/mod.rs
+++ b/embassy-stm32/src/cryp/mod.rs
@@ -327,14 +327,16 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
         dir: Direction,
         int_data: &mut [u8; AES_BLOCK_SIZE],
         _temp1: [u32; 4],
-        _padding_mask: [u8; 16],
+        padding_mask: [u8; AES_BLOCK_SIZE],
     ) {
         if dir == Direction::Encrypt {
             //Handle special GCM partial block process.
             p.cr().modify(|w| w.set_crypen(false));
-            p.cr().write(|w| w.set_algomode3(true));
-            p.cr().write(|w| w.set_algomode0(0));
-            p.init(1).ivrr().write_value(2);
+            p.cr().modify(|w| w.set_algomode3(true));
+            p.cr().modify(|w| w.set_algomode0(0));
+            for i in 0..AES_BLOCK_SIZE {
+                int_data[i] = int_data[i] & padding_mask[i];
+            }
             p.cr().modify(|w| w.set_crypen(true));
             p.cr().modify(|w| w.set_gcm_ccmph(3));
             let mut index = 0;
@@ -479,10 +481,10 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesCcm<'c, KEY_SIZE> {
         if dir == Direction::Decrypt {
             p.cr().modify(|w| w.set_crypen(false));
             let iv1temp = p.init(1).ivrr().read();
-            temp1[0] = p.csgcmccmr(0).read();
-            temp1[1] = p.csgcmccmr(1).read();
-            temp1[2] = p.csgcmccmr(2).read();
-            temp1[3] = p.csgcmccmr(3).read();
+            temp1[0] = p.csgcmccmr(0).read().swap_bytes();
+            temp1[1] = p.csgcmccmr(1).read().swap_bytes();
+            temp1[2] = p.csgcmccmr(2).read().swap_bytes();
+            temp1[3] = p.csgcmccmr(3).read().swap_bytes();
             p.init(1).ivrr().write_value(iv1temp);
             p.cr().modify(|w| w.set_algomode3(false));
             p.cr().modify(|w| w.set_algomode0(6));
@@ -501,27 +503,27 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesCcm<'c, KEY_SIZE> {
     ) {
         if dir == Direction::Decrypt {
             //Handle special CCM partial block process.
-            let mut intdata_o: [u32; 4] = [0; 4];
-            for i in 0..intdata_o.len() {
-                intdata_o[i] = p.dout().read();
-            }
             let mut temp2 = [0; 4];
-            temp2[0] = p.csgcmccmr(0).read();
-            temp2[1] = p.csgcmccmr(1).read();
-            temp2[2] = p.csgcmccmr(2).read();
-            temp2[3] = p.csgcmccmr(3).read();
-            p.cr().write(|w| w.set_algomode3(true));
-            p.cr().write(|w| w.set_algomode0(1));
+            temp2[0] = p.csgcmccmr(0).read().swap_bytes();
+            temp2[1] = p.csgcmccmr(1).read().swap_bytes();
+            temp2[2] = p.csgcmccmr(2).read().swap_bytes();
+            temp2[3] = p.csgcmccmr(3).read().swap_bytes();
+            p.cr().modify(|w| w.set_algomode3(true));
+            p.cr().modify(|w| w.set_algomode0(1));
             p.cr().modify(|w| w.set_gcm_ccmph(3));
             // Header phase
             p.cr().modify(|w| w.set_gcm_ccmph(1));
+            for i in 0..AES_BLOCK_SIZE {
+                int_data[i] = int_data[i] & padding_mask[i];
+            }
             let mut in_data: [u32; 4] = [0; 4];
             for i in 0..in_data.len() {
-                let mut mask_bytes: [u8; 4] = [0; 4];
-                mask_bytes.copy_from_slice(&padding_mask[(i * 4)..(i * 4) + 4]);
-                let mask_word = u32::from_le_bytes(mask_bytes);
-                in_data[i] = intdata_o[i] & mask_word;
+                let mut int_bytes: [u8; 4] = [0; 4];
+                int_bytes.copy_from_slice(&int_data[(i * 4)..(i * 4) + 4]);
+                let int_word = u32::from_le_bytes(int_bytes);
+                in_data[i] = int_word;
                 in_data[i] = in_data[i] ^ temp1[i] ^ temp2[i];
+                p.din().write_value(in_data[i]);
             }
         }
     }