diff --git a/embassy-stm32/src/cryp/mod.rs b/embassy-stm32/src/cryp/mod.rs index 038923870..9d1a62905 100644 --- a/embassy-stm32/src/cryp/mod.rs +++ b/embassy-stm32/src/cryp/mod.rs @@ -988,7 +988,7 @@ impl<'d, T: Instance> Cryp<'d, T> { // Write block in while index < end_index { let mut in_word: [u8; 4] = [0; 4]; - in_word.copy_from_slice(&aad[index..index + 4]); + in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]); T::regs().din().write_value(u32::from_ne_bytes(in_word)); index += 4; } @@ -1000,14 +1000,16 @@ impl<'d, T: Instance> Cryp<'d, T> { ctx.aad_buffer_len = 0; let leftovers = aad_len_remaining % C::BLOCK_SIZE; ctx.aad_buffer[..leftovers].copy_from_slice(&aad[aad.len() - leftovers..aad.len()]); + ctx.aad_buffer_len += leftovers; + ctx.aad_buffer[ctx.aad_buffer_len..].fill(0); aad_len_remaining -= leftovers; assert_eq!(aad_len_remaining % C::BLOCK_SIZE, 0); // Load full data blocks into core. let num_full_blocks = aad_len_remaining / C::BLOCK_SIZE; - for _ in 0..num_full_blocks { - let mut index = len_to_copy; - let end_index = len_to_copy + C::BLOCK_SIZE; + for block in 0..num_full_blocks { + let mut index = len_to_copy + (block * C::BLOCK_SIZE); + let end_index = index + C::BLOCK_SIZE; // Write block in while index < end_index { let mut in_word: [u8; 4] = [0; 4]; @@ -1020,6 +1022,19 @@ impl<'d, T: Instance> Cryp<'d, T> { } if last_aad_block { + if leftovers > 0 { + let mut index = 0; + let end_index = C::BLOCK_SIZE; + // Write block in + while index < end_index { + let mut in_word: [u8; 4] = [0; 4]; + in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]); + T::regs().din().write_value(u32::from_ne_bytes(in_word)); + index += 4; + } + // Block until input FIFO is empty. + while !T::regs().sr().read().ifem() {} + } // Switch to payload phase. ctx.aad_complete = true; T::regs().cr().modify(|w| w.set_crypen(false)); @@ -1065,7 +1080,7 @@ impl<'d, T: Instance> Cryp<'d, T> { if ctx.last_block_processed { panic!("The last block has already been processed!"); } - if input.len() != output.len() { + if input.len() > output.len() { panic!("Output buffer length must match input length."); } if !last_block {