From 61e3ca049c604384593168037e80186c317dacbb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?D=C3=A1niel=20Buga?= <bugadani@gmail.com>
Date: Sat, 27 Apr 2024 17:55:57 +0200
Subject: [PATCH] Only access the necessary parts of State

---
 embassy-usb-synopsys-otg/src/lib.rs | 60 ++++++++++++++++-------------
 1 file changed, 33 insertions(+), 27 deletions(-)

diff --git a/embassy-usb-synopsys-otg/src/lib.rs b/embassy-usb-synopsys-otg/src/lib.rs
index 402347017..59be2c0be 100644
--- a/embassy-usb-synopsys-otg/src/lib.rs
+++ b/embassy-usb-synopsys-otg/src/lib.rs
@@ -55,12 +55,12 @@ pub unsafe fn on_interrupt(r: Otg, state: &State<{ MAX_EP_COUNT }>, ep_count: us
                     while r.grstctl().read().txfflsh() {}
                 }
 
-                if state.ep0_setup_ready.load(Ordering::Relaxed) == false {
+                if state.cp_state.setup_ready.load(Ordering::Relaxed) == false {
                     // SAFETY: exclusive access ensured by atomic bool
-                    let data = unsafe { &mut *state.ep0_setup_data.get() };
+                    let data = unsafe { &mut *state.cp_state.setup_data.get() };
                     data[0..4].copy_from_slice(&r.fifo(0).read().0.to_ne_bytes());
                     data[4..8].copy_from_slice(&r.fifo(0).read().0.to_ne_bytes());
-                    state.ep0_setup_ready.store(true, Ordering::Release);
+                    state.cp_state.setup_ready.store(true, Ordering::Release);
                     state.ep_states[0].out_waker.wake();
                 } else {
                     error!("received SETUP before previous finished processing");
@@ -216,11 +216,15 @@ struct EpState {
     out_size: AtomicU16,
 }
 
+struct ControlPipeSetupState {
+    /// Holds received SETUP packets. Available if [Ep0State::setup_ready] is true.
+    setup_data: UnsafeCell<[u8; 8]>,
+    setup_ready: AtomicBool,
+}
+
 /// USB OTG driver state.
 pub struct State<const EP_COUNT: usize> {
-    /// Holds received SETUP packets. Available if [State::ep0_setup_ready] is true.
-    ep0_setup_data: UnsafeCell<[u8; 8]>,
-    ep0_setup_ready: AtomicBool,
+    cp_state: ControlPipeSetupState,
     ep_states: [EpState; EP_COUNT],
     bus_waker: AtomicWaker,
 }
@@ -242,8 +246,10 @@ impl<const EP_COUNT: usize> State<EP_COUNT> {
         };
 
         Self {
-            ep0_setup_data: UnsafeCell::new([0u8; 8]),
-            ep0_setup_ready: AtomicBool::new(false),
+            cp_state: ControlPipeSetupState {
+                setup_data: UnsafeCell::new([0u8; 8]),
+                setup_ready: AtomicBool::new(false),
+            },
             ep_states: [NEW_EP_STATE; EP_COUNT],
             bus_waker: NEW_AW,
         }
@@ -385,11 +391,11 @@ impl<'d> Driver<'d> {
 
         trace!("  index={}", index);
 
+        let state = &self.instance.state.ep_states[index];
         if D::dir() == Direction::Out {
             // Buffer capacity check was done above, now allocation cannot fail
             unsafe {
-                *self.instance.state.ep_states[index].out_buffer.get() =
-                    self.ep_out_buffer.as_mut_ptr().offset(self.ep_out_buffer_offset as _);
+                *state.out_buffer.get() = self.ep_out_buffer.as_mut_ptr().offset(self.ep_out_buffer_offset as _);
             }
             self.ep_out_buffer_offset += max_packet_size as usize;
         }
@@ -397,7 +403,7 @@ impl<'d> Driver<'d> {
         Ok(Endpoint {
             _phantom: PhantomData,
             regs: self.instance.regs,
-            state: self.instance.state,
+            state,
             info: EndpointInfo {
                 addr: EndpointAddress::from_parts(index, D::dir()),
                 ep_type,
@@ -446,6 +452,7 @@ impl<'d> embassy_usb_driver::Driver<'d> for Driver<'d> {
 
         let regs = self.instance.regs;
         let quirk_setup_late_cnak = self.instance.quirk_setup_late_cnak;
+        let cp_setup_state = &self.instance.state.cp_state;
         (
             Bus {
                 config: self.config,
@@ -456,6 +463,7 @@ impl<'d> embassy_usb_driver::Driver<'d> for Driver<'d> {
             },
             ControlPipe {
                 max_packet_size: control_max_packet_size,
+                setup_state: cp_setup_state,
                 ep_out,
                 ep_in,
                 regs,
@@ -955,7 +963,7 @@ pub struct Endpoint<'d, D> {
     _phantom: PhantomData<D>,
     regs: Otg,
     info: EndpointInfo,
-    state: &'d State<{ MAX_EP_COUNT }>,
+    state: &'d EpState,
 }
 
 impl<'d> embassy_usb_driver::Endpoint for Endpoint<'d, In> {
@@ -967,7 +975,7 @@ impl<'d> embassy_usb_driver::Endpoint for Endpoint<'d, In> {
         poll_fn(|cx| {
             let ep_index = self.info.addr.index();
 
-            self.state.ep_states[ep_index].in_waker.register(cx.waker());
+            self.state.in_waker.register(cx.waker());
 
             if self.regs.diepctl(ep_index).read().usbaep() {
                 Poll::Ready(())
@@ -988,7 +996,7 @@ impl<'d> embassy_usb_driver::Endpoint for Endpoint<'d, Out> {
         poll_fn(|cx| {
             let ep_index = self.info.addr.index();
 
-            self.state.ep_states[ep_index].out_waker.register(cx.waker());
+            self.state.out_waker.register(cx.waker());
 
             if self.regs.doepctl(ep_index).read().usbaep() {
                 Poll::Ready(())
@@ -1006,7 +1014,7 @@ impl<'d> embassy_usb_driver::EndpointOut for Endpoint<'d, Out> {
 
         poll_fn(|cx| {
             let index = self.info.addr.index();
-            self.state.ep_states[index].out_waker.register(cx.waker());
+            self.state.out_waker.register(cx.waker());
 
             let doepctl = self.regs.doepctl(index).read();
             trace!("read ep={:?}: doepctl {:08x}", self.info.addr, doepctl.0,);
@@ -1015,7 +1023,7 @@ impl<'d> embassy_usb_driver::EndpointOut for Endpoint<'d, Out> {
                 return Poll::Ready(Err(EndpointError::Disabled));
             }
 
-            let len = self.state.ep_states[index].out_size.load(Ordering::Relaxed);
+            let len = self.state.out_size.load(Ordering::Relaxed);
             if len != EP_OUT_BUFFER_EMPTY {
                 trace!("read ep={:?} done len={}", self.info.addr, len);
 
@@ -1024,14 +1032,11 @@ impl<'d> embassy_usb_driver::EndpointOut for Endpoint<'d, Out> {
                 }
 
                 // SAFETY: exclusive access ensured by `ep_out_size` atomic variable
-                let data =
-                    unsafe { core::slice::from_raw_parts(*self.state.ep_states[index].out_buffer.get(), len as usize) };
+                let data = unsafe { core::slice::from_raw_parts(*self.state.out_buffer.get(), len as usize) };
                 buf[..len as usize].copy_from_slice(data);
 
                 // Release buffer
-                self.state.ep_states[index]
-                    .out_size
-                    .store(EP_OUT_BUFFER_EMPTY, Ordering::Release);
+                self.state.out_size.store(EP_OUT_BUFFER_EMPTY, Ordering::Release);
 
                 critical_section::with(|_| {
                     // Receive 1 packet
@@ -1066,7 +1071,7 @@ impl<'d> embassy_usb_driver::EndpointIn for Endpoint<'d, In> {
         let index = self.info.addr.index();
         // Wait for previous transfer to complete and check if endpoint is disabled
         poll_fn(|cx| {
-            self.state.ep_states[index].in_waker.register(cx.waker());
+            self.state.in_waker.register(cx.waker());
 
             let diepctl = self.regs.diepctl(index).read();
             let dtxfsts = self.regs.dtxfsts(index).read();
@@ -1091,7 +1096,7 @@ impl<'d> embassy_usb_driver::EndpointIn for Endpoint<'d, In> {
 
         if buf.len() > 0 {
             poll_fn(|cx| {
-                self.state.ep_states[index].in_waker.register(cx.waker());
+                self.state.in_waker.register(cx.waker());
 
                 let size_words = (buf.len() + 3) / 4;
 
@@ -1151,6 +1156,7 @@ impl<'d> embassy_usb_driver::EndpointIn for Endpoint<'d, In> {
 pub struct ControlPipe<'d> {
     max_packet_size: u16,
     regs: Otg,
+    setup_state: &'d ControlPipeSetupState,
     ep_in: Endpoint<'d, In>,
     ep_out: Endpoint<'d, Out>,
     quirk_setup_late_cnak: bool,
@@ -1163,11 +1169,11 @@ impl<'d> embassy_usb_driver::ControlPipe for ControlPipe<'d> {
 
     async fn setup(&mut self) -> [u8; 8] {
         poll_fn(|cx| {
-            self.ep_out.state.ep_states[0].out_waker.register(cx.waker());
+            self.ep_out.state.out_waker.register(cx.waker());
 
-            if self.ep_out.state.ep0_setup_ready.load(Ordering::Relaxed) {
-                let data = unsafe { *self.ep_out.state.ep0_setup_data.get() };
-                self.ep_out.state.ep0_setup_ready.store(false, Ordering::Release);
+            if self.setup_state.setup_ready.load(Ordering::Relaxed) {
+                let data = unsafe { *self.setup_state.setup_data.get() };
+                self.setup_state.setup_ready.store(false, Ordering::Release);
 
                 // EP0 should not be controlled by `Bus` so this RMW does not need a critical section
                 // Receive 1 SETUP packet