stm32/i2c: fix races when using dma.

fixes #1341.
This commit is contained in:
xoviat 2023-04-17 15:24:24 -05:00
parent 201a038134
commit f5216624bb

View file

@ -1,6 +1,5 @@
use core::cmp;
use core::future::poll_fn;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::Poll;
use embassy_embedded_hal::SetConfig;
@ -35,14 +34,12 @@ impl Default for Config {
pub struct State {
waker: AtomicWaker,
chunks_transferred: AtomicUsize,
}
impl State {
pub(crate) const fn new() -> Self {
Self {
waker: AtomicWaker::new(),
chunks_transferred: AtomicUsize::new(0),
}
}
}
@ -130,10 +127,7 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
let isr = regs.isr().read();
if isr.tcr() || isr.tc() {
let state = T::state();
let transferred = state.chunks_transferred.load(Ordering::Relaxed);
state.chunks_transferred.store(transferred + 1, Ordering::Relaxed);
state.waker.wake();
T::state().waker.wake();
}
// The flag can only be cleared by writting to nbytes, we won't do that here, so disable
// the interrupt
@ -457,12 +451,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
TXDMA: crate::i2c::TxDma<T>,
{
let total_len = write.len();
let completed_chunks = total_len / 255;
let total_chunks = if completed_chunks * 255 == total_len {
completed_chunks
} else {
completed_chunks + 1
};
let dma_transfer = unsafe {
let regs = T::regs();
@ -480,7 +468,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
};
let state = T::state();
state.chunks_transferred.store(0, Ordering::Relaxed);
let mut remaining_len = total_len;
let on_drop = OnDrop::new(|| {
@ -495,33 +482,31 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
}
});
// NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers
if first_slice {
unsafe {
Self::master_write(
address,
total_len.min(255),
Stop::Software,
(total_chunks != 1) || !last_slice,
&check_timeout,
)?;
}
} else {
unsafe {
Self::master_continue(total_len.min(255), (total_chunks != 1) || !last_slice, &check_timeout)?;
T::regs().cr1().modify(|w| w.set_tcie(true));
}
}
poll_fn(|cx| {
state.waker.register(cx.waker());
let chunks_transferred = state.chunks_transferred.load(Ordering::Relaxed);
if chunks_transferred == total_chunks {
if remaining_len == total_len {
// NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers
if first_slice {
unsafe {
Self::master_write(
address,
total_len.min(255),
Stop::Software,
(total_len > 255) || !last_slice,
&check_timeout,
)?;
}
} else {
unsafe {
Self::master_continue(total_len.min(255), (total_len > 255) || !last_slice, &check_timeout)?;
T::regs().cr1().modify(|w| w.set_tcie(true));
}
}
} else if remaining_len == 0 {
return Poll::Ready(Ok(()));
} else if chunks_transferred != 0 {
remaining_len = remaining_len.saturating_sub(255);
let last_piece = (chunks_transferred + 1 == total_chunks) && last_slice;
} else {
let last_piece = (remaining_len <= 255) && last_slice;
// NOTE(unsafe) self.tx_dma does not fiddle with the i2c registers
unsafe {
@ -531,6 +516,8 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
T::regs().cr1().modify(|w| w.set_tcie(true));
}
}
remaining_len = remaining_len.saturating_sub(255);
Poll::Pending
})
.await?;
@ -559,12 +546,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
RXDMA: crate::i2c::RxDma<T>,
{
let total_len = buffer.len();
let completed_chunks = total_len / 255;
let total_chunks = if completed_chunks * 255 == total_len {
completed_chunks
} else {
completed_chunks + 1
};
let dma_transfer = unsafe {
let regs = T::regs();
@ -580,7 +561,6 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
};
let state = T::state();
state.chunks_transferred.store(0, Ordering::Relaxed);
let mut remaining_len = total_len;
let on_drop = OnDrop::new(|| {
@ -593,27 +573,24 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
}
});
// NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers
unsafe {
Self::master_read(
address,
total_len.min(255),
Stop::Software,
total_chunks != 1,
restart,
&check_timeout,
)?;
}
poll_fn(|cx| {
state.waker.register(cx.waker());
let chunks_transferred = state.chunks_transferred.load(Ordering::Relaxed);
if chunks_transferred == total_chunks {
if remaining_len == total_len {
// NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers
unsafe {
Self::master_read(
address,
total_len.min(255),
Stop::Software,
total_len > 255,
restart,
&check_timeout,
)?;
}
} else if remaining_len == 0 {
return Poll::Ready(Ok(()));
} else if chunks_transferred != 0 {
remaining_len = remaining_len.saturating_sub(255);
let last_piece = chunks_transferred + 1 == total_chunks;
} else {
let last_piece = remaining_len <= 255;
// NOTE(unsafe) self.rx_dma does not fiddle with the i2c registers
unsafe {
@ -623,6 +600,8 @@ impl<'d, T: Instance, TXDMA, RXDMA> I2c<'d, T, TXDMA, RXDMA> {
T::regs().cr1().modify(|w| w.set_tcie(true));
}
}
remaining_len = remaining_len.saturating_sub(255);
Poll::Pending
})
.await?;