diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index df0efa511..124316a29 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -443,23 +443,8 @@ unsafe fn write_dma(i: usize, buf: &[u8]) { impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { type ReadFuture<'a> = impl Future> + 'a where Self: 'a; - type DataReadyFuture<'a> = impl Future + 'a where Self: 'a; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { - async move { - let i = self.info.addr.index(); - assert!(i != 0); - - self.wait_data_ready().await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); - - unsafe { read_dma::(i, buf) } - } - } - - fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a> { async move { let i = self.info.addr.index(); assert!(i != 0); @@ -475,6 +460,11 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { } }) .await; + + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); + + unsafe { read_dma::(i, buf) } } } } diff --git a/embassy-usb-hid/src/async_lease.rs b/embassy-usb-hid/src/async_lease.rs deleted file mode 100644 index 0971daa25..000000000 --- a/embassy-usb-hid/src/async_lease.rs +++ /dev/null @@ -1,90 +0,0 @@ -use core::cell::Cell; -use core::future::Future; -use core::task::{Poll, Waker}; - -enum AsyncLeaseState { - Empty, - Waiting(*mut u8, usize, Waker), - Done(usize), -} - -impl Default for AsyncLeaseState { - fn default() -> Self { - AsyncLeaseState::Empty - } -} - -#[derive(Default)] -pub struct AsyncLease { - state: Cell, -} - -pub struct AsyncLeaseFuture<'a> { - buf: &'a mut [u8], - state: &'a Cell, -} - -impl<'a> Drop for AsyncLeaseFuture<'a> { - fn drop(&mut self) { - self.state.set(AsyncLeaseState::Empty); - } -} - -impl<'a> Future for AsyncLeaseFuture<'a> { - type Output = usize; - - fn poll( - mut self: core::pin::Pin<&mut Self>, - cx: &mut core::task::Context<'_>, - ) -> Poll { - match self.state.take() { - AsyncLeaseState::Done(len) => Poll::Ready(len), - state => { - if let AsyncLeaseState::Waiting(ptr, _, _) = state { - assert_eq!( - ptr, - self.buf.as_mut_ptr(), - "lend() called on a busy AsyncLease." - ); - } - - self.state.set(AsyncLeaseState::Waiting( - self.buf.as_mut_ptr(), - self.buf.len(), - cx.waker().clone(), - )); - Poll::Pending - } - } - } -} - -pub struct AsyncLeaseNotReady {} - -impl AsyncLease { - pub fn new() -> Self { - Default::default() - } - - pub fn try_borrow_mut usize>( - &self, - f: F, - ) -> Result<(), AsyncLeaseNotReady> { - if let AsyncLeaseState::Waiting(data, len, waker) = self.state.take() { - let buf = unsafe { core::slice::from_raw_parts_mut(data, len) }; - let len = f(buf); - self.state.set(AsyncLeaseState::Done(len)); - waker.wake(); - Ok(()) - } else { - Err(AsyncLeaseNotReady {}) - } - } - - pub fn lend<'a>(&'a self, buf: &'a mut [u8]) -> AsyncLeaseFuture<'a> { - AsyncLeaseFuture { - buf, - state: &self.state, - } - } -} diff --git a/embassy-usb-hid/src/lib.rs b/embassy-usb-hid/src/lib.rs index 43e678806..527f014f2 100644 --- a/embassy-usb-hid/src/lib.rs +++ b/embassy-usb-hid/src/lib.rs @@ -8,24 +8,21 @@ pub(crate) mod fmt; use core::mem::MaybeUninit; +use core::ops::Range; -use async_lease::AsyncLease; use embassy::time::Duration; -use embassy_usb::driver::{EndpointOut, ReadError}; +use embassy_usb::driver::EndpointOut; use embassy_usb::{ control::{ControlHandler, InResponse, OutResponse, Request, RequestType}, driver::{Driver, Endpoint, EndpointIn, WriteError}, UsbDeviceBuilder, }; -use futures_util::future::{select, Either}; -use futures_util::pin_mut; + #[cfg(feature = "usbd-hid")] use ssmarshal::serialize; #[cfg(feature = "usbd-hid")] use usbd_hid::descriptor::AsInputReport; -mod async_lease; - const USB_CLASS_HID: u8 = 0x03; const USB_SUBCLASS_NONE: u8 = 0x00; const USB_PROTOCOL_NONE: u8 = 0x00; @@ -64,14 +61,12 @@ impl ReportId { pub struct State<'a, const IN_N: usize, const OUT_N: usize> { control: MaybeUninit>, - lease: AsyncLease, } impl<'a, const IN_N: usize, const OUT_N: usize> State<'a, IN_N, OUT_N> { pub fn new() -> Self { State { control: MaybeUninit::uninit(), - lease: AsyncLease::new(), } } } @@ -90,9 +85,9 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> { /// high performance uses, and a value of 255 is good for best-effort usecases. /// /// This allocates an IN endpoint only. - pub fn new( + pub fn new( builder: &mut UsbDeviceBuilder<'d, D>, - state: &'d mut State<'d, IN_N, 0>, + state: &'d mut State<'d, IN_N, OUT_N>, report_descriptor: &'static [u8], request_handler: Option<&'d dyn RequestHandler>, poll_ms: u8, @@ -101,8 +96,7 @@ impl<'d, D: Driver<'d>, const IN_N: usize> HidClass<'d, D, (), IN_N> { let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); let control = state .control - .write(Control::new(report_descriptor, None, request_handler)); - + .write(Control::new(report_descriptor, request_handler)); control.build(builder, None, &ep_in); Self { @@ -144,20 +138,14 @@ impl<'d, D: Driver<'d>, const IN_N: usize, const OUT_N: usize> let ep_out = builder.alloc_interrupt_endpoint_out(max_packet_size, poll_ms); let ep_in = builder.alloc_interrupt_endpoint_in(max_packet_size, poll_ms); - let control = state.control.write(Control::new( - report_descriptor, - Some(&state.lease), - request_handler, - )); - + let control = state + .control + .write(Control::new(report_descriptor, request_handler)); control.build(builder, Some(&ep_out), &ep_in); Self { input: ReportWriter { ep_in }, - output: ReportReader { - ep_out, - lease: &state.lease, - }, + output: ReportReader { ep_out, offset: 0 }, } } @@ -178,7 +166,21 @@ pub struct ReportWriter<'d, D: Driver<'d>, const N: usize> { pub struct ReportReader<'d, D: Driver<'d>, const N: usize> { ep_out: D::EndpointOut, - lease: &'d AsyncLease, + offset: usize, +} + +pub enum ReadError { + BufferOverflow, + Sync(Range), +} + +impl From for ReadError { + fn from(val: embassy_usb::driver::ReadError) -> Self { + use embassy_usb::driver::ReadError::*; + match val { + BufferOverflow => ReadError::BufferOverflow, + } + } } impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { @@ -216,31 +218,55 @@ impl<'d, D: Driver<'d>, const N: usize> ReportWriter<'d, D, N> { } impl<'d, D: Driver<'d>, const N: usize> ReportReader<'d, D, N> { + /// Starts a task to deliver output reports from the Interrupt Out pipe to + /// `handler`. + pub async fn run(mut self, handler: &T) -> ! { + assert!(self.offset == 0); + let mut buf = [0; N]; + loop { + match self.read(&mut buf).await { + Ok(len) => { handler.set_report(ReportId::Out(0), &buf[0..len]); } + Err(ReadError::BufferOverflow) => warn!("Host sent output report larger than the configured maximum output report length ({})", N), + Err(ReadError::Sync(_)) => unreachable!(), + } + } + } + + /// Reads an output report from the Interrupt Out pipe. + /// + /// **Note:** Any reports sent from the host over the control pipe will be + /// passed to [`RequestHandler::set_report()`] for handling. The application + /// is responsible for ensuring output reports from both pipes are handled + /// correctly. + /// + /// **Note:** If `N` > the maximum packet size of the endpoint (i.e. output + /// reports may be split across multiple packets) and this method's future + /// is dropped after some packets have been read, the next call to `read()` + /// will return a [`ReadError::SyncError()`]. The range in the sync error + /// indicates the portion `buf` that was filled by the current call to + /// `read()`. If the dropped future used the same `buf`, then `buf` will + /// contain the full report. pub async fn read(&mut self, buf: &mut [u8]) -> Result { assert!(buf.len() >= N); - // Wait until a packet is ready to read from the endpoint or a SET_REPORT control request is received - { - let data_ready = self.ep_out.wait_data_ready(); - pin_mut!(data_ready); - match select(data_ready, self.lease.lend(buf)).await { - Either::Left(_) => (), - Either::Right((len, _)) => return Ok(len), - } - } - // Read packets from the endpoint let max_packet_size = usize::from(self.ep_out.info().max_packet_size); - let mut total = 0; - for chunk in buf.chunks_mut(max_packet_size) { + let starting_offset = self.offset; + for chunk in buf[starting_offset..].chunks_mut(max_packet_size) { let size = self.ep_out.read(chunk).await?; - total += size; - if size < max_packet_size || total == N { + self.offset += size; + if size < max_packet_size || self.offset == N { break; } } - Ok(total) + let total = self.offset; + self.offset = 0; + if starting_offset > 0 { + Err(ReadError::Sync(starting_offset..total)) + } else { + Ok(total) + } } } @@ -254,10 +280,6 @@ pub trait RequestHandler { } /// Sets the value of report `id` to `data`. - /// - /// If an output endpoint has been allocated, output reports - /// are routed through [`HidClass::output()`]. Otherwise they - /// are sent here, along with input and feature reports. fn set_report(&self, id: ReportId, data: &[u8]) -> OutResponse { let _ = (id, data); OutResponse::Rejected @@ -266,8 +288,8 @@ pub trait RequestHandler { /// Get the idle rate for `id`. /// /// If `id` is `None`, get the idle rate for all reports. Returning `None` - /// will reject the control request. Any duration above 1.020 seconds or 0 - /// will be returned as an indefinite idle rate. + /// will reject the control request. Any duration at or above 1.024 seconds + /// or below 4ms will be returned as an indefinite idle rate. fn get_idle(&self, id: Option) -> Option { let _ = id; None @@ -284,7 +306,6 @@ pub trait RequestHandler { struct Control<'d> { report_descriptor: &'static [u8], - out_lease: Option<&'d AsyncLease>, request_handler: Option<&'d dyn RequestHandler>, hid_descriptor: [u8; 9], } @@ -292,12 +313,10 @@ struct Control<'d> { impl<'a> Control<'a> { fn new( report_descriptor: &'static [u8], - out_lease: Option<&'a AsyncLease>, request_handler: Option<&'a dyn RequestHandler>, ) -> Self { Control { report_descriptor, - out_lease, request_handler, hid_descriptor: [ // Length of buf inclusive of size prefix @@ -370,7 +389,7 @@ impl<'d> ControlHandler for Control<'d> { if let RequestType::Class = req.request_type { match req.request { HID_REQ_SET_IDLE => { - if let Some(handler) = self.request_handler.as_ref() { + if let Some(handler) = self.request_handler { let id = req.value as u8; let id = (id != 0).then(|| ReportId::In(id)); let dur = u64::from(req.value >> 8); @@ -383,25 +402,8 @@ impl<'d> ControlHandler for Control<'d> { } OutResponse::Accepted } - HID_REQ_SET_REPORT => match ( - ReportId::try_from(req.value), - self.out_lease, - self.request_handler.as_ref(), - ) { - (Ok(ReportId::Out(_)), Some(lease), _) => { - match lease.try_borrow_mut(|buf| { - let len = buf.len().min(data.len()); - buf[0..len].copy_from_slice(&data[0..len]); - len - }) { - Ok(()) => OutResponse::Accepted, - Err(_) => { - warn!("SET_REPORT received for output report with no reader listening."); - OutResponse::Rejected - } - } - } - (Ok(id), _, Some(handler)) => handler.set_report(id, data), + HID_REQ_SET_REPORT => match (ReportId::try_from(req.value), self.request_handler) { + (Ok(id), Some(handler)) => handler.set_report(id, data), _ => OutResponse::Rejected, }, HID_REQ_SET_PROTOCOL => { @@ -429,10 +431,7 @@ impl<'d> ControlHandler for Control<'d> { }, (RequestType::Class, HID_REQ_GET_REPORT) => { let size = match ReportId::try_from(req.value) { - Ok(id) => self - .request_handler - .as_ref() - .and_then(|x| x.get_report(id, buf)), + Ok(id) => self.request_handler.and_then(|x| x.get_report(id, buf)), Err(_) => None, }; @@ -443,7 +442,7 @@ impl<'d> ControlHandler for Control<'d> { } } (RequestType::Class, HID_REQ_GET_IDLE) => { - if let Some(handler) = self.request_handler.as_ref() { + if let Some(handler) = self.request_handler { let id = req.value as u8; let id = (id != 0).then(|| ReportId::In(id)); if let Some(dur) = handler.get_idle(id) { diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index 03e39b8c9..82b59bd1e 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -120,9 +120,6 @@ pub trait Endpoint { pub trait EndpointOut: Endpoint { type ReadFuture<'a>: Future> + 'a - where - Self: 'a; - type DataReadyFuture<'a>: Future + 'a where Self: 'a; @@ -131,11 +128,6 @@ pub trait EndpointOut: Endpoint { /// /// This should also clear any NAK flags and prepare the endpoint to receive the next packet. fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; - - /// Waits until a packet of data is ready to be read from the endpoint. - /// - /// A call to[`read()`](Self::read()) after this future completes should not block. - fn wait_data_ready<'a>(&'a mut self) -> Self::DataReadyFuture<'a>; } pub trait ControlPipe { diff --git a/examples/nrf/src/bin/usb_hid.rs b/examples/nrf/src/bin/usb_hid.rs index 5253f225d..6ffb1fd40 100644 --- a/examples/nrf/src/bin/usb_hid.rs +++ b/examples/nrf/src/bin/usb_hid.rs @@ -52,7 +52,7 @@ async fn main(_spawner: Spawner, p: Peripherals) { let mut control_buf = [0; 16]; let request_handler = MyRequestHandler {}; - let mut state = State::<5, 0>::new(); + let mut control = State::<5, 0>::new(); let mut builder = UsbDeviceBuilder::new( driver, @@ -66,7 +66,7 @@ async fn main(_spawner: Spawner, p: Peripherals) { // Create classes on the builder. let mut hid = HidClass::new( &mut builder, - &mut state, + &mut control, MouseReport::desc(), Some(&request_handler), 60,