From c22218c72e4940a0aef3fc180ba18b557713cf40 Mon Sep 17 00:00:00 2001
From: Leon Camus <leon.c@gmx.de>
Date: Mon, 6 Mar 2023 17:50:57 +0100
Subject: [PATCH] feat: Add multicast to udp socket

---
 embassy-net/Cargo.toml |  3 ++-
 embassy-net/src/lib.rs | 37 +++++++++++++++++++++++++++++++++
 embassy-net/src/udp.rs | 47 ++++++++++++++++++++++++++++++------------
 3 files changed, 73 insertions(+), 14 deletions(-)

diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml
index ca34262df..1854d2043 100644
--- a/embassy-net/Cargo.toml
+++ b/embassy-net/Cargo.toml
@@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
 [package.metadata.embassy_docs]
 src_base = "https://github.com/embassy-rs/embassy/blob/embassy-net-v$VERSION/embassy-net/src/"
 src_base_git = "https://github.com/embassy-rs/embassy/blob/$COMMIT/embassy-net/src/"
-features = ["nightly", "unstable-traits", "defmt", "tcp", "udp", "dns", "dhcpv4", "proto-ipv6", "medium-ethernet", "medium-ip"]
+features = ["nightly", "unstable-traits", "defmt", "tcp", "udp", "dns", "dhcpv4", "proto-ipv6", "medium-ethernet", "medium-ip", "igmp"]
 target = "thumbv7em-none-eabi"
 
 [features]
@@ -27,6 +27,7 @@ dhcpv4 = ["medium-ethernet", "smoltcp/socket-dhcpv4"]
 proto-ipv6 = ["smoltcp/proto-ipv6"]
 medium-ethernet = ["smoltcp/medium-ethernet"]
 medium-ip = ["smoltcp/medium-ip"]
+igmp = ["smoltcp/proto-igmp"]
 
 [dependencies]
 
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index 4ec1b5a77..57055bd77 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -304,6 +304,43 @@ impl<D: Driver + 'static> Stack<D> {
     }
 }
 
+#[cfg(feature = "igmp")]
+impl<D: Driver + smoltcp::phy::Device + 'static> Stack<D> {
+    pub(crate) fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+        where
+            T: Into<IpAddress>
+    {
+        let addr = addr.into();
+
+        self.with_mut(|s, i| {
+            s.iface.join_multicast_group(
+                &mut i.device,
+                addr,
+                instant_to_smoltcp(Instant::now()),
+            )
+        })
+    }
+
+    pub(crate) fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+        where
+            T: Into<IpAddress>
+    {
+        let addr = addr.into();
+
+        self.with_mut(|s, i| {
+            s.iface.leave_multicast_group(
+                &mut i.device,
+                addr,
+                instant_to_smoltcp(Instant::now()),
+            )
+        })
+    }
+
+    pub(crate) fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
+        self.socket.borrow().iface.has_multicast_group(addr)
+    }
+}
+
 impl SocketStack {
     #[allow(clippy::absurd_extreme_comparisons, dead_code)]
     pub fn get_local_port(&mut self) -> u16 {
diff --git a/embassy-net/src/udp.rs b/embassy-net/src/udp.rs
index 0ee8c6e19..c840eeaa2 100644
--- a/embassy-net/src/udp.rs
+++ b/embassy-net/src/udp.rs
@@ -6,9 +6,9 @@ use core::task::Poll;
 use embassy_net_driver::Driver;
 use smoltcp::iface::{Interface, SocketHandle};
 use smoltcp::socket::udp::{self, PacketMetadata};
-use smoltcp::wire::{IpEndpoint, IpListenEndpoint};
+use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint};
 
-use crate::{SocketStack, Stack};
+use crate::Stack;
 
 #[derive(PartialEq, Eq, Clone, Copy, Debug)]
 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
@@ -26,13 +26,13 @@ pub enum Error {
     NoRoute,
 }
 
-pub struct UdpSocket<'a> {
-    stack: &'a RefCell<SocketStack>,
+pub struct UdpSocket<'a, D: Driver> {
+    stack: &'a Stack<D>,
     handle: SocketHandle,
 }
 
-impl<'a> UdpSocket<'a> {
-    pub fn new<D: Driver>(
+impl<'a, D: Driver> UdpSocket<'a, D> {
+    pub fn new(
         stack: &'a Stack<D>,
         rx_meta: &'a mut [PacketMetadata],
         rx_buffer: &'a mut [u8],
@@ -51,7 +51,7 @@ impl<'a> UdpSocket<'a> {
         ));
 
         Self {
-            stack: &stack.socket,
+            stack,
             handle,
         }
     }
@@ -64,7 +64,7 @@ impl<'a> UdpSocket<'a> {
 
         if endpoint.port == 0 {
             // If user didn't specify port allocate a dynamic port.
-            endpoint.port = self.stack.borrow_mut().get_local_port();
+            endpoint.port = self.stack.socket.borrow_mut().get_local_port();
         }
 
         match self.with_mut(|s, _| s.bind(endpoint)) {
@@ -75,13 +75,13 @@ impl<'a> UdpSocket<'a> {
     }
 
     fn with<R>(&self, f: impl FnOnce(&udp::Socket, &Interface) -> R) -> R {
-        let s = &*self.stack.borrow();
+        let s = &*self.stack.socket.borrow();
         let socket = s.sockets.get::<udp::Socket>(self.handle);
         f(socket, &s.iface)
     }
 
     fn with_mut<R>(&self, f: impl FnOnce(&mut udp::Socket, &mut Interface) -> R) -> R {
-        let s = &mut *self.stack.borrow_mut();
+        let s = &mut *self.stack.socket.borrow_mut();
         let socket = s.sockets.get_mut::<udp::Socket>(self.handle);
         let res = f(socket, &mut s.iface);
         s.waker.wake();
@@ -143,8 +143,29 @@ impl<'a> UdpSocket<'a> {
     }
 }
 
-impl Drop for UdpSocket<'_> {
-    fn drop(&mut self) {
-        self.stack.borrow_mut().sockets.remove(self.handle);
+#[cfg(feature = "igmp")]
+impl<'a, D: Driver + smoltcp::phy::Device + 'static> UdpSocket<'a, D> {
+    pub fn join_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+        where
+            T: Into<IpAddress>
+    {
+        self.stack.join_multicast_group(addr)
+    }
+
+    pub fn leave_multicast_group<T>(&self, addr: T) -> Result<bool, smoltcp::iface::MulticastError>
+        where
+            T: Into<IpAddress>
+    {
+        self.stack.leave_multicast_group(addr)
+    }
+
+    pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool {
+        self.stack.has_multicast_group(addr)
+    }
+}
+
+impl<D: Driver> Drop for UdpSocket<'_, D> {
+    fn drop(&mut self) {
+        self.stack.socket.borrow_mut().sockets.remove(self.handle);
     }
 }