From ffa0c08140be6c90bde9f11e797eda95e4b8331b Mon Sep 17 00:00:00 2001
From: xoviat <xoviat@users.noreply.github.com>
Date: Sun, 30 Jul 2023 20:22:14 -0500
Subject: [PATCH] stm32/dma: fix condition check

---
 embassy-stm32/src/dma/ringbuffer.rs | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/embassy-stm32/src/dma/ringbuffer.rs b/embassy-stm32/src/dma/ringbuffer.rs
index 8056a7c3a..c3e4f20c0 100644
--- a/embassy-stm32/src/dma/ringbuffer.rs
+++ b/embassy-stm32/src/dma/ringbuffer.rs
@@ -230,9 +230,10 @@ impl<'a, W: Word> WritableDmaRingBuffer<'a, W> {
         let start = self.pos(dma.get_remaining_transfers());
         if start > self.end {
             trace!(
-                "[1]: start, end, complete_count: {}, {}, {}",
+                "[1]: start, end, len, complete_count: {}, {}, {}, {}",
                 start,
                 self.end,
+                buf.len(),
                 dma.get_complete_count()
             );
 
@@ -242,8 +243,9 @@ impl<'a, W: Word> WritableDmaRingBuffer<'a, W> {
             compiler_fence(Ordering::SeqCst);
 
             // Confirm that the DMA is not inside data we could have written
-            let pos = self.pos(dma.get_remaining_transfers());
-            if (pos > self.end && pos <= start) || dma.get_complete_count() > 0 {
+            let (pos, complete_count) =
+                critical_section::with(|_| (self.pos(dma.get_remaining_transfers()), dma.get_complete_count()));
+            if (pos >= self.end && pos < start) || (complete_count > 0 && pos >= start) || complete_count > 1 {
                 Err(OverrunError)
             } else {
                 self.end = (self.end + len) % self.cap();
@@ -252,18 +254,20 @@ impl<'a, W: Word> WritableDmaRingBuffer<'a, W> {
             }
         } else if start == self.end && dma.get_complete_count() == 0 {
             trace!(
-                "[2]: start, end, complete_count: {}, {}, {}",
+                "[2]: start, end, len, complete_count: {}, {}, {}, {}",
                 start,
                 self.end,
+                buf.len(),
                 dma.get_complete_count()
             );
 
             Ok((0, 0))
         } else if start <= self.end && self.end + buf.len() < self.cap() {
             trace!(
-                "[3]: start, end, complete_count: {}, {}, {}",
+                "[3]: start, end, len, complete_count: {}, {}, {}, {}",
                 start,
                 self.end,
+                buf.len(),
                 dma.get_complete_count()
             );
 
@@ -286,9 +290,10 @@ impl<'a, W: Word> WritableDmaRingBuffer<'a, W> {
             }
         } else {
             trace!(
-                "[4]: start, end, complete_count: {}, {}, {}",
+                "[4]: start, end, len, complete_count: {}, {}, {}, {}",
                 start,
                 self.end,
+                buf.len(),
                 dma.get_complete_count()
             );