diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index 05bbd07f2..7b9d0e773 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -306,7 +306,7 @@ impl<D: Driver + 'static> Stack<D> { #[cfg(feature = "igmp")] impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> { - pub(crate) fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> + pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> where T: Into<IpAddress>, { @@ -318,7 +318,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> { }) } - pub(crate) fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> + pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> where T: Into<IpAddress>, { @@ -330,7 +330,7 @@ impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> { }) } - pub(crate) fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { + pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { self.socket.borrow().iface.has_multicast_group(addr) } } diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs index 12bdbf402..0ee8c6e19 100644 --- a/embassy-net/src/udp.rs +++ b/embassy-net/src/udp.rs @@ -1,3 +1,4 @@ +use core::cell::RefCell; use core::future::poll_fn; use core::mem; use core::task::Poll; @@ -7,7 +8,7 @@ use smoltcp::iface::{Interface, SocketHandle}; use smoltcp::socket::udp::{self, PacketMetadata}; use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; -use crate::Stack; +use crate::{SocketStack, Stack}; #[derive(PartialEq, Eq, Clone, Copy, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -25,13 +26,13 @@ pub enum Error { NoRoute, } -pub struct UdpSocket<'a, D: Driver> { - stack: &'a Stack<D>, +pub struct UdpSocket<'a> { + stack: &'a RefCell<SocketStack>, handle: SocketHandle, } -impl<'a, D: Driver> UdpSocket<'a, D> { - pub fn new( +impl<'a> UdpSocket<'a> { + pub fn new<D: Driver>( stack: &'a Stack<D>, rx_meta: &'a mut [PacketMetadata], rx_buffer: &'a mut [u8], @@ -49,7 +50,10 @@ impl<'a, D: Driver> UdpSocket<'a, D> { udp::PacketBuffer::new(tx_meta, tx_buffer), )); - Self { stack, handle } + Self { + stack: &stack.socket, + handle, + } } pub fn bind<T>(&mut self, endpoint: T) -> Result<(), BindError> @@ -60,7 +64,7 @@ impl<'a, D: Driver> UdpSocket<'a, D> { if endpoint.port == 0 { // If user didn't specify port allocate a dynamic port. - endpoint.port = self.stack.socket.borrow_mut().get_local_port(); + endpoint.port = self.stack.borrow_mut().get_local_port(); } match self.with_mut(|s, _| s.bind(endpoint)) { @@ -71,13 +75,13 @@ impl<'a, D: Driver> UdpSocket<'a, D> { } fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R { - let s = &*self.stack.socket.borrow(); + let s = &*self.stack.borrow(); let socket = s.sockets.get::<udp::Socket>(self.handle); f(socket, &s.iface) } fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R { - let s = &mut *self.stack.socket.borrow_mut(); + let s = &mut *self.stack.borrow_mut(); let socket = s.sockets.get_mut::<udp::Socket>(self.handle); let res = f(socket, &mut s.iface); s.waker.wake(); @@ -139,29 +143,8 @@ impl<'a, D: Driver> UdpSocket<'a, D> { } } -#[cfg(feature = "igmp")] -impl<'a, D: Driver + smoltcp::phy::Device + 'static> UdpSocket<'a, D> { - pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> - where - T: Into<smoltcp::wire::IpAddress>, - { - self.stack.join_multicast_group(addr) - } - - pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError> - where - T: Into<smoltcp::wire::IpAddress>, - { - self.stack.leave_multicast_group(addr) - } - - pub fn has_multicast_group<T: Into<smoltcp::wire::IpAddress>>(&self, addr: T) -> bool { - self.stack.has_multicast_group(addr) - } -} - -impl<D: Driver> Drop for UdpSocket<'_, D> { +impl Drop for UdpSocket<'_> { fn drop(&mut self) { - self.stack.socket.borrow_mut().sockets.remove(self.handle); + self.stack.borrow_mut().sockets.remove(self.handle); } }