From 7256ff3e71ceea9091349b040a2ebc987aca590c Mon Sep 17 00:00:00 2001 From: huntc Date: Fri, 10 Dec 2021 12:08:00 +1100 Subject: [PATCH] Provides AsyncWrite with flush As per Tokio and others, this commit provides a `poll_flush` method on `AsyncWrite` so that a best-effort attempt at wakening once all bytes are flushed can be made. --- embassy-hal-common/src/ring_buffer.rs | 7 ++++++ embassy-hal-common/src/usb/usb_serial.rs | 15 +++++++++++ embassy-nrf/src/buffered_uarte.rs | 14 +++++++++++ embassy/src/io/traits.rs | 17 +++++++++++++ embassy/src/io/util/flush.rs | 32 ++++++++++++++++++++++++ embassy/src/io/util/mod.rs | 12 +++++++++ examples/nrf/src/bin/buffered_uart.rs | 3 +++ 7 files changed, 100 insertions(+) create mode 100644 embassy/src/io/util/flush.rs diff --git a/embassy-hal-common/src/ring_buffer.rs b/embassy-hal-common/src/ring_buffer.rs index 6829f62f5..fcad68bb1 100644 --- a/embassy-hal-common/src/ring_buffer.rs +++ b/embassy-hal-common/src/ring_buffer.rs @@ -125,5 +125,12 @@ mod tests { let buf = rb.pop_buf(); assert_eq!(1, buf.len()); assert_eq!(4, buf[0]); + rb.pop(1); + + let buf = rb.pop_buf(); + assert_eq!(0, buf.len()); + + let buf = rb.push_buf(); + assert_eq!(4, buf.len()); } } diff --git a/embassy-hal-common/src/usb/usb_serial.rs b/embassy-hal-common/src/usb/usb_serial.rs index ca43a4d73..2592d05a6 100644 --- a/embassy-hal-common/src/usb/usb_serial.rs +++ b/embassy-hal-common/src/usb/usb_serial.rs @@ -106,6 +106,17 @@ where serial.poll_write(cx, buf) }) } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + mutex.with(|state| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + serial.poll_flush(cx) + }) + } } pub struct UsbSerial<'bus, 'a, B: UsbBus> { @@ -167,6 +178,10 @@ impl<'bus, 'a, B: UsbBus> AsyncWrite for UsbSerial<'bus, 'a, B> { this.flush_write(); Poll::Ready(Ok(count)) } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } } /// Keeps track of the type of the last written packet. diff --git a/embassy-nrf/src/buffered_uarte.rs b/embassy-nrf/src/buffered_uarte.rs index 9b0451c12..e3ca74384 100644 --- a/embassy-nrf/src/buffered_uarte.rs +++ b/embassy-nrf/src/buffered_uarte.rs @@ -266,6 +266,20 @@ impl<'d, U: UarteInstance, T: TimerInstance> AsyncWrite for BufferedUarte<'d, U, poll } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.with(|state| { + trace!("poll_flush"); + + if !state.tx.is_empty() { + trace!("poll_flush: pending"); + state.tx_waker.register(cx.waker()); + return Poll::Pending; + } + + Poll::Ready(Ok(())) + }) + } } impl<'a, U: UarteInstance, T: TimerInstance> Drop for StateInner<'a, U, T> { diff --git a/embassy/src/io/traits.rs b/embassy/src/io/traits.rs index 8e4a981da..06500a687 100644 --- a/embassy/src/io/traits.rs +++ b/embassy/src/io/traits.rs @@ -89,6 +89,15 @@ pub trait AsyncWrite { /// `poll_write` must try to make progress by flushing the underlying object if /// that is the only way the underlying object can become writable again. fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; + + /// Attempt to flush the object, ensuring that any buffered data reach their destination. + /// + /// On success, returns Poll::Ready(Ok(())). + /// + /// If flushing cannot immediately complete, this method returns [Poll::Pending] and arranges for the + /// current task (via cx.waker()) to receive a notification when the object can make progress + /// towards flushing. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; } macro_rules! defer_async_read { @@ -135,6 +144,10 @@ macro_rules! deref_async_write { ) -> Poll> { Pin::new(&mut **self).poll_write(cx, buf) } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_flush(cx) + } }; } @@ -155,4 +168,8 @@ where fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { self.get_mut().as_mut().poll_write(cx, buf) } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().as_mut().poll_flush(cx) + } } diff --git a/embassy/src/io/util/flush.rs b/embassy/src/io/util/flush.rs new file mode 100644 index 000000000..966ef10fb --- /dev/null +++ b/embassy/src/io/util/flush.rs @@ -0,0 +1,32 @@ +use core::pin::Pin; +use futures::future::Future; +use futures::ready; +use futures::task::{Context, Poll}; + +use super::super::error::Result; +use super::super::traits::AsyncWrite; + +/// Future for the [`flush`](super::AsyncWriteExt::flush) method. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Flush<'a, W: ?Sized> { + writer: &'a mut W, +} + +impl Unpin for Flush<'_, W> {} + +impl<'a, W: AsyncWrite + ?Sized + Unpin> Flush<'a, W> { + pub(super) fn new(writer: &'a mut W) -> Self { + Flush { writer } + } +} + +impl Future for Flush<'_, W> { + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + let _ = ready!(Pin::new(&mut this.writer).poll_flush(cx))?; + Poll::Ready(Ok(())) + } +} diff --git a/embassy/src/io/util/mod.rs b/embassy/src/io/util/mod.rs index de6643cb3..49758ba99 100644 --- a/embassy/src/io/util/mod.rs +++ b/embassy/src/io/util/mod.rs @@ -27,6 +27,9 @@ pub use self::skip_while::SkipWhile; mod drain; pub use self::drain::Drain; +mod flush; +pub use self::flush::Flush; + mod write; pub use self::write::Write; @@ -160,6 +163,15 @@ pub trait AsyncWriteExt: AsyncWrite { { Write::new(self, buf) } + + /// Awaits until all bytes have actually been written, and + /// not just enqueued as per the other "write" methods. + fn flush<'a>(&mut self) -> Flush + where + Self: Unpin, + { + Flush::new(self) + } } impl AsyncWriteExt for R {} diff --git a/examples/nrf/src/bin/buffered_uart.rs b/examples/nrf/src/bin/buffered_uart.rs index 5d9075edf..c3e07e44a 100644 --- a/examples/nrf/src/bin/buffered_uart.rs +++ b/examples/nrf/src/bin/buffered_uart.rs @@ -61,5 +61,8 @@ async fn main(_spawner: Spawner, p: Peripherals) { info!("writing..."); unwrap!(u.write_all(&buf).await); info!("write done"); + + // Wait until the bytes are actually finished being transmitted + unwrap!(u.flush().await); } }