diff --git a/embassy-extras/Cargo.toml b/embassy-extras/Cargo.toml index 3c42b5c2f..5d07901a9 100644 --- a/embassy-extras/Cargo.toml +++ b/embassy-extras/Cargo.toml @@ -17,3 +17,4 @@ embassy = { version = "0.1.0", path = "../embassy" } defmt = { version = "0.2.0", optional = true } log = { version = "0.4.11", optional = true } cortex-m = "0.7.1" +usb-device = "0.2.7" diff --git a/embassy-extras/src/lib.rs b/embassy-extras/src/lib.rs index 4a95173cf..536e86c61 100644 --- a/embassy-extras/src/lib.rs +++ b/embassy-extras/src/lib.rs @@ -5,6 +5,7 @@ pub(crate) mod fmt; pub mod peripheral; pub mod ring_buffer; +pub mod usb; /// Low power blocking wait loop using WFE/SEV. pub fn low_power_wait_until(mut condition: impl FnMut() -> bool) { diff --git a/embassy-extras/src/ring_buffer.rs b/embassy-extras/src/ring_buffer.rs index f2b9f7359..0ef66f00a 100644 --- a/embassy-extras/src/ring_buffer.rs +++ b/embassy-extras/src/ring_buffer.rs @@ -69,6 +69,12 @@ impl<'a> RingBuffer<'a> { self.empty = self.start == self.end; } + pub fn clear(&mut self) { + self.start = 0; + self.end = 0; + self.empty = true; + } + fn wrap(&self, n: usize) -> usize { assert!(n <= self.buf.len()); if n == self.buf.len() { diff --git a/embassy-extras/src/usb/cdc_acm.rs b/embassy-extras/src/usb/cdc_acm.rs new file mode 100644 index 000000000..5a85b3846 --- /dev/null +++ b/embassy-extras/src/usb/cdc_acm.rs @@ -0,0 +1,338 @@ +// Copied from https://github.com/mvirkkunen/usbd-serial +#![allow(dead_code)] + +use core::convert::TryInto; +use core::mem; +use usb_device::class_prelude::*; +use usb_device::Result; + +/// This should be used as `device_class` when building the `UsbDevice`. +pub const USB_CLASS_CDC: u8 = 0x02; + +const USB_CLASS_CDC_DATA: u8 = 0x0a; +const CDC_SUBCLASS_ACM: u8 = 0x02; +const CDC_PROTOCOL_NONE: u8 = 0x00; + +const CS_INTERFACE: u8 = 0x24; +const CDC_TYPE_HEADER: u8 = 0x00; +const CDC_TYPE_CALL_MANAGEMENT: u8 = 0x01; +const CDC_TYPE_ACM: u8 = 0x02; +const CDC_TYPE_UNION: u8 = 0x06; + +const REQ_SEND_ENCAPSULATED_COMMAND: u8 = 0x00; +#[allow(unused)] +const REQ_GET_ENCAPSULATED_COMMAND: u8 = 0x01; +const REQ_SET_LINE_CODING: u8 = 0x20; +const REQ_GET_LINE_CODING: u8 = 0x21; +const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22; + +/// Packet level implementation of a CDC-ACM serial port. +/// +/// This class can be used directly and it has the least overhead due to directly reading and +/// writing USB packets with no intermediate buffers, but it will not act like a stream-like serial +/// port. The following constraints must be followed if you use this class directly: +/// +/// - `read_packet` must be called with a buffer large enough to hold max_packet_size bytes, and the +/// method will return a `WouldBlock` error if there is no packet to be read. +/// - `write_packet` must not be called with a buffer larger than max_packet_size bytes, and the +/// method will return a `WouldBlock` error if the previous packet has not been sent yet. +/// - If you write a packet that is exactly max_packet_size bytes long, it won't be processed by the +/// host operating system until a subsequent shorter packet is sent. A zero-length packet (ZLP) +/// can be sent if there is no other data to send. This is because USB bulk transactions must be +/// terminated with a short packet, even if the bulk endpoint is used for stream-like data. +pub struct CdcAcmClass<'a, B: UsbBus> { + comm_if: InterfaceNumber, + comm_ep: EndpointIn<'a, B>, + data_if: InterfaceNumber, + read_ep: EndpointOut<'a, B>, + write_ep: EndpointIn<'a, B>, + line_coding: LineCoding, + dtr: bool, + rts: bool, +} + +impl<B: UsbBus> CdcAcmClass<'_, B> { + /// Creates a new CdcAcmClass with the provided UsbBus and max_packet_size in bytes. For + /// full-speed devices, max_packet_size has to be one of 8, 16, 32 or 64. + pub fn new(alloc: &UsbBusAllocator<B>, max_packet_size: u16) -> CdcAcmClass<'_, B> { + CdcAcmClass { + comm_if: alloc.interface(), + comm_ep: alloc.interrupt(8, 255), + data_if: alloc.interface(), + read_ep: alloc.bulk(max_packet_size), + write_ep: alloc.bulk(max_packet_size), + line_coding: LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + }, + dtr: false, + rts: false, + } + } + + /// Gets the maximum packet size in bytes. + pub fn max_packet_size(&self) -> u16 { + // The size is the same for both endpoints. + self.read_ep.max_packet_size() + } + + /// Gets the current line coding. The line coding contains information that's mainly relevant + /// for USB to UART serial port emulators, and can be ignored if not relevant. + pub fn line_coding(&self) -> &LineCoding { + &self.line_coding + } + + /// Gets the DTR (data terminal ready) state + pub fn dtr(&self) -> bool { + self.dtr + } + + /// Gets the RTS (request to send) state + pub fn rts(&self) -> bool { + self.rts + } + + /// Writes a single packet into the IN endpoint. + pub fn write_packet(&mut self, data: &[u8]) -> Result<usize> { + self.write_ep.write(data) + } + + /// Reads a single packet from the OUT endpoint. + pub fn read_packet(&mut self, data: &mut [u8]) -> Result<usize> { + self.read_ep.read(data) + } + + /// Gets the address of the IN endpoint. + pub fn write_ep_address(&self) -> EndpointAddress { + self.write_ep.address() + } + + /// Gets the address of the OUT endpoint. + pub fn read_ep_address(&self) -> EndpointAddress { + self.read_ep.address() + } +} + +impl<B: UsbBus> UsbClass<B> for CdcAcmClass<'_, B> { + fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> { + writer.iad( + self.comm_if, + 2, + USB_CLASS_CDC, + CDC_SUBCLASS_ACM, + CDC_PROTOCOL_NONE, + )?; + + writer.interface( + self.comm_if, + USB_CLASS_CDC, + CDC_SUBCLASS_ACM, + CDC_PROTOCOL_NONE, + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_HEADER, // bDescriptorSubtype + 0x10, + 0x01, // bcdCDC (1.10) + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_ACM, // bDescriptorSubtype + 0x00, // bmCapabilities + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_UNION, // bDescriptorSubtype + self.comm_if.into(), // bControlInterface + self.data_if.into(), // bSubordinateInterface + ], + )?; + + writer.write( + CS_INTERFACE, + &[ + CDC_TYPE_CALL_MANAGEMENT, // bDescriptorSubtype + 0x00, // bmCapabilities + self.data_if.into(), // bDataInterface + ], + )?; + + writer.endpoint(&self.comm_ep)?; + + writer.interface(self.data_if, USB_CLASS_CDC_DATA, 0x00, 0x00)?; + + writer.endpoint(&self.write_ep)?; + writer.endpoint(&self.read_ep)?; + + Ok(()) + } + + fn reset(&mut self) { + self.line_coding = LineCoding::default(); + self.dtr = false; + self.rts = false; + } + + fn control_in(&mut self, xfer: ControlIn<B>) { + let req = xfer.request(); + + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return; + } + + match req.request { + // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. + REQ_GET_LINE_CODING if req.length == 7 => { + xfer.accept(|data| { + data[0..4].copy_from_slice(&self.line_coding.data_rate.to_le_bytes()); + data[4] = self.line_coding.stop_bits as u8; + data[5] = self.line_coding.parity_type as u8; + data[6] = self.line_coding.data_bits; + + Ok(7) + }) + .ok(); + } + _ => { + xfer.reject().ok(); + } + } + } + + fn control_out(&mut self, xfer: ControlOut<B>) { + let req = xfer.request(); + + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return; + } + + match req.request { + REQ_SEND_ENCAPSULATED_COMMAND => { + // We don't actually support encapsulated commands but pretend we do for standards + // compatibility. + xfer.accept().ok(); + } + REQ_SET_LINE_CODING if xfer.data().len() >= 7 => { + self.line_coding.data_rate = + u32::from_le_bytes(xfer.data()[0..4].try_into().unwrap()); + self.line_coding.stop_bits = xfer.data()[4].into(); + self.line_coding.parity_type = xfer.data()[5].into(); + self.line_coding.data_bits = xfer.data()[6]; + + xfer.accept().ok(); + } + REQ_SET_CONTROL_LINE_STATE => { + self.dtr = (req.value & 0x0001) != 0; + self.rts = (req.value & 0x0002) != 0; + + xfer.accept().ok(); + } + _ => { + xfer.reject().ok(); + } + }; + } +} + +/// Number of stop bits for LineCoding +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum StopBits { + /// 1 stop bit + One = 0, + + /// 1.5 stop bits + OnePointFive = 1, + + /// 2 stop bits + Two = 2, +} + +impl From<u8> for StopBits { + fn from(value: u8) -> Self { + if value <= 2 { + unsafe { mem::transmute(value) } + } else { + StopBits::One + } + } +} + +/// Parity for LineCoding +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum ParityType { + None = 0, + Odd = 1, + Event = 2, + Mark = 3, + Space = 4, +} + +impl From<u8> for ParityType { + fn from(value: u8) -> Self { + if value <= 4 { + unsafe { mem::transmute(value) } + } else { + ParityType::None + } + } +} + +/// Line coding parameters +/// +/// This is provided by the host for specifying the standard UART parameters such as baud rate. Can +/// be ignored if you don't plan to interface with a physical UART. +pub struct LineCoding { + stop_bits: StopBits, + data_bits: u8, + parity_type: ParityType, + data_rate: u32, +} + +impl LineCoding { + /// Gets the number of stop bits for UART communication. + pub fn stop_bits(&self) -> StopBits { + self.stop_bits + } + + /// Gets the number of data bits for UART communication. + pub fn data_bits(&self) -> u8 { + self.data_bits + } + + /// Gets the parity type for UART communication. + pub fn parity_type(&self) -> ParityType { + self.parity_type + } + + /// Gets the data rate in bits per second for UART communication. + pub fn data_rate(&self) -> u32 { + self.data_rate + } +} + +impl Default for LineCoding { + fn default() -> Self { + LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + } + } +} diff --git a/embassy-extras/src/usb/mod.rs b/embassy-extras/src/usb/mod.rs new file mode 100644 index 000000000..182cd87d0 --- /dev/null +++ b/embassy-extras/src/usb/mod.rs @@ -0,0 +1,254 @@ +use core::cell::RefCell; +use core::marker::PhantomData; +use core::pin::Pin; + +use usb_device::bus::UsbBus; +use usb_device::class::UsbClass; +use usb_device::device::UsbDevice; + +mod cdc_acm; +pub mod usb_serial; + +use crate::peripheral::{PeripheralMutex, PeripheralState}; +use embassy::interrupt::Interrupt; +use usb_serial::{ReadInterface, UsbSerial, WriteInterface}; + +/// Marker trait to mark an interrupt to be used with the [`Usb`] abstraction. +pub unsafe trait USBInterrupt: Interrupt {} + +pub(crate) struct State<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B>, + I: USBInterrupt, +{ + device: UsbDevice<'bus, B>, + pub(crate) classes: T, + _interrupt: PhantomData<I>, +} + +pub struct Usb<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B>, + I: USBInterrupt, +{ + // Don't you dare moving out `PeripheralMutex` + inner: RefCell<PeripheralMutex<State<'bus, B, T, I>>>, +} + +impl<'bus, B, T, I> Usb<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B>, + I: USBInterrupt, +{ + pub fn new<S: IntoClassSet<B, T>>(device: UsbDevice<'bus, B>, class_set: S, irq: I) -> Self { + let state = State { + device, + classes: class_set.into_class_set(), + _interrupt: PhantomData, + }; + let mutex = PeripheralMutex::new(state, irq); + Self { + inner: RefCell::new(mutex), + } + } + + pub fn start(self: Pin<&mut Self>) { + let this = unsafe { self.get_unchecked_mut() }; + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + + // Use inner to register the irq + mutex.register_interrupt(); + } +} + +impl<'bus, 'c, B, T, I> Usb<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B> + SerialState<'bus, 'c, B, Index0>, + I: USBInterrupt, +{ + /// Take a serial class that was passed as the first class in a tuple + pub fn take_serial_0<'a>( + self: Pin<&'a Self>, + ) -> ( + ReadInterface<'a, 'bus, 'c, Index0, B, T, I>, + WriteInterface<'a, 'bus, 'c, Index0, B, T, I>, + ) { + let this = self.get_ref(); + + let r = ReadInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + _index: PhantomData, + }; + + let w = WriteInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + _index: PhantomData, + }; + (r, w) + } +} + +impl<'bus, 'c, B, T, I> Usb<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B> + SerialState<'bus, 'c, B, Index1>, + I: USBInterrupt, +{ + /// Take a serial class that was passed as the second class in a tuple + pub fn take_serial_1<'a>( + self: Pin<&'a Self>, + ) -> ( + ReadInterface<'a, 'bus, 'c, Index1, B, T, I>, + WriteInterface<'a, 'bus, 'c, Index1, B, T, I>, + ) { + let this = self.get_ref(); + + let r = ReadInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + _index: PhantomData, + }; + + let w = WriteInterface { + inner: &this.inner, + _buf_lifetime: PhantomData, + _index: PhantomData, + }; + (r, w) + } +} + +impl<'bus, B, T, I> PeripheralState for State<'bus, B, T, I> +where + B: UsbBus, + T: ClassSet<B>, + I: USBInterrupt, +{ + type Interrupt = I; + fn on_interrupt(&mut self) { + self.classes.poll_all(&mut self.device); + } +} + +pub trait ClassSet<B: UsbBus> { + fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool; +} + +pub trait IntoClassSet<B: UsbBus, C: ClassSet<B>> { + fn into_class_set(self) -> C; +} + +pub struct ClassSet1<B, C1> +where + B: UsbBus, + C1: UsbClass<B>, +{ + class: C1, + _bus: PhantomData<B>, +} + +pub struct ClassSet2<B, C1, C2> +where + B: UsbBus, + C1: UsbClass<B>, + C2: UsbClass<B>, +{ + class1: C1, + class2: C2, + _bus: PhantomData<B>, +} + +/// The first class into a [`ClassSet`] +pub struct Index0; + +/// The second class into a [`ClassSet`] +pub struct Index1; + +impl<B, C1> ClassSet<B> for ClassSet1<B, C1> +where + B: UsbBus, + C1: UsbClass<B>, +{ + fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool { + device.poll(&mut [&mut self.class]) + } +} + +impl<B, C1, C2> ClassSet<B> for ClassSet2<B, C1, C2> +where + B: UsbBus, + C1: UsbClass<B>, + C2: UsbClass<B>, +{ + fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool { + device.poll(&mut [&mut self.class1, &mut self.class2]) + } +} + +impl<B, C1> IntoClassSet<B, ClassSet1<B, C1>> for C1 +where + B: UsbBus, + C1: UsbClass<B>, +{ + fn into_class_set(self) -> ClassSet1<B, C1> { + ClassSet1 { + class: self, + _bus: PhantomData, + } + } +} + +impl<B, C1, C2> IntoClassSet<B, ClassSet2<B, C1, C2>> for (C1, C2) +where + B: UsbBus, + C1: UsbClass<B>, + C2: UsbClass<B>, +{ + fn into_class_set(self) -> ClassSet2<B, C1, C2> { + ClassSet2 { + class1: self.0, + class2: self.1, + _bus: PhantomData, + } + } +} + +/// Trait for a USB State that has a serial class inside +pub trait SerialState<'bus, 'a, B: UsbBus, I> { + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B>; +} + +impl<'bus, 'a, B: UsbBus> SerialState<'bus, 'a, B, Index0> + for ClassSet1<B, UsbSerial<'bus, 'a, B>> +{ + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B> { + &mut self.class + } +} + +impl<'bus, 'a, B, C2> SerialState<'bus, 'a, B, Index0> for ClassSet2<B, UsbSerial<'bus, 'a, B>, C2> +where + B: UsbBus, + C2: UsbClass<B>, +{ + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B> { + &mut self.class1 + } +} + +impl<'bus, 'a, B, C1> SerialState<'bus, 'a, B, Index1> for ClassSet2<B, C1, UsbSerial<'bus, 'a, B>> +where + B: UsbBus, + C1: UsbClass<B>, +{ + fn get_serial(&mut self) -> &mut UsbSerial<'bus, 'a, B> { + &mut self.class2 + } +} diff --git a/embassy-extras/src/usb/usb_serial.rs b/embassy-extras/src/usb/usb_serial.rs new file mode 100644 index 000000000..9cbfb2da4 --- /dev/null +++ b/embassy-extras/src/usb/usb_serial.rs @@ -0,0 +1,310 @@ +use core::cell::RefCell; +use core::marker::{PhantomData, Unpin}; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use embassy::io::{self, AsyncBufRead, AsyncWrite}; +use embassy::util::WakerRegistration; +use usb_device::bus::UsbBus; +use usb_device::class_prelude::*; +use usb_device::UsbError; + +use super::cdc_acm::CdcAcmClass; +use crate::peripheral::PeripheralMutex; +use crate::ring_buffer::RingBuffer; +use crate::usb::{ClassSet, SerialState, State, USBInterrupt}; + +pub struct ReadInterface<'a, 'bus, 'c, I, B, T, INT> +where + I: Unpin, + B: UsbBus, + T: SerialState<'bus, 'c, B, I> + ClassSet<B>, + INT: USBInterrupt, +{ + // Don't you dare moving out `PeripheralMutex` + pub(crate) inner: &'a RefCell<PeripheralMutex<State<'bus, B, T, INT>>>, + pub(crate) _buf_lifetime: PhantomData<&'c T>, + pub(crate) _index: PhantomData<I>, +} + +/// Write interface for USB CDC_ACM +/// +/// This interface is buffered, meaning that after the write returns the bytes might not be fully +/// on the wire just yet +pub struct WriteInterface<'a, 'bus, 'c, I, B, T, INT> +where + I: Unpin, + B: UsbBus, + T: SerialState<'bus, 'c, B, I> + ClassSet<B>, + INT: USBInterrupt, +{ + // Don't you dare moving out `PeripheralMutex` + pub(crate) inner: &'a RefCell<PeripheralMutex<State<'bus, B, T, INT>>>, + pub(crate) _buf_lifetime: PhantomData<&'c T>, + pub(crate) _index: PhantomData<I>, +} + +impl<'a, 'bus, 'c, I, B, T, INT> AsyncBufRead for ReadInterface<'a, 'bus, 'c, I, B, T, INT> +where + I: Unpin, + B: UsbBus, + T: SerialState<'bus, 'c, B, I> + ClassSet<B>, + INT: USBInterrupt, +{ + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + match serial.poll_fill_buf(cx) { + Poll::Ready(Ok(buf)) => { + let buf: &[u8] = buf; + // NOTE(unsafe) This part of the buffer won't be modified until the user calls + // consume, which will invalidate this ref + let buf: &[u8] = unsafe { core::mem::transmute(buf) }; + Poll::Ready(Ok(buf)) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(io::Error::Other)), + Poll::Pending => Poll::Pending, + } + }) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + serial.consume(amt); + }) + } +} + +impl<'a, 'bus, 'c, I, B, T, INT> AsyncWrite for WriteInterface<'a, 'bus, 'c, I, B, T, INT> +where + I: Unpin, + B: UsbBus, + T: SerialState<'bus, 'c, B, I> + ClassSet<B>, + INT: USBInterrupt, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let this = self.get_mut(); + let mut mutex = this.inner.borrow_mut(); + let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; + mutex.with(|state, _irq| { + let serial = state.classes.get_serial(); + let serial = Pin::new(serial); + + serial.poll_write(cx, buf) + }) + } +} + +pub struct UsbSerial<'bus, 'a, B: UsbBus> { + inner: CdcAcmClass<'bus, B>, + read_buf: RingBuffer<'a>, + write_buf: RingBuffer<'a>, + read_waker: WakerRegistration, + write_waker: WakerRegistration, + write_state: WriteState, + read_error: bool, + write_error: bool, +} + +impl<'bus, 'a, B: UsbBus> AsyncBufRead for UsbSerial<'bus, 'a, B> { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + let this = self.get_mut(); + + if this.read_error { + this.read_error = false; + return Poll::Ready(Err(io::Error::Other)); + } + + let buf = this.read_buf.pop_buf(); + if buf.is_empty() { + this.read_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(buf)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_mut().read_buf.pop(amt); + } +} + +impl<'bus, 'a, B: UsbBus> AsyncWrite for UsbSerial<'bus, 'a, B> { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let this = self.get_mut(); + + if this.write_error { + this.write_error = false; + return Poll::Ready(Err(io::Error::Other)); + } + + let write_buf = this.write_buf.push_buf(); + if write_buf.is_empty() { + this.write_waker.register(cx.waker()); + return Poll::Pending; + } + + let count = write_buf.len().min(buf.len()); + write_buf[..count].copy_from_slice(&buf[..count]); + this.write_buf.push(count); + + this.flush_write(); + Poll::Ready(Ok(count)) + } +} + +/// Keeps track of the type of the last written packet. +enum WriteState { + /// No packets in-flight + Idle, + + /// Short packet currently in-flight + Short, + + /// Full packet current in-flight. A full packet must be followed by a short packet for the host + /// OS to see the transaction. The data is the number of subsequent full packets sent so far. A + /// short packet is forced every SHORT_PACKET_INTERVAL packets so that the OS sees data in a + /// timely manner. + Full(usize), +} + +impl<'bus, 'a, B: UsbBus> UsbSerial<'bus, 'a, B> { + pub fn new( + alloc: &'bus UsbBusAllocator<B>, + read_buf: &'a mut [u8], + write_buf: &'a mut [u8], + ) -> Self { + Self { + inner: CdcAcmClass::new(alloc, 64), + read_buf: RingBuffer::new(read_buf), + write_buf: RingBuffer::new(write_buf), + read_waker: WakerRegistration::new(), + write_waker: WakerRegistration::new(), + write_state: WriteState::Idle, + read_error: false, + write_error: false, + } + } + + fn flush_write(&mut self) { + /// If this many full size packets have been sent in a row, a short packet will be sent so that the + /// host sees the data in a timely manner. + const SHORT_PACKET_INTERVAL: usize = 10; + + let full_size_packets = match self.write_state { + WriteState::Full(c) => c, + _ => 0, + }; + + let ep_size = self.inner.max_packet_size() as usize; + let max_size = if full_size_packets > SHORT_PACKET_INTERVAL { + ep_size - 1 + } else { + ep_size + }; + + let buf = { + let buf = self.write_buf.pop_buf(); + if buf.len() > max_size { + &buf[..max_size] + } else { + buf + } + }; + + if !buf.is_empty() { + let count = match self.inner.write_packet(buf) { + Ok(c) => c, + Err(UsbError::WouldBlock) => 0, + Err(_) => { + self.write_error = true; + return; + } + }; + + if buf.len() == ep_size { + self.write_state = WriteState::Full(full_size_packets + 1); + } else { + self.write_state = WriteState::Short; + } + self.write_buf.pop(count); + } else if full_size_packets > 0 { + if let Err(e) = self.inner.write_packet(&[]) { + if !matches!(e, UsbError::WouldBlock) { + self.write_error = true; + } + return; + } + self.write_state = WriteState::Idle; + } + } +} + +impl<B> UsbClass<B> for UsbSerial<'_, '_, B> +where + B: UsbBus, +{ + fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<(), UsbError> { + self.inner.get_configuration_descriptors(writer) + } + + fn reset(&mut self) { + self.inner.reset(); + self.read_buf.clear(); + self.write_buf.clear(); + self.write_state = WriteState::Idle; + } + + fn endpoint_in_complete(&mut self, addr: EndpointAddress) { + if addr == self.inner.write_ep_address() { + self.write_waker.wake(); + + self.flush_write(); + } + } + + fn endpoint_out(&mut self, addr: EndpointAddress) { + if addr == self.inner.read_ep_address() { + let buf = self.read_buf.push_buf(); + let count = match self.inner.read_packet(buf) { + Ok(c) => c, + Err(UsbError::WouldBlock) => 0, + Err(_) => { + self.read_error = true; + return; + } + }; + + if count > 0 { + self.read_buf.push(count); + self.read_waker.wake(); + } + } + } + + fn control_in(&mut self, xfer: ControlIn<B>) { + self.inner.control_in(xfer); + } + + fn control_out(&mut self, xfer: ControlOut<B>) { + self.inner.control_out(xfer); + } +} diff --git a/embassy-stm32/Cargo.toml b/embassy-stm32/Cargo.toml index 437cacb27..28d9e7fe3 100644 --- a/embassy-stm32/Cargo.toml +++ b/embassy-stm32/Cargo.toml @@ -36,6 +36,7 @@ stm32l0x3 = ["stm32l0xx-hal/stm32l0x3"] [dependencies] embassy = { version = "0.1.0", path = "../embassy" } embassy-macros = { version = "0.1.0", path = "../embassy-macros", features = ["stm32"]} +embassy-extras = {version = "0.1.0", path = "../embassy-extras" } defmt = { version = "0.2.0", optional = true } log = { version = "0.4.11", optional = true } diff --git a/embassy-stm32/src/lib.rs b/embassy-stm32/src/lib.rs index 122761c1e..078e1e5a6 100644 --- a/embassy-stm32/src/lib.rs +++ b/embassy-stm32/src/lib.rs @@ -36,6 +36,7 @@ pub mod exti; pub mod interrupt; #[cfg(any( + feature = "stm32f401", feature = "stm32f405", feature = "stm32f407", feature = "stm32f412", @@ -74,6 +75,8 @@ pub mod can; ))] pub mod rtc; +unsafe impl embassy_extras::usb::USBInterrupt for interrupt::OTG_FS {} + use core::option::Option; use hal::prelude::*; use hal::rcc::Clocks; diff --git a/embassy-stm32f4-examples/Cargo.toml b/embassy-stm32f4-examples/Cargo.toml index fbdc6d794..c6ef98973 100644 --- a/embassy-stm32f4-examples/Cargo.toml +++ b/embassy-stm32f4-examples/Cargo.toml @@ -39,6 +39,7 @@ embassy = { version = "0.1.0", path = "../embassy", features = ["defmt", "defmt- embassy-traits = { version = "0.1.0", path = "../embassy-traits", features = ["defmt"] } embassy-stm32f4 = { version = "*", path = "../embassy-stm32f4" } embassy-stm32 = { version = "*", path = "../embassy-stm32" } +embassy-extras = {version = "0.1.0", path = "../embassy-extras" } defmt = "0.2.0" defmt-rtt = "0.2.0" @@ -47,7 +48,8 @@ cortex-m = "0.7.1" cortex-m-rt = "0.6.13" embedded-hal = { version = "0.2.4" } panic-probe = "0.1.0" -stm32f4xx-hal = { version = "0.8.3", features = ["rt"], git = "https://github.com/stm32-rs/stm32f4xx-hal.git"} +stm32f4xx-hal = { version = "0.8.3", features = ["rt", "usb_fs"], git = "https://github.com/stm32-rs/stm32f4xx-hal.git"} futures = { version = "0.3.8", default-features = false, features = ["async-await"] } rtt-target = { version = "0.3", features = ["cortex-m"] } -bxcan = "0.5.0" \ No newline at end of file +bxcan = "0.5.0" +usb-device = "0.2.7" diff --git a/embassy-stm32f4-examples/src/bin/usb_serial.rs b/embassy-stm32f4-examples/src/bin/usb_serial.rs new file mode 100644 index 000000000..f1c4631d7 --- /dev/null +++ b/embassy-stm32f4-examples/src/bin/usb_serial.rs @@ -0,0 +1,146 @@ +#![no_std] +#![no_main] +#![feature(type_alias_impl_trait)] +#![feature(min_type_alias_impl_trait)] +#![feature(impl_trait_in_bindings)] + +#[path = "../example_common.rs"] +mod example_common; +use example_common::*; + +use cortex_m_rt::entry; +use defmt::panic; +use embassy::executor::{task, Executor}; +use embassy::interrupt::InterruptExt; +use embassy::io::{AsyncBufReadExt, AsyncWriteExt}; +use embassy::time::{Duration, Timer}; +use embassy::util::Forever; +use embassy_extras::usb::usb_serial::UsbSerial; +use embassy_extras::usb::Usb; +use embassy_stm32f4::{interrupt, pac, rtc}; +use futures::future::{select, Either}; +use futures::pin_mut; +use stm32f4xx_hal::otg_fs::{UsbBus, USB}; +use stm32f4xx_hal::prelude::*; +use usb_device::bus::UsbBusAllocator; +use usb_device::prelude::*; + +#[task] +async fn run1(bus: &'static mut UsbBusAllocator<UsbBus<USB>>) { + info!("Async task"); + + let mut read_buf = [0u8; 128]; + let mut write_buf = [0u8; 128]; + let serial = UsbSerial::new(bus, &mut read_buf, &mut write_buf); + + let device = UsbDeviceBuilder::new(bus, UsbVidPid(0x16c0, 0x27dd)) + .manufacturer("Fake company") + .product("Serial port") + .serial_number("TEST") + .device_class(0x02) + .build(); + + let irq = interrupt::take!(OTG_FS); + irq.set_priority(interrupt::Priority::Level3); + + let usb = Usb::new(device, serial, irq); + pin_mut!(usb); + usb.as_mut().start(); + + let (mut read_interface, mut write_interface) = usb.as_ref().take_serial_0(); + + let mut buf = [0u8; 64]; + loop { + let mut n = 0; + let left = { + let recv_fut = async { + loop { + let byte = unwrap!(read_interface.read_byte().await); + unwrap!(write_interface.write_byte(byte).await); + buf[n] = byte; + + n += 1; + if byte == b'\n' || byte == b'\r' || n == buf.len() { + break; + } + } + }; + pin_mut!(recv_fut); + + let timeout = Timer::after(Duration::from_ticks(32768 * 10)); + + match select(recv_fut, timeout).await { + Either::Left(_) => true, + Either::Right(_) => false, + } + }; + + if left { + for c in buf[..n].iter_mut() { + if 0x61 <= *c && *c <= 0x7a { + *c &= !0x20; + } + } + unwrap!(write_interface.write_byte(b'\n').await); + unwrap!(write_interface.write_all(&buf[..n]).await); + unwrap!(write_interface.write_byte(b'\n').await); + } else { + unwrap!(write_interface.write_all(b"\r\nSend something\r\n").await); + } + } +} + +static RTC: Forever<rtc::RTC<pac::TIM2>> = Forever::new(); +static ALARM: Forever<rtc::Alarm<pac::TIM2>> = Forever::new(); +static EXECUTOR: Forever<Executor> = Forever::new(); +static USB_BUS: Forever<UsbBusAllocator<UsbBus<USB>>> = Forever::new(); + +#[entry] +fn main() -> ! { + static mut EP_MEMORY: [u32; 1024] = [0; 1024]; + + info!("Hello World!"); + + let p = unwrap!(pac::Peripherals::take()); + + p.RCC.ahb1enr.modify(|_, w| w.dma1en().enabled()); + let rcc = p.RCC.constrain(); + let clocks = rcc + .cfgr + .use_hse(25.mhz()) + .sysclk(48.mhz()) + .require_pll48clk() + .freeze(); + + p.DBGMCU.cr.modify(|_, w| { + w.dbg_sleep().set_bit(); + w.dbg_standby().set_bit(); + w.dbg_stop().set_bit() + }); + + let rtc = RTC.put(rtc::RTC::new(p.TIM2, interrupt::take!(TIM2), clocks)); + rtc.start(); + + unsafe { embassy::time::set_clock(rtc) }; + + let alarm = ALARM.put(rtc.alarm1()); + let executor = EXECUTOR.put(Executor::new()); + executor.set_alarm(alarm); + + let gpioa = p.GPIOA.split(); + let usb = USB { + usb_global: p.OTG_FS_GLOBAL, + usb_device: p.OTG_FS_DEVICE, + usb_pwrclk: p.OTG_FS_PWRCLK, + pin_dm: gpioa.pa11.into_alternate_af10(), + pin_dp: gpioa.pa12.into_alternate_af10(), + hclk: clocks.hclk(), + }; + // Rust analyzer isn't recognizing the static ref magic `cortex-m` does + #[allow(unused_unsafe)] + let usb_bus = USB_BUS.put(UsbBus::new(usb, unsafe { EP_MEMORY })); + + executor.run(move |spawner| { + unwrap!(spawner.spawn(run1(usb_bus))); + }); +} diff --git a/embassy-stm32f4/Cargo.toml b/embassy-stm32f4/Cargo.toml index c132d7c99..8e9b14f03 100644 --- a/embassy-stm32f4/Cargo.toml +++ b/embassy-stm32f4/Cargo.toml @@ -32,6 +32,7 @@ stm32f479 = ["stm32f4xx-hal/stm32f469", "embassy-stm32/stm32f479"] [dependencies] embassy = { version = "0.1.0", path = "../embassy" } embassy-stm32 = { version = "0.1.0", path = "../embassy-stm32" } + defmt = { version = "0.2.0", optional = true } log = { version = "0.4.11", optional = true } cortex-m-rt = "0.6.13"