From e484cb1b875adc7a0865299df5167f94302fcf49 Mon Sep 17 00:00:00 2001
From: Leon Camus <leon.c@gmx.de>
Date: Wed, 8 Mar 2023 12:37:00 +0100
Subject: [PATCH] refactor: Multicast method modifiers on stack to public
 revert: udp.rs

---
 embassy-net/src/lib.rs |  6 +++---
 embassy-net/src/udp.rs | 47 ++++++++++++++----------------------------
 2 files changed, 18 insertions(+), 35 deletions(-)

diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index 05bbd07f2..7b9d0e773 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -306,7 +306,7 @@ impl<D: Driver + 'static> Stack<D> {
 
 #[cfg(feature = "igmp")]
 impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
-    pub(crate) fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+    pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
     where
         T: Into<IpAddress>,
     {
@@ -318,7 +318,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
         })
     }
 
-    pub(crate) fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+    pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
     where
         T: Into<IpAddress>,
     {
@@ -330,7 +330,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
         })
     }
 
-    pub(crate) fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
+    pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
         self.socket.borrow().iface.has_multicast_group(addr)
     }
 }
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
index 12bdbf402..0ee8c6e19 100644
--- a/embassy-net/src/udp.rs
+++ b/embassy-net/src/udp.rs
@@ -1,3 +1,4 @@
+use core::cell::RefCell;
 use core::future::poll_fn;
 use core::mem;
 use core::task::Poll;
@@ -7,7 +8,7 @@ use smoltcp::iface::{Interface, SocketHandle};
 use smoltcp::socket::udp::{self, PacketMetadata};
 use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
 
-use crate::Stack;
+use crate::{SocketStack, Stack};
 
 #[derive(PartialEq, Eq, Clone, Copy, Debug)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
@@ -25,13 +26,13 @@ pub enum Error {
     NoRoute,
 }
 
-pub struct UdpSocket<'a, D: Driver> {
-    stack: &'a Stack<D>,
+pub struct UdpSocket<'a> {
+    stack: &'a RefCell<SocketStack>,
     handle: SocketHandle,
 }
 
-impl<'a, D: Driver> UdpSocket<'a, D> {
-    pub fn new(
+impl<'a> UdpSocket<'a> {
+    pub fn new<D: Driver>(
         stack: &'a Stack<D>,
         rx_meta: &'a mut [PacketMetadata],
         rx_buffer: &'a mut [u8],
@@ -49,7 +50,10 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
             udp::PacketBuffer::new(tx_meta, tx_buffer),
         ));
 
-        Self { stack, handle }
+        Self {
+            stack: &stack.socket,
+            handle,
+        }
     }
 
     pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError>
@@ -60,7 +64,7 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
 
         if endpoint.port == 0 {
             // If user didn't specify port allocate a dynamic port.
-            endpoint.port = self.stack.socket.borrow_mut().get_local_port();
+            endpoint.port = self.stack.borrow_mut().get_local_port();
         }
 
         match self.with_mut(|s, _| s.bind(endpoint)) {
@@ -71,13 +75,13 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
     }
 
     fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
-        let s = &*self.stack.socket.borrow();
+        let s = &*self.stack.borrow();
         let socket = s.sockets.get::<udp::Socket>(self.handle);
         f(socket, &s.iface)
     }
 
     fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
-        let s = &mut *self.stack.socket.borrow_mut();
+        let s = &mut *self.stack.borrow_mut();
         let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
         let res = f(socket, &mut s.iface);
         s.waker.wake();
@@ -139,29 +143,8 @@ impl<'a, D: Driver> UdpSocket<'a, D> {
     }
 }
 
-#[cfg(feature = "igmp")]
-impl<'a, D: Driver + smoltcp::phy::Device + 'static> UdpSocket<'a, D> {
-    pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
-    where
-        T: Into<smoltcp::wire::IpAddress>,
-    {
-        self.stack.join_multicast_group(addr)
-    }
-
-    pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
-    where
-        T: Into<smoltcp::wire::IpAddress>,
-    {
-        self.stack.leave_multicast_group(addr)
-    }
-
-    pub fn has_multicast_group<T: Into<smoltcp::wire::IpAddress>>(&self, addr: T) -> bool {
-        self.stack.has_multicast_group(addr)
-    }
-}
-
-impl<D: Driver> Drop for UdpSocket<'_, D> {
+impl Drop for UdpSocket<'_> {
     fn drop(&mut self) {
-        self.stack.socket.borrow_mut().sockets.remove(self.handle);
+        self.stack.borrow_mut().sockets.remove(self.handle);
     }
 }