From d5f0bceb7ce33c4ac761daad7aed052ae646e36a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=D0=A0=D0=BE=D0=BC=D0=B0=D0=BD=20=D0=9A=D1=80=D0=B8=D0=B2?=
 =?UTF-8?q?=D0=B5=D0=BD=D0=BA=D0=BE=D0=B2?= <qwerty19106@gmail.com>
Date: Tue, 30 Apr 2024 15:42:11 +0400
Subject: [PATCH] Clear Receiver enable flag before write in Half-Duplex mode

---
 embassy-stm32/src/usart/mod.rs | 50 +++++++++++++++++++++++++++++++---
 1 file changed, 46 insertions(+), 4 deletions(-)

diff --git a/embassy-stm32/src/usart/mod.rs b/embassy-stm32/src/usart/mod.rs
index 68899bfff..a6dfbd482 100644
--- a/embassy-stm32/src/usart/mod.rs
+++ b/embassy-stm32/src/usart/mod.rs
@@ -359,16 +359,32 @@ impl<'d, T: BasicInstance> UartTx<'d, T, Async> {
 
     /// Initiate an asynchronous UART write
     pub async fn write(&mut self, buffer: &[u8]) -> Result<(), Error> {
+        let r = T::regs();
+
+        // Disable Receiver for Half-Duplex mode
+        if r.cr3().read().hdsel() {
+            r.cr1().modify(|reg| reg.set_re(false));
+        }
+
         let ch = self.tx_dma.as_mut().unwrap();
-        T::regs().cr3().modify(|reg| {
+        r.cr3().modify(|reg| {
             reg.set_dmat(true);
         });
         // If we don't assign future to a variable, the data register pointer
         // is held across an await and makes the future non-Send.
-        let transfer = unsafe { ch.write(buffer, tdr(T::regs()), Default::default()) };
+        let transfer = unsafe { ch.write(buffer, tdr(r), Default::default()) };
         transfer.await;
         Ok(())
     }
+
+    async fn flush_inner() -> Result<(), Error> {
+        Self::blocking_flush_inner()
+    }
+
+    /// Wait until transmission complete
+    pub async fn flush(&mut self) -> Result<(), Error> {
+        Self::flush_inner().await
+    }
 }
 
 impl<'d, T: BasicInstance> UartTx<'d, T, Blocking> {
@@ -436,6 +452,12 @@ impl<'d, T: BasicInstance, M: Mode> UartTx<'d, T, M> {
     /// Perform a blocking UART write
     pub fn blocking_write(&mut self, buffer: &[u8]) -> Result<(), Error> {
         let r = T::regs();
+
+        // Disable Receiver for Half-Duplex mode
+        if r.cr3().read().hdsel() {
+            r.cr1().modify(|reg| reg.set_re(false));
+        }
+
         for &b in buffer {
             while !sr(r).read().txe() {}
             unsafe { tdr(r).write_volatile(b) };
@@ -443,12 +465,21 @@ impl<'d, T: BasicInstance, M: Mode> UartTx<'d, T, M> {
         Ok(())
     }
 
-    /// Block until transmission complete
-    pub fn blocking_flush(&mut self) -> Result<(), Error> {
+    fn blocking_flush_inner() -> Result<(), Error> {
         let r = T::regs();
         while !sr(r).read().tc() {}
+
+        // Enable Receiver after transmission complete for Half-Duplex mode
+        if r.cr3().read().hdsel() {
+            r.cr1().modify(|reg| reg.set_re(true));
+        }
         Ok(())
     }
+
+    /// Block until transmission complete
+    pub fn blocking_flush(&mut self) -> Result<(), Error> {
+        Self::blocking_flush_inner()
+    }
 }
 
 impl<'d, T: BasicInstance> UartRx<'d, T, Async> {
@@ -502,6 +533,11 @@ impl<'d, T: BasicInstance> UartRx<'d, T, Async> {
     ) -> Result<ReadCompletionEvent, Error> {
         let r = T::regs();
 
+        // Call flush for Half-Duplex mode. It prevents reading of bytes which have just been written.
+        if r.cr3().read().hdsel() {
+            UartTx::<'d, T, Async>::flush_inner().await?;
+        }
+
         // make sure USART state is restored to neutral state when this future is dropped
         let on_drop = OnDrop::new(move || {
             // defmt::trace!("Clear all USART interrupts and DMA Read Request");
@@ -825,6 +861,12 @@ impl<'d, T: BasicInstance, M: Mode> UartRx<'d, T, M> {
     /// Perform a blocking read into `buffer`
     pub fn blocking_read(&mut self, buffer: &mut [u8]) -> Result<(), Error> {
         let r = T::regs();
+
+        // Call flush for Half-Duplex mode. It prevents reading of bytes which have just been written.
+        if r.cr3().read().hdsel() {
+            UartTx::<'d, T, M>::blocking_flush_inner()?;
+        }
+
         for b in buffer {
             while !self.check_rx_flags()? {}
             unsafe { *b = rdr(r).read_volatile() }