From 2548bbdd65fc3094f624bd043a1a9a296f9184b5 Mon Sep 17 00:00:00 2001
From: Dario Nieuwenhuis <dirbaio@dirbaio.net>
Date: Tue, 27 Dec 2022 01:19:26 +0100
Subject: [PATCH] Update Embassy.

---
 Cargo.toml                      |   2 +-
 examples/rpi-pico-w/Cargo.toml  |  18 +--
 examples/rpi-pico-w/src/main.rs |   6 +-
 src/lib.rs                      | 218 +++++++++++++-------------------
 4 files changed, 101 insertions(+), 143 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index dadfb5c5a..6e3237448 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,7 +14,7 @@ firmware-logs = []
 embassy-time = { version = "0.1.0" }
 embassy-sync = { version = "0.1.0" }
 embassy-futures = { version = "0.1.0" }
-embassy-net = { version = "0.1.0" }
+embassy-net-driver-channel = { version = "0.1.0" }
 atomic-polyfill = "0.1.5"
 
 defmt = { version = "0.3", optional = true }
diff --git a/examples/rpi-pico-w/Cargo.toml b/examples/rpi-pico-w/Cargo.toml
index b817289e5..fa1cad8c7 100644
--- a/examples/rpi-pico-w/Cargo.toml
+++ b/examples/rpi-pico-w/Cargo.toml
@@ -9,7 +9,7 @@ cyw43 = { path = "../../", features = ["defmt", "firmware-logs"]}
 embassy-executor = { version = "0.1.0",  features = ["defmt", "integrated-timers"] }
 embassy-time = { version = "0.1.0",  features = ["defmt", "defmt-timestamp-uptime"] }
 embassy-rp = { version = "0.1.0",  features = ["defmt", "unstable-traits", "nightly", "unstable-pac", "time-driver"] }
-embassy-net = { version = "0.1.0", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "pool-16", "unstable-traits", "nightly"] }
+embassy-net = { version = "0.1.0", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "unstable-traits", "nightly"] }
 atomic-polyfill = "0.1.5"
 static_cell = "1.0"
 
@@ -28,12 +28,14 @@ heapless = "0.7.15"
 
 
 [patch.crates-io]
-embassy-executor = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
-embassy-time = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
-embassy-futures = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
-embassy-sync = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
-embassy-rp = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
-embassy-net = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" }
+embassy-executor = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-time = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-futures = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-sync = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-rp = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-net = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-net-driver = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
+embassy-net-driver-channel = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" }
 
 [profile.dev]
 debug = 2
@@ -43,7 +45,7 @@ overflow-checks = true
 
 [profile.release]
 codegen-units = 1
-debug = 2
+debug = 1
 debug-assertions = false
 incremental = false
 lto = 'fat'
diff --git a/examples/rpi-pico-w/src/main.rs b/examples/rpi-pico-w/src/main.rs
index a19f38591..fd58e46df 100644
--- a/examples/rpi-pico-w/src/main.rs
+++ b/examples/rpi-pico-w/src/main.rs
@@ -34,7 +34,7 @@ async fn wifi_task(
 }
 
 #[embassy_executor::task]
