From 1b410d6f3f08f12f2bd250a8b76f217291f4df26 Mon Sep 17 00:00:00 2001
From: kbleeke <pluth@0t.re>
Date: Wed, 1 Mar 2023 19:03:46 +0100
Subject: [PATCH] add event handling to join

---
 examples/rpi-pico-w/src/main.rs |  9 ++++---
 src/events.rs                   | 14 +++++++++++
 src/lib.rs                      | 42 +++++++++++++++++++++++++++------
 3 files changed, 53 insertions(+), 12 deletions(-)

diff --git a/examples/rpi-pico-w/src/main.rs b/examples/rpi-pico-w/src/main.rs
index c706e121d..91caa5e3a 100644
--- a/examples/rpi-pico-w/src/main.rs
+++ b/examples/rpi-pico-w/src/main.rs
@@ -70,16 +70,11 @@ async fn main(spawner: Spawner) {
     let state = singleton!(cyw43::State::new());
     let (net_device, mut control, runner) = cyw43::new(state, pwr, spi, fw).await;
 
-    spawner.spawn(wifi_task(runner)).unwrap();
-
     control.init(clm).await;
     control
         .set_power_management(cyw43::PowerManagementMode::PowerSave)
         .await;
 
-    //control.join_open(env!("WIFI_NETWORK")).await;
-    control.join_wpa2(env!("WIFI_NETWORK"), env!("WIFI_PASSWORD")).await;
-
     let config = Config::Dhcp(Default::default());
     //let config = embassy_net::Config::Static(embassy_net::Config {
     //    address: Ipv4Cidr::new(Ipv4Address::new(192, 168, 69, 2), 24),
@@ -98,8 +93,12 @@ async fn main(spawner: Spawner) {
         seed
     ));
 
+    unwrap!(spawner.spawn(wifi_task(runner)));
     unwrap!(spawner.spawn(net_task(stack)));
 
+    //control.join_open(env!("WIFI_NETWORK")).await;
+    control.join_wpa2(env!("WIFI_NETWORK"), env!("WIFI_PASSWORD")).await;
+
     // And now we can use it!
 
     let mut rx_buffer = [0; 4096];
diff --git a/src/events.rs b/src/events.rs
index a828eec98..9e6bb9625 100644
--- a/src/events.rs
+++ b/src/events.rs
@@ -3,6 +3,9 @@
 
 use core::num;
 
+use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
+use embassy_sync::pubsub::{PubSubChannel, Publisher, Subscriber};
+
 #[derive(Clone, Copy, PartialEq, Eq, num_enum::FromPrimitive)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
 #[repr(u8)]
@@ -280,3 +283,14 @@ pub enum Event {
     /// highest val + 1 for range checking
     LAST = 190,
 }
+
+pub type EventQueue = PubSubChannel<CriticalSectionRawMutex, EventStatus, 2, 1, 1>;
+pub type EventPublisher<'a> = Publisher<'a, CriticalSectionRawMutex, EventStatus, 2, 1, 1>;
+pub type EventSubscriber<'a> = Subscriber<'a, CriticalSectionRawMutex, EventStatus, 2, 1, 1>;
+
+#[derive(Clone, Copy)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
+pub struct EventStatus {
+    pub event_type: Event,
+    pub status: u32,
+}
diff --git a/src/lib.rs b/src/lib.rs
index 5733506ac..c58ac8e7d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -19,13 +19,15 @@ use core::slice;
 use ch::driver::LinkState;
 use embassy_futures::yield_now;
 use embassy_net_driver_channel as ch;
+use embassy_sync::pubsub::PubSubBehavior;
 use embassy_time::{block_for, Duration, Timer};
 use embedded_hal_1::digital::OutputPin;
 use embedded_hal_async::spi::{SpiBusRead, SpiBusWrite, SpiDevice};
+use events::EventQueue;
 
 use crate::bus::Bus;
 use crate::consts::*;
-use crate::events::Event;
+use crate::events::{Event, EventStatus};
 use crate::structs::*;
 
 const MTU: usize = 1514;
@@ -127,6 +129,7 @@ enum IoctlState {
 pub struct State {
     ioctl_state: Cell<IoctlState>,
     ch: ch::State<MTU, 4, 4>,
+    events: EventQueue,
 }
 
 impl State {
@@ -134,12 +137,14 @@ impl State {
         Self {
             ioctl_state: Cell::new(IoctlState::Idle),
             ch: ch::State::new(),
+            events: EventQueue::new(),
         }
     }
 }
 
 pub struct Control<'a> {
     state_ch: ch::StateRunner<'a>,
+    event_sub: &'a EventQueue,
     ioctl_state: &'a Cell<IoctlState>,
 }
 
@@ -313,6 +318,7 @@ impl<'a> Control<'a> {
         evts.unset(Event::PROBREQ_MSG_RX);
         evts.unset(Event::PROBRESP_MSG);
         evts.unset(Event::PROBRESP_MSG);
+        evts.unset(Event::ROAM);
 
         self.set_iovar("bsscfg:event_msgs", &evts.to_bytes()).await;
 
@@ -393,8 +399,22 @@ impl<'a> Control<'a> {
             ssid: [0; 32],
         };
         i.ssid[..ssid.len()].copy_from_slice(ssid.as_bytes());
+
+        let mut subscriber = self.event_sub.subscriber().unwrap();
         self.ioctl(IoctlType::Set, 26, 0, &mut i.to_bytes()).await; // set_ssid
 
+        loop {
+            let msg = subscriber.next_message_pure().await;
+            if msg.event_type == Event::AUTH && msg.status != 0 {
+                // retry
+                defmt::warn!("JOIN failed with status={}", msg.status);
+                self.ioctl(IoctlType::Set, 26, 0, &mut i.to_bytes()).await;
+            } else if msg.event_type == Event::JOIN && msg.status == 0 {
+                // successful join
+                break;
+            }
+        }
+
         info!("JOINED");
     }
 
@@ -489,6 +509,8 @@ pub struct Runner<'a, PWR, SPI> {
     sdpcm_seq: u8,
     sdpcm_seq_max: u8,
 
+    events: &'a EventQueue,
+
     #[cfg(feature = "firmware-logs")]
     log: LogState,
 }
@@ -526,6 +548,8 @@ where
         sdpcm_seq: 0,
         sdpcm_seq_max: 1,
 
+        events: &state.events,
+
         #[cfg(feature = "firmware-logs")]
         log: LogState {
             addr: 0,
@@ -541,6 +565,7 @@ where
         device,
         Control {
             state_ch,
+            event_sub: &&state.events,
             ioctl_state: &state.ioctl_state,
         },
         runner,
@@ -883,13 +908,16 @@ where
                     return;
                 }
 
+                let evt_type = events::Event::from(event_packet.msg.event_type as u8);
                 let evt_data = &bcd_packet[EventMessage::SIZE..][..event_packet.msg.datalen as usize];
-                debug!(
-                    "=== EVENT {}: {} {:02x}",
-                    events::Event::from(event_packet.msg.event_type as u8),
-                    event_packet.msg,
-                    evt_data
-                );
+                debug!("=== EVENT {}: {} {:02x}", evt_type, event_packet.msg, evt_data);
+
+                if evt_type == events::Event::AUTH || evt_type == events::Event::JOIN {
+                    self.events.publish_immediate(EventStatus {
+                        status: event_packet.msg.status,
+                        event_type: evt_type,
+                    });
+                }
             }
             CHANNEL_TYPE_DATA => {
                 let bcd_header = BcdHeader::from_bytes(&payload[..BcdHeader::SIZE].try_into().unwrap());