From accec7a84076ff26f9bdad388809b96537cf31da Mon Sep 17 00:00:00 2001
From: Sebastian Goll <sebastian.goll@gmx.de>
Date: Thu, 21 Mar 2024 00:30:53 +0100
Subject: [PATCH] Implement asynchronous transaction for I2C v1

---
 embassy-stm32/src/i2c/mod.rs |   4 +-
 embassy-stm32/src/i2c/v1.rs  | 214 ++++++++++++++++++++---------------
 2 files changed, 123 insertions(+), 95 deletions(-)

diff --git a/embassy-stm32/src/i2c/mod.rs b/embassy-stm32/src/i2c/mod.rs
index f1b11cc44..6700f0f7d 100644
--- a/embassy-stm32/src/i2c/mod.rs
+++ b/embassy-stm32/src/i2c/mod.rs
@@ -332,8 +332,6 @@ impl<'d, T: Instance, TXDMA: TxDma<T>, RXDMA: RxDma<T>> embedded_hal_async::i2c:
         address: u8,
         operations: &mut [embedded_hal_1::i2c::Operation<'_>],
     ) -> Result<(), Self::Error> {
-        let _ = address;
-        let _ = operations;
-        todo!()
+        self.transaction(address, operations).await
     }
 }
diff --git a/embassy-stm32/src/i2c/v1.rs b/embassy-stm32/src/i2c/v1.rs
index dd2cea6b8..a740ab834 100644
--- a/embassy-stm32/src/i2c/v1.rs
+++ b/embassy-stm32/src/i2c/v1.rs
@@ -111,12 +111,21 @@ impl FrameOptions {
 /// [transaction contract]: embedded_hal_1::i2c::I2c::transaction
 fn operation_frames<'a, 'b: 'a>(
     operations: &'a mut [Operation<'b>],
-) -> impl IntoIterator<Item = (&'a mut Operation<'b>, FrameOptions)> {
+) -> Result<impl IntoIterator<Item = (&'a mut Operation<'b>, FrameOptions)>, Error> {
+    // Check empty read buffer before starting transaction. Otherwise, we would risk halting with an
+    // error in the middle of the transaction.
+    if operations.iter().any(|op| match op {
+        Operation::Read(read) => read.is_empty(),
+        Operation::Write(_) => false,
+    }) {
+        return Err(Error::Overrun);
+    }
+
     let mut operations = operations.iter_mut().peekable();
 
     let mut next_first_frame = true;
 
-    iter::from_fn(move || {
+    Ok(iter::from_fn(move || {
         let Some(op) = operations.next() else {
             return None;
         };
@@ -156,7 +165,7 @@ fn operation_frames<'a, 'b: 'a>(
         };
 
         Some((op, frame))
-    })
+    }))
 }
 
 impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
