diff --git a/embassy-stm32/src/dma/mod.rs b/embassy-stm32/src/dma/mod.rs index f27a55f61..5e4b5d3de 100644 --- a/embassy-stm32/src/dma/mod.rs +++ b/embassy-stm32/src/dma/mod.rs @@ -8,7 +8,13 @@ mod dmamux; #[cfg(dmamux)] pub use dmamux::*; +use core::future::Future; +use core::marker::PhantomData; +use core::pin::Pin; +use core::task::Waker; +use core::task::{Context, Poll}; use embassy::util::Unborrow; +use embassy_hal_common::unborrow; #[cfg(feature = "unstable-pac")] pub use transfers::*; @@ -23,7 +29,7 @@ pub type Request = (); pub(crate) mod sealed { use super::*; - use core::task::Waker; + pub trait Channel { /// Starts this channel for writing a stream of words. unsafe fn start_write(&mut self, request: Request, buf: &[W], reg_addr: *mut u32); @@ -86,104 +92,89 @@ impl Word for u32 { } mod transfers { - use core::task::Poll; - - use super::Channel; - use embassy_hal_common::{drop::OnDrop, unborrow}; - use futures::future::poll_fn; - use super::*; #[allow(unused)] - pub async fn read<'a, W: Word>( - channel: &mut impl Unborrow, + pub fn read<'a, W: Word>( + channel: impl Unborrow + 'a, request: Request, reg_addr: *mut u32, buf: &'a mut [W], - ) { + ) -> impl Future + 'a { assert!(buf.len() <= 0xFFFF); - let drop_clone = unsafe { channel.unborrow() }; unborrow!(channel); - channel.request_stop(); - let on_drop = OnDrop::new({ - let mut channel = drop_clone; - move || { - channel.request_stop(); - } - }); - unsafe { channel.start_read::(request, reg_addr, buf) }; - wait_for_stopped(&mut channel).await; - drop(on_drop) + + Transfer { + channel, + _phantom: PhantomData, + } } #[allow(unused)] - pub async fn write<'a, W: Word>( - channel: &mut impl Unborrow, + pub fn write<'a, W: Word>( + channel: impl Unborrow + 'a, request: Request, buf: &'a [W], reg_addr: *mut u32, - ) { + ) -> impl Future + 'a { assert!(buf.len() <= 0xFFFF); - let drop_clone = unsafe { channel.unborrow() }; unborrow!(channel); - channel.request_stop(); - let on_drop = OnDrop::new({ - let mut channel = drop_clone; - move || { - channel.request_stop(); - } - }); - unsafe { channel.start_write::(request, buf, reg_addr) }; - wait_for_stopped(&mut channel).await; - drop(on_drop) + + Transfer { + channel, + _phantom: PhantomData, + } } #[allow(unused)] - pub async fn write_repeated( - channel: &mut impl Unborrow, + pub fn write_repeated<'a, W: Word>( + channel: impl Unborrow + 'a, request: Request, repeated: W, count: usize, reg_addr: *mut u32, - ) { - let drop_clone = unsafe { channel.unborrow() }; + ) -> impl Future + 'a { unborrow!(channel); - channel.request_stop(); - let on_drop = OnDrop::new({ - let mut channel = drop_clone; - move || { - channel.request_stop(); - } - }); - unsafe { channel.start_write_repeated::(request, repeated, count, reg_addr) }; - wait_for_stopped(&mut channel).await; - drop(on_drop) + + Transfer { + channel, + _phantom: PhantomData, + } } - async fn wait_for_stopped(channel: &mut impl Unborrow) { - unborrow!(channel); - poll_fn(move |cx| { - channel.set_waker(cx.waker()); + struct Transfer<'a, C: Channel> { + channel: C, + _phantom: PhantomData<&'a mut C>, + } - // TODO in the future, error checking could be added so that this function returns an error + impl<'a, C: Channel> Drop for Transfer<'a, C> { + fn drop(&mut self) { + self.channel.request_stop(); + while self.channel.is_running() {} + } + } - if channel.is_running() { + impl<'a, C: Channel> Unpin for Transfer<'a, C> {} + impl<'a, C: Channel> Future for Transfer<'a, C> { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.channel.set_waker(cx.waker()); + if self.channel.is_running() { Poll::Pending } else { Poll::Ready(()) } - }) - .await + } } } -pub trait Channel: sealed::Channel + Unborrow {} +pub trait Channel: sealed::Channel + Unborrow + 'static {} pub struct NoDma;