-async fn net_task(stack: &'static Stack<cyw43::NetDevice<'static>>) -> ! {
+async fn net_task(stack: &'static Stack<cyw43::NetDriver<'static>>) -> ! {
     stack.run().await
 }
 
@@ -66,11 +66,11 @@ async fn main(spawner: Spawner) {
     let spi = ExclusiveDevice::new(bus, cs);
 
     let state = singleton!(cyw43::State::new());
-    let (mut control, runner) = cyw43::new(state, pwr, spi, fw).await;
+    let (net_device, mut control, runner) = cyw43::new(state, pwr, spi, fw).await;
 
     spawner.spawn(wifi_task(runner)).unwrap();
 
-    let net_device = control.init(clm).await;
+    control.init(clm).await;
 
     //control.join_open(env!("WIFI_NETWORK")).await;
     control.join_wpa2(env!("WIFI_NETWORK"), env!("WIFI_PASSWORD")).await;
diff --git a/src/lib.rs b/src/lib.rs
index fa73b32e0..25e6f8f16 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -15,14 +15,10 @@ mod structs;
 use core::cell::Cell;
 use core::cmp::{max, min};
 use core::slice;
-use core::sync::atomic::Ordering;
-use core::task::Waker;
 
-use atomic_polyfill::AtomicBool;
+use ch::driver::LinkState;
 use embassy_futures::yield_now;
-use embassy_net::{PacketBoxExt, PacketBuf};
-use embassy_sync::blocking_mutex::raw::NoopRawMutex;
-use embassy_sync::channel::Channel;
+use embassy_net_driver_channel as ch;
 use embassy_time::{block_for, Duration, Timer};
 use embedded_hal_1::digital::OutputPin;
 use embedded_hal_async::spi::{SpiBusRead, SpiBusWrite, SpiDevice};
@@ -32,6 +28,8 @@ use crate::consts::*;
 use crate::events::Event;
 use crate::structs::*;
 
+const MTU: usize = 1514;
+
 #[derive(Clone, Copy)]
 pub enum IoctlType {
     Get = 0,
@@ -128,30 +126,25 @@ enum IoctlState {
 
 pub struct State {
     ioctl_state: Cell<IoctlState>,
-
-    tx_channel: Channel<NoopRawMutex, PacketBuf, 8>,
-    rx_channel: Channel<NoopRawMutex, PacketBuf, 8>,
-    link_up: AtomicBool,
+    ch: ch::State<MTU, 4, 4>,
 }
 
 impl State {
     pub fn new() -> Self {
         Self {
             ioctl_state: Cell::new(IoctlState::Idle),
-
-            tx_channel: Channel::new(),
-            rx_channel: Channel::new(),
-            link_up: AtomicBool::new(true), // TODO set up/down as we join/deassociate
+            ch: ch::State::new(),
         }
     }
 }
 
 pub struct Control<'a> {
-    state: &'a State,
+    state_ch: ch::StateRunner<'a>,
+    ioctl_state: &'a Cell<IoctlState>,
 }
 
 impl<'a> Control<'a> {
-    pub async fn init(&mut self, clm: &[u8]) -> NetDevice<'a> {
+    pub async fn init(&mut self, clm: &[u8]) {
         const CHUNK_SIZE: usize = 1024;
 
         info!("Downloading CLM...");
@@ -258,12 +251,10 @@ impl<'a> Control<'a> {
 
         Timer::after(Duration::from_millis(100)).await;
 
-        info!("INIT DONE");
+        self.state_ch.set_ethernet_address(mac_addr);
+        self.state_ch.set_link_state(LinkState::Up); // TODO do on join/leave
 
-        NetDevice {
-            state: self.state,
-            mac_addr,
-        }
+        info!("INIT DONE");
     }
 
     pub async fn join_open(&mut self, ssid: &str) {
@@ -381,75 +372,30 @@ impl<'a> Control<'a> {
     async fn ioctl(&mut self, kind: IoctlType, cmd: u32, iface: u32, buf: &mut [u8]) -> usize {
         // TODO cancel ioctl on future drop.
 
-        while !matches!(self.state.ioctl_state.get(), IoctlState::Idle) {
+        while !matches!(self.ioctl_state.get(), IoctlState::Idle) {
             yield_now().await;
         }
 
-        self.state
-            .ioctl_state
-            .set(IoctlState::Pending { kind, cmd, iface, buf });
+        self.ioctl_state.set(IoctlState::Pending { kind, cmd, iface, buf });
 
         let resp_len = loop {
-            if let IoctlState::Done { resp_len } = self.state.ioctl_state.get() {
+            if let IoctlState::Done { resp_len } = self.ioctl_state.get() {
                 break resp_len;
             }
             yield_now().await;
         };
 
-        self.state.ioctl_state.set(IoctlState::Idle);
+        self.ioctl_state.set(IoctlState::Idle);
 
         resp_len
     }
 }
 
-pub struct NetDevice<'a> {
-    state: &'a State,
-    mac_addr: [u8; 6],
-}
-
-impl<'a> embassy_net::Device for NetDevice<'a> {
-    fn register_waker(&mut self, waker: &Waker) {
-        // loopy loopy wakey wakey
-        waker.wake_by_ref()
-    }
-
-    fn link_state(&mut self) -> embassy_net::LinkState {
-        match self.state.link_up.load(Ordering::Relaxed) {
-            true => embassy_net::LinkState::Up,
-            false => embassy_net::LinkState::Down,
-        }
-    }
-
-    fn capabilities(&self) -> embassy_net::DeviceCapabilities {
-        let mut caps = embassy_net::DeviceCapabilities::default();
-        caps.max_transmission_unit = 1514; // 1500 IP + 14 ethernet header
-        caps.medium = embassy_net::Medium::Ethernet;
-        caps
-    }
-
-    fn is_transmit_ready(&mut self) -> bool {
-        true
-    }
-
-    fn transmit(&mut self, pkt: PacketBuf) {
-        if self.state.tx_channel.try_send(pkt).is_err() {
-            warn!("TX failed")
-        }
-    }
-
-    fn receive(&mut self) -> Option<PacketBuf> {
-        self.state.rx_channel.try_recv().ok()
-    }
-
-    fn ethernet_address(&self) -> [u8; 6] {
-        self.mac_addr
-    }
-}
-
 pub struct Runner<'a, PWR, SPI> {
+    ch: ch::Runner<'a, MTU>,
     bus: Bus<PWR, SPI>,
 
-    state: &'a State,
+    ioctl_state: &'a Cell<IoctlState>,
     ioctl_id: u16,
     sdpcm_seq: u8,
     sdpcm_seq_max: u8,
@@ -466,21 +412,27 @@ struct LogState {
     buf_count: usize,
 }
 
+pub type NetDriver<'a> = ch::Device<'a, MTU>;
+
 pub async fn new<'a, PWR, SPI>(
-    state: &'a State,
+    state: &'a mut State,
     pwr: PWR,
     spi: SPI,
     firmware: &[u8],
-) -> (Control<'a>, Runner<'a, PWR, SPI>)
+) -> (NetDriver<'a>, Control<'a>, Runner<'a, PWR, SPI>)
 where
     PWR: OutputPin,
     SPI: SpiDevice,
     SPI::Bus: SpiBusRead<u32> + SpiBusWrite<u32>,
 {
+    let (ch_runner, device) = ch::new(&mut state.ch, [0; 6]);
+    let state_ch = ch_runner.state_runner();
+
     let mut runner = Runner {
+        ch: ch_runner,
         bus: Bus::new(pwr, spi),
 
-        state,
+        ioctl_state: &state.ioctl_state,
         ioctl_id: 0,
         sdpcm_seq: 0,
         sdpcm_seq_max: 1,
@@ -496,7 +448,14 @@ where
 
     runner.init(firmware).await;
 
-    (Control { state }, runner)
+    (
+        device,
+        Control {
+            state_ch,
+            ioctl_state: &state.ioctl_state,
+        },
+        runner,
+    )
 }
 
 impl<'a, PWR, SPI> Runner<'a, PWR, SPI>
@@ -662,15 +621,55 @@ where
             if !self.has_credit() {
                 warn!("TX stalled");
             } else {
-                if let IoctlState::Pending { kind, cmd, iface, buf } = self.state.ioctl_state.get() {
+                if let IoctlState::Pending { kind, cmd, iface, buf } = self.ioctl_state.get() {
                     self.send_ioctl(kind, cmd, iface, unsafe { &*buf }).await;
-                    self.state.ioctl_state.set(IoctlState::Sent { buf });
+                    self.ioctl_state.set(IoctlState::Sent { buf });
                 }
                 if !self.has_credit() {
                     warn!("TX stalled");
                 } else {
-                    if let Ok(p) = self.state.tx_channel.try_recv() {
-                        self.send_packet(&p).await;
+                    if let Some(packet) = self.ch.try_tx_buf() {
+                        trace!("tx pkt {:02x}", &packet[..packet.len().min(48)]);
+
+                        let mut buf = [0; 512];
+                        let buf8 = slice8_mut(&mut buf);
+
+                        let total_len = SdpcmHeader::SIZE + BcdHeader::SIZE + packet.len();
+
+                        let seq = self.sdpcm_seq;
+                        self.sdpcm_seq = self.sdpcm_seq.wrapping_add(1);
+
+                        let sdpcm_header = SdpcmHeader {
+                            len: total_len as u16, // TODO does this len need to be rounded up to u32?
+                            len_inv: !total_len as u16,
+                            sequence: seq,
+                            channel_and_flags: CHANNEL_TYPE_DATA,
+                            next_length: 0,
+                            header_length: SdpcmHeader::SIZE as _,
+                            wireless_flow_control: 0,
+                            bus_data_credit: 0,
+                            reserved: [0, 0],
+                        };
+
+                        let bcd_header = BcdHeader {
+                            flags: BDC_VERSION << BDC_VERSION_SHIFT,
+                            priority: 0,
+                            flags2: 0,
+                            data_offset: 0,
+                        };
+                        trace!("tx {:?}", sdpcm_header);
+                        trace!("    {:?}", bcd_header);
+
+                        buf8[0..SdpcmHeader::SIZE].copy_from_slice(&sdpcm_header.to_bytes());
+                        buf8[SdpcmHeader::SIZE..][..BcdHeader::SIZE].copy_from_slice(&bcd_header.to_bytes());
+                        buf8[SdpcmHeader::SIZE + BcdHeader::SIZE..][..packet.len()].copy_from_slice(packet);
+
+                        let total_len = (total_len + 3) & !3; // round up to 4byte
+
+                        trace!("    {:02x}", &buf8[..total_len.min(48)]);
+
+                        self.bus.wlan_write(&buf[..(total_len / 4)]).await;
+                        self.ch.tx_done();
                     }
                 }
             }
@@ -686,7 +685,6 @@ where
 
                 if status & STATUS_F2_PKT_AVAILABLE != 0 {
                     let len = (status & STATUS_F2_PKT_LEN_MASK) >> STATUS_F2_PKT_LEN_SHIFT;
-
                     self.bus.wlan_read(&mut buf[..(len as usize + 3) / 4]).await;
                     trace!("rx {:02x}", &slice8_mut(&mut buf)[..(len as usize).min(48)]);
                     self.rx(&slice8_mut(&mut buf)[..len as usize]);
@@ -698,49 +696,6 @@ where
         }
     }
 
-    async fn send_packet(&mut self, packet: &[u8]) {
-        trace!("tx pkt {:02x}", &packet[..packet.len().min(48)]);
-
-        let mut buf = [0; 512];
-        let buf8 = slice8_mut(&mut buf);
-
-        let total_len = SdpcmHeader::SIZE + BcdHeader::SIZE + packet.len();
-
-        let seq = self.sdpcm_seq;
-        self.sdpcm_seq = self.sdpcm_seq.wrapping_add(1);
-
-        let sdpcm_header = SdpcmHeader {
-            len: total_len as u16, // TODO does this len need to be rounded up to u32?
-            len_inv: !total_len as u16,
-            sequence: seq,
-            channel_and_flags: CHANNEL_TYPE_DATA,
-            next_length: 0,
-            header_length: SdpcmHeader::SIZE as _,
-            wireless_flow_control: 0,
-            bus_data_credit: 0,
-            reserved: [0, 0],
-        };
-
-        let bcd_header = BcdHeader {
-            flags: BDC_VERSION << BDC_VERSION_SHIFT,
-            priority: 0,
-            flags2: 0,
-            data_offset: 0,
-        };
-        trace!("tx {:?}", sdpcm_header);
-        trace!("    {:?}", bcd_header);
-
-        buf8[0..SdpcmHeader::SIZE].copy_from_slice(&sdpcm_header.to_bytes());
-        buf8[SdpcmHeader::SIZE..][..BcdHeader::SIZE].copy_from_slice(&bcd_header.to_bytes());
-        buf8[SdpcmHeader::SIZE + BcdHeader::SIZE..][..packet.len()].copy_from_slice(packet);
-
-        let total_len = (total_len + 3) & !3; // round up to 4byte
-
-        trace!("    {:02x}", &buf8[..total_len.min(48)]);
-
-        self.bus.wlan_write(&buf[..(total_len / 4)]).await;
-    }
-
     fn rx(&mut self, packet: &[u8]) {
         if packet.len() < SdpcmHeader::SIZE {
             warn!("packet too short, len={}", packet.len());
@@ -775,7 +730,7 @@ where
                 let cdc_header = CdcHeader::from_bytes(payload[..CdcHeader::SIZE].try_into().unwrap());
                 trace!("    {:?}", cdc_header);
 
-                if let IoctlState::Sent { buf } = self.state.ioctl_state.get() {
+                if let IoctlState::Sent { buf } = self.ioctl_state.get() {
                     if cdc_header.id == self.ioctl_id {
                         if cdc_header.status != 0 {
                             // TODO: propagate error instead
@@ -786,7 +741,7 @@ where
                         info!("IOCTL Response: {:02x}", &payload[CdcHeader::SIZE..][..resp_len]);
 
                         (unsafe { &mut *buf }[..resp_len]).copy_from_slice(&payload[CdcHeader::SIZE..][..resp_len]);
-                        self.state.ioctl_state.set(IoctlState::Done { resp_len });
+                        self.ioctl_state.set(IoctlState::Done { resp_len });
                     }
                 }
             }
@@ -859,11 +814,12 @@ where
                 let packet = &payload[packet_start..];
                 trace!("rx pkt {:02x}", &packet[..(packet.len() as usize).min(48)]);
 
-                let mut p = unwrap!(embassy_net::PacketBox::new(embassy_net::Packet::new()));
-                p[..packet.len()].copy_from_slice(packet);
-
-                if let Err(_) = self.state.rx_channel.try_send(p.slice(0..packet.len())) {
-                    warn!("failed to push rxd packet to the channel.")
+                match self.ch.try_rx_buf() {
+                    Some(buf) => {
+                        buf[..packet.len()].copy_from_slice(packet);
+                        self.ch.rx_done(packet.len())
+                    }
+                    None => warn!("failed to push rxd packet to the channel."),
                 }
             }
             _ => {}