@@ -442,18 +451,9 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
     ///
     /// [transaction contract]: embedded_hal_1::i2c::I2c::transaction
     pub fn blocking_transaction(&mut self, addr: u8, operations: &mut [Operation<'_>]) -> Result<(), Error> {
-        // Check empty read buffer before starting transaction. Otherwise, we would not generate the
-        // stop condition below.
-        if operations.iter().any(|op| match op {
-            Operation::Read(read) => read.is_empty(),
-            Operation::Write(_) => false,
-        }) {
-            return Err(Error::Overrun);
-        }
-
         let timeout = self.timeout();
 
-        for (op, frame) in operation_frames(operations) {
+        for (op, frame) in operation_frames(operations)? {
             match op {
                 Operation::Read(read) => self.blocking_read_timeout(addr, read, timeout, frame)?,
                 Operation::Write(write) => self.write_bytes(addr, write, timeout, frame)?,
@@ -480,9 +480,12 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
         let dma_transfer = unsafe {
             let regs = T::regs();
             regs.cr2().modify(|w| {
+                // Note: Do not enable the ITBUFEN bit in the I2C_CR2 register if DMA is used for reception.
+                w.set_itbufen(false);
                 // DMA mode can be enabled for transmission by setting the DMAEN bit in the I2C_CR2 register.
                 w.set_dmaen(true);
-                w.set_itbufen(false);
+                // Sending NACK is not necessary (nor possible) for write transfer.
+                w.set_last(false);
             });
             // Set the I2C_DR register address in the DMA_SxPAR register. The data will be moved to this address from the memory after each TxE event.
             let dst = regs.dr().as_ptr() as *mut u8;
@@ -520,6 +523,9 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                         if sr1.start() {
                             Poll::Ready(Ok(()))
                         } else {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
                             Poll::Pending
                         }
                     }
@@ -537,6 +543,9 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                     Ok(_) => {
                         let sr2 = T::regs().sr2().read();
                         if !sr2.msl() && !sr2.busy() {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
                             Poll::Pending
                         } else {
                             Poll::Ready(Ok(()))
@@ -550,14 +559,14 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             Self::enable_interrupts();
             T::regs().dr().write(|reg| reg.set_dr(address << 1));
 
+            // Wait for the address to be acknowledged
             poll_fn(|cx| {
                 state.waker.register(cx.waker());
+
                 match Self::check_and_clear_error_flags() {
                     Err(e) => Poll::Ready(Err(e)),
                     Ok(sr1) => {
                         if sr1.addr() {
-                            // Clear the ADDR condition by reading SR2.
-                            T::regs().sr2().read();
                             Poll::Ready(Ok(()))
                         } else {
                             // If we need to go around, then re-enable the interrupts, otherwise nothing
@@ -569,8 +578,12 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                 }
             })
             .await?;
+
+            // Clear condition by reading SR2
+            T::regs().sr2().read();
         }
 
+        // Wait for bytes to be sent, or an error to occur.
         Self::enable_interrupts();
         let poll_error = poll_fn(|cx| {
             state.waker.register(cx.waker());
@@ -579,7 +592,12 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                 // Unclear why the Err turbofish is necessary here? The compiler didn’t require it in the other
                 // identical poll_fn check_and_clear matches.
                 Err(e) => Poll::Ready(Err::<T, Error>(e)),
-                Ok(_) => Poll::Pending,
+                Ok(_) => {
+                    // If we need to go around, then re-enable the interrupts, otherwise nothing
+                    // can wake us up and we'll hang.
+                    Self::enable_interrupts();
+                    Poll::Pending
+                }
             }
         });
 
@@ -589,52 +607,38 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             _ => Ok(()),
         }?;
 
-        // The I2C transfer itself will take longer than the DMA transfer, so wait for that to finish too.
-
-        // 18.3.8 “Master transmitter: In the interrupt routine after the EOT interrupt, disable DMA
-        // requests then wait for a BTF event before programming the Stop condition.”
-
-        // TODO: If this has to be done “in the interrupt routine after the EOT interrupt”, where to put it?
         T::regs().cr2().modify(|w| {
             w.set_dmaen(false);
         });
 
-        Self::enable_interrupts();
-        poll_fn(|cx| {
-            state.waker.register(cx.waker());
-
-            match Self::check_and_clear_error_flags() {
-                Err(e) => Poll::Ready(Err(e)),
-                Ok(sr1) => {
-                    if sr1.btf() {
-                        if frame.send_stop() {
-                            T::regs().cr1().modify(|w| {
-                                w.set_stop(true);
-                            });
-                        }
-
-                        Poll::Ready(Ok(()))
-                    } else {
-                        Poll::Pending
-                    }
-                }
-            }
-        })
-        .await?;
-
         if frame.send_stop() {
-            // Wait for STOP condition to transmit.
+            // The I2C transfer itself will take longer than the DMA transfer, so wait for that to finish too.
+
+            // 18.3.8 “Master transmitter: In the interrupt routine after the EOT interrupt, disable DMA
+            // requests then wait for a BTF event before programming the Stop condition.”
             Self::enable_interrupts();
             poll_fn(|cx| {
-                T::state().waker.register(cx.waker());
-                // TODO: error interrupts are enabled here, should we additional check for and return errors?
-                if T::regs().cr1().read().stop() {
-                    Poll::Pending
-                } else {
-                    Poll::Ready(Ok(()))
+                state.waker.register(cx.waker());
+
+                match Self::check_and_clear_error_flags() {
+                    Err(e) => Poll::Ready(Err(e)),
+                    Ok(sr1) => {
+                        if sr1.btf() {
+                            Poll::Ready(Ok(()))
+                        } else {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
+                            Poll::Pending
+                        }
+                    }
                 }
             })
             .await?;
+
+            T::regs().cr1().modify(|w| {
+                w.set_stop(true);
+            });
         }
 
         drop(on_drop);
@@ -669,15 +673,19 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
     where
         RXDMA: crate::i2c::RxDma<T>,
     {
-        let state = T::state();
         let buffer_len = buffer.len();
 
         let dma_transfer = unsafe {
             let regs = T::regs();
             regs.cr2().modify(|w| {
-                // DMA mode can be enabled for transmission by setting the DMAEN bit in the I2C_CR2 register.
+                // Note: Do not enable the ITBUFEN bit in the I2C_CR2 register if DMA is used for reception.
                 w.set_itbufen(false);
+                // DMA mode can be enabled for transmission by setting the DMAEN bit in the I2C_CR2 register.
                 w.set_dmaen(true);
+                // If, in the I2C_CR2 register, the LAST bit is set, I2C
+                // automatically sends a NACK after the next byte following EOT_1. The user can
+                // generate a Stop condition in the DMA Transfer Complete interrupt routine if enabled.
+                w.set_last(frame.send_nack() && buffer_len != 1);
             });
             // Set the I2C_DR register address in the DMA_SxPAR register. The data will be moved to this address from the memory after each TxE event.
             let src = regs.dr().as_ptr() as *mut u8;
@@ -696,6 +704,8 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             })
         });
 
+        let state = T::state();
+
         if frame.send_start() {
             // Send a START condition and set ACK bit
             Self::enable_interrupts();
@@ -714,6 +724,9 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                         if sr1.start() {
                             Poll::Ready(Ok(()))
                         } else {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
                             Poll::Pending
                         }
                     }
@@ -733,6 +746,9 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                     Ok(_) => {
                         let sr2 = T::regs().sr2().read();
                         if !sr2.msl() && !sr2.busy() {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
                             Poll::Pending
                         } else {
                             Poll::Ready(Ok(()))
@@ -743,11 +759,10 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             .await?;
 
             // Set up current address, we're trying to talk to
+            Self::enable_interrupts();
             T::regs().dr().write(|reg| reg.set_dr((address << 1) + 1));
 
             // Wait for the address to be acknowledged
-
-            Self::enable_interrupts();
             poll_fn(|cx| {
                 state.waker.register(cx.waker());
 
@@ -755,15 +770,11 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
                     Err(e) => Poll::Ready(Err(e)),
                     Ok(sr1) => {
                         if sr1.addr() {
-                            // 18.3.8: When a single byte must be received: the NACK must be programmed during EV6
-                            // event, i.e. program ACK=0 when ADDR=1, before clearing ADDR flag.
-                            if buffer_len == 1 && frame.send_nack() {
-                                T::regs().cr1().modify(|w| {
-                                    w.set_ack(false);
-                                });
-                            }
                             Poll::Ready(Ok(()))
                         } else {
+                            // If we need to go around, then re-enable the interrupts, otherwise nothing
+                            // can wake us up and we'll hang.
+                            Self::enable_interrupts();
                             Poll::Pending
                         }
                     }
@@ -771,24 +782,29 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             })
             .await?;
 
-            // Clear ADDR condition by reading SR2
+            // 18.3.8: When a single byte must be received: the NACK must be programmed during EV6
+            // event, i.e. program ACK=0 when ADDR=1, before clearing ADDR flag.
+            if frame.send_nack() && buffer_len == 1 {
+                T::regs().cr1().modify(|w| {
+                    w.set_ack(false);
+                });
+            }
+
+            // Clear condition by reading SR2
             T::regs().sr2().read();
+        } else if frame.send_nack() && buffer_len == 1 {
+            T::regs().cr1().modify(|w| {
+                w.set_ack(false);
+            });
         }
 
         // 18.3.8: When a single byte must be received: [snip] Then the
         // user can program the STOP condition either after clearing ADDR flag, or in the
         // DMA Transfer Complete interrupt routine.
-        if buffer_len == 1 && frame.send_stop() {
+        if frame.send_stop() && buffer_len == 1 {
             T::regs().cr1().modify(|w| {
                 w.set_stop(true);
             });
-        } else if buffer_len != 1 && frame.send_nack() {
-            // If, in the I2C_CR2 register, the LAST bit is set, I2C
-            // automatically sends a NACK after the next byte following EOT_1. The user can
-            // generate a Stop condition in the DMA Transfer Complete interrupt routine if enabled.
-            T::regs().cr2().modify(|w| {
-                w.set_last(true);
-            });
         }
 
         // Wait for bytes to be received, or an error to occur.
@@ -798,7 +814,12 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
 
             match Self::check_and_clear_error_flags() {
                 Err(e) => Poll::Ready(Err::<T, Error>(e)),
-                _ => Poll::Pending,
+                _ => {
+                    // If we need to go around, then re-enable the interrupts, otherwise nothing
+                    // can wake us up and we'll hang.
+                    Self::enable_interrupts();
+                    Poll::Pending
+                }
             }
         });
 
@@ -807,25 +828,14 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
             _ => Ok(()),
         }?;
 
-        if frame.send_stop() {
-            if buffer_len != 1 {
-                T::regs().cr1().modify(|w| {
-                    w.set_stop(true);
-                });
-            }
+        T::regs().cr2().modify(|w| {
+            w.set_dmaen(false);
+        });
 
-            // Wait for the STOP to be sent (STOP bit cleared).
-            Self::enable_interrupts();
-            poll_fn(|cx| {
-                state.waker.register(cx.waker());
-                // TODO: error interrupts are enabled here, should we additional check for and return errors?
-                if T::regs().cr1().read().stop() {
-                    Poll::Pending
-                } else {
-                    Poll::Ready(Ok(()))
-                }
-            })
-            .await?;
+        if frame.send_stop() && buffer_len != 1 {
+            T::regs().cr1().modify(|w| {
+                w.set_stop(true);
+            });
         }
 
         drop(on_drop);
@@ -843,6 +853,26 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
         self.write_frame(address, write, FrameOptions::FirstFrame).await?;
         self.read_frame(address, read, FrameOptions::FirstAndLastFrame).await
     }
+
+    /// Transaction with operations.
+    ///
+    /// Consecutive operations of same type are merged. See [transaction contract] for details.
+    ///
+    /// [transaction contract]: embedded_hal_1::i2c::I2c::transaction
+    pub async fn transaction(&mut self, addr: u8, operations: &mut [Operation<'_>]) -> Result<(), Error>
+    where
+        RXDMA: crate::i2c::RxDma<T>,
+        TXDMA: crate::i2c::TxDma<T>,
+    {
+        for (op, frame) in operation_frames(operations)? {
+            match op {
+                Operation::Read(read) => self.read_frame(addr, read, frame).await?,
+                Operation::Write(write) => self.write_frame(addr, write, frame).await?,
+            }
+        }
+
+        Ok(())
+    }
 }
 
 impl<'d, T: Instance, TXDMA, RXDMA> Drop for I2c<'d, T, TXDMA, RXDMA> {