diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 124316a29..c16e1be00 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -17,6 +17,7 @@ use futures::future::poll_fn; use futures::Future; pub use embassy_usb; +use pac::usbd::RegisterBlock; use crate::interrupt::Interrupt; use crate::pac; @@ -92,11 +93,11 @@ impl<'d, T: Instance> Driver<'d, T> { regs.epdatastatus.write(|w| unsafe { w.bits(r) }); READY_ENDPOINTS.fetch_or(r, Ordering::AcqRel); for i in 1..=7 { - if r & (1 << i) != 0 { - EP_IN_WAKERS[i - 1].wake(); + if r & In::mask(i) != 0 { + In::waker(i).wake(); } - if r & (1 << (i + 16)) != 0 { - EP_OUT_WAKERS[i - 1].wake(); + if r & Out::mask(i) != 0 { + Out::waker(i).wake(); } } } @@ -272,32 +273,48 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { #[inline] fn reset(&mut self) { + self.set_configured(false); + } + + #[inline] + fn set_configured(&mut self, configured: bool) { let regs = T::regs(); - // TODO: Initialize ISO buffers - - // XXX this is not spec compliant; the endpoints should only be enabled after the device - // has been put in the Configured state. However, usb-device provides no hook to do that unsafe { - regs.epinen.write(|w| w.bits(self.alloc_in.used.into())); - regs.epouten.write(|w| w.bits(self.alloc_out.used.into())); - } + if configured { + // TODO: Initialize ISO buffers - for i in 1..8 { - let out_enabled = self.alloc_out.used & (1 << i) != 0; + regs.epinen.write(|w| w.bits(self.alloc_in.used.into())); + regs.epouten.write(|w| w.bits(self.alloc_out.used.into())); - // when first enabled, bulk/interrupt OUT endpoints will *not* receive data (the - // peripheral will NAK all incoming packets) until we write a zero to the SIZE - // register (see figure 203 of the 52840 manual). To avoid that we write a 0 to the - // SIZE register - if out_enabled { - regs.size.epout[i].reset(); + for i in 1..8 { + let out_enabled = self.alloc_out.used & (1 << i) != 0; + + // when first enabled, bulk/interrupt OUT endpoints will *not* receive data (the + // peripheral will NAK all incoming packets) until we write a zero to the SIZE + // register (see figure 203 of the 52840 manual). To avoid that we write a 0 to the + // SIZE register + if out_enabled { + regs.size.epout[i].reset(); + } + } + + // IN endpoints (low bits) default to ready. + // OUT endpoints (high bits) default to NOT ready, they become ready when data comes in. + READY_ENDPOINTS.store(0x0000FFFF, Ordering::Release); + } else { + // Disable all endpoints except EP0 + regs.epinen.write(|w| w.bits(0x01)); + regs.epouten.write(|w| w.bits(0x01)); + + READY_ENDPOINTS.store(In::mask(0), Ordering::Release); } } - // IN endpoints (low bits) default to ready. - // OUT endpoints (high bits) default to NOT ready, they become ready when data comes in. - READY_ENDPOINTS.store(0x0000FFFF, Ordering::Release); + for i in 1..=7 { + In::waker(i).wake(); + Out::waker(i).wake(); + } } #[inline] @@ -332,6 +349,46 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { pub enum Out {} pub enum In {} +trait EndpointDir { + fn waker(i: usize) -> &'static AtomicWaker; + fn mask(i: usize) -> u32; + fn is_enabled(regs: &RegisterBlock, i: usize) -> bool; +} + +impl EndpointDir for In { + #[inline] + fn waker(i: usize) -> &'static AtomicWaker { + &EP_IN_WAKERS[i - 1] + } + + #[inline] + fn mask(i: usize) -> u32 { + 1 << i + } + + #[inline] + fn is_enabled(regs: &RegisterBlock, i: usize) -> bool { + (regs.epinen.read().bits() & (1 << i)) != 0 + } +} + +impl EndpointDir for Out { + #[inline] + fn waker(i: usize) -> &'static AtomicWaker { + &EP_OUT_WAKERS[i - 1] + } + + #[inline] + fn mask(i: usize) -> u32 { + 1 << (i + 16) + } + + #[inline] + fn is_enabled(regs: &RegisterBlock, i: usize) -> bool { + (regs.epouten.read().bits() & (1 << i)) != 0 + } +} + pub struct Endpoint<'d, T: Instance, Dir> { _phantom: PhantomData<(&'d mut T, Dir)>, info: EndpointInfo, @@ -346,7 +403,7 @@ impl<'d, T: Instance, Dir> Endpoint<'d, T, Dir> { } } -impl<'d, T: Instance, Dir> driver::Endpoint for Endpoint<'d, T, Dir> { +impl<'d, T: Instance, Dir: EndpointDir> driver::Endpoint for Endpoint<'d, T, Dir> { fn info(&self) -> &EndpointInfo { &self.info } @@ -358,6 +415,49 @@ impl<'d, T: Instance, Dir> driver::Endpoint for Endpoint<'d, T, Dir> { fn is_stalled(&self) -> bool { Driver::::is_stalled(self.info.addr) } + + type WaitEnabledFuture<'a> = impl Future + 'a where Self: 'a; + + fn wait_enabled(&mut self) -> Self::WaitEnabledFuture<'_> { + let i = self.info.addr.index(); + assert!(i != 0); + + poll_fn(move |cx| { + Dir::waker(i).register(cx.waker()); + if Dir::is_enabled(T::regs(), i) { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + } +} + +impl<'d, T: Instance, Dir> Endpoint<'d, T, Dir> { + async fn wait_data_ready(&mut self) -> Result<(), ()> + where + Dir: EndpointDir, + { + let i = self.info.addr.index(); + assert!(i != 0); + poll_fn(|cx| { + Dir::waker(i).register(cx.waker()); + let r = READY_ENDPOINTS.load(Ordering::Acquire); + if !Dir::is_enabled(T::regs(), i) { + Poll::Ready(Err(())) + } else if r & Dir::mask(i) != 0 { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + }) + .await?; + + // Mark as not ready + READY_ENDPOINTS.fetch_and(!Dir::mask(i), Ordering::AcqRel); + + Ok(()) + } } unsafe fn read_dma(i: usize, buf: &mut [u8]) -> Result { @@ -449,20 +549,9 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { let i = self.info.addr.index(); assert!(i != 0); - // Wait until ready - poll_fn(|cx| { - EP_OUT_WAKERS[i - 1].register(cx.waker()); - let r = READY_ENDPOINTS.load(Ordering::Acquire); - if r & (1 << (i + 16)) != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); + self.wait_data_ready() + .await + .map_err(|_| ReadError::Disabled)?; unsafe { read_dma::(i, buf) } } @@ -477,20 +566,9 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { let i = self.info.addr.index(); assert!(i != 0); - // Wait until ready. - poll_fn(|cx| { - EP_IN_WAKERS[i - 1].register(cx.waker()); - let r = READY_ENDPOINTS.load(Ordering::Acquire); - if r & (1 << i) != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel); + self.wait_data_ready() + .await + .map_err(|_| WriteError::Disabled)?; unsafe { write_dma::(i, buf) } diff --git a/embassy-usb-serial/src/lib.rs b/embassy-usb-serial/src/lib.rs index 8418de0f0..07352fac5 100644 --- a/embassy-usb-serial/src/lib.rs +++ b/embassy-usb-serial/src/lib.rs @@ -273,6 +273,11 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { pub async fn read_packet(&mut self, data: &mut [u8]) -> Result { self.read_ep.read(data).await } + + /// Waits for the USB host to enable this interface + pub async fn wait_connection(&mut self) { + self.read_ep.wait_enabled().await + } } /// Number of stop bits for LineCoding diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index 82b59bd1e..6eaa40b0d 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -72,6 +72,9 @@ pub trait Bus { /// Sets the device USB address to `addr`. fn set_device_address(&mut self, addr: u8); + /// Sets the device configured state. + fn set_configured(&mut self, configured: bool); + /// Sets or clears the STALL condition for an endpoint. If the endpoint is an OUT endpoint, it /// should be prepared to receive data again. Only used during control transfers. fn set_stalled(&mut self, ep_addr: EndpointAddress, stalled: bool); @@ -105,6 +108,10 @@ pub trait Bus { } pub trait Endpoint { + type WaitEnabledFuture<'a>: Future + 'a + where + Self: 'a; + /// Get the endpoint address fn info(&self) -> &EndpointInfo; @@ -115,6 +122,9 @@ pub trait Endpoint { /// Gets whether the STALL condition is set for an endpoint. fn is_stalled(&self) -> bool; + /// Waits for the endpoint to be enabled. + fn wait_enabled(&mut self) -> Self::WaitEnabledFuture<'_>; + // TODO enable/disable? } @@ -212,6 +222,7 @@ pub enum WriteError { /// class shouldn't provide more data than the `max_packet_size` it specified when allocating /// the endpoint. BufferOverflow, + Disabled, } #[derive(Copy, Clone, Eq, PartialEq, Debug)] @@ -223,4 +234,5 @@ pub enum ReadError { /// should use a buffer that is large enough for the `max_packet_size` it specified when /// allocating the endpoint. BufferOverflow, + Disabled, } diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index d2d3e5e0a..cf8d12539 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -162,18 +162,21 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { self.remote_wakeup_enabled = true; self.control.accept(stage) } - (Request::SET_ADDRESS, 1..=127) => { - self.pending_address = req.value as u8; + (Request::SET_ADDRESS, addr @ 1..=127) => { + self.pending_address = addr as u8; + self.bus.set_device_address(self.pending_address); self.control.accept(stage) } (Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => { self.device_state = UsbDeviceState::Configured; + self.bus.set_configured(true); self.control.accept(stage) } (Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state { UsbDeviceState::Default => self.control.accept(stage), _ => { self.device_state = UsbDeviceState::Addressed; + self.bus.set_configured(false); self.control.accept(stage) } }, diff --git a/examples/nrf/src/bin/usb_serial.rs b/examples/nrf/src/bin/usb_serial.rs index cd681c5ce..9437e835f 100644 --- a/examples/nrf/src/bin/usb_serial.rs +++ b/examples/nrf/src/bin/usb_serial.rs @@ -4,12 +4,13 @@ #![feature(type_alias_impl_trait)] use core::mem; -use defmt::*; +use defmt::{info, panic}; use embassy::executor::Spawner; use embassy_nrf::interrupt; use embassy_nrf::pac; -use embassy_nrf::usb::Driver; +use embassy_nrf::usb::{Driver, Instance}; use embassy_nrf::Peripherals; +use embassy_usb::driver::{ReadError, WriteError}; use embassy_usb::{Config, UsbDeviceBuilder}; use embassy_usb_serial::{CdcAcmClass, State}; use futures::future::join; @@ -66,12 +67,11 @@ async fn main(_spawner: Spawner, p: Peripherals) { // Do stuff with the class! let echo_fut = async { - let mut buf = [0; 64]; loop { - let n = class.read_packet(&mut buf).await.unwrap(); - let data = &buf[..n]; - info!("data: {:x}", data); - class.write_packet(data).await.unwrap(); + class.wait_connection().await; + info!("Connected"); + let _ = echo(&mut class).await; + info!("Disconnected"); } }; @@ -79,3 +79,35 @@ async fn main(_spawner: Spawner, p: Peripherals) { // If we had made everything `'static` above instead, we could do this using separate tasks instead. join(usb_fut, echo_fut).await; } + +struct Disconnected {} + +impl From for Disconnected { + fn from(val: ReadError) -> Self { + match val { + ReadError::BufferOverflow => panic!("Buffer overflow"), + ReadError::Disabled => Disconnected {}, + } + } +} + +impl From for Disconnected { + fn from(val: WriteError) -> Self { + match val { + WriteError::BufferOverflow => panic!("Buffer overflow"), + WriteError::Disabled => Disconnected {}, + } + } +} + +async fn echo<'d, T: Instance + 'd>( + class: &mut CdcAcmClass<'d, Driver<'d, T>>, +) -> Result<(), Disconnected> { + let mut buf = [0; 64]; + loop { + let n = class.read_packet(&mut buf).await?; + let data = &buf[..n]; + info!("data: {:x}", data); + class.write_packet(data).await?; + } +}