From cd440a49d677f7dfc09e405d99b87a49fba9ba31 Mon Sep 17 00:00:00 2001
From: Ulf Lilleengen <lulf@redhat.com>
Date: Fri, 10 Feb 2023 17:43:23 +0100
Subject: [PATCH] Rewrite to use a single socket

---
 embassy-net/Cargo.toml          |   2 +-
 embassy-net/src/dns.rs          | 124 +++++++++++++-------------------
 embassy-net/src/lib.rs          |  83 ++++++++++++++++++++-
 examples/std/src/bin/net_dns.rs |   3 +-
 4 files changed, 133 insertions(+), 79 deletions(-)

diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml
index 6eea8c307..53778899b 100644
--- a/embassy-net/Cargo.toml
+++ b/embassy-net/Cargo.toml
@@ -52,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false }
 stable_deref_trait = { version = "1.2.0", default-features = false }
 futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] }
 atomic-pool = "1.0"
-embedded-nal-async = { version = "0.3.0", optional = true }
+embedded-nal-async = { version = "0.4.0", optional = true }
 atomic-polyfill = { version = "1.0" }
diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs
index e98247bfd..1815d258f 100644
--- a/embassy-net/src/dns.rs
+++ b/embassy-net/src/dns.rs
@@ -1,19 +1,10 @@
 //! DNS socket with async support.
-use core::cell::RefCell;
-use core::future::poll_fn;
-use core::mem;
-use core::task::Poll;
-
-use embassy_hal_common::drop::OnDrop;
-use embassy_net_driver::Driver;
 use heapless::Vec;
-use managed::ManagedSlice;
-use smoltcp::iface::{Interface, SocketHandle};
-pub use smoltcp::socket::dns::DnsQuery;
-use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError, MAX_ADDRESS_COUNT};
+pub use smoltcp::socket::dns::{DnsQuery, Socket, MAX_ADDRESS_COUNT};
+pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
 pub use smoltcp::wire::{DnsQueryType, IpAddress};
 
-use crate::{SocketStack, Stack};
+use crate::{Driver, Stack};
 
 /// Errors returned by DnsSocket.
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@@ -46,81 +37,64 @@ impl From<StartQueryError> for Error {
 }
 
 /// Async socket for making DNS queries.
-pub struct DnsSocket<'a> {
-    stack: &'a RefCell<SocketStack>,
-    handle: SocketHandle,
+pub struct DnsSocket<'a, D>
+where
+    D: Driver + 'static,
+{
+    stack: &'a Stack<D>,
 }
 
-impl<'a> DnsSocket<'a> {
+impl<'a, D> DnsSocket<'a, D>
+where
+    D: Driver + 'static,
+{
     /// Create a new DNS socket using the provided stack and query storage.
     ///
     /// DNS servers are derived from the stack configuration.
     ///
     /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
-    pub fn new<D, Q>(stack: &'a Stack<D>, queries: Q) -> Self
-    where
-        D: Driver + 'static,
-        Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
-    {
-        let servers = stack
-            .config()
-            .map(|c| {
-                let v: Vec<IpAddress, 3> = c.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
-                v
-            })
-            .unwrap_or(Vec::new());
-        let s = &mut *stack.socket.borrow_mut();
-        let queries: ManagedSlice<'static, Option<DnsQuery>> = unsafe { mem::transmute(queries.into()) };
-
-        let handle = s.sockets.add(dns::Socket::new(&servers[..], queries));
-        Self {
-            stack: &stack.socket,
-            handle,
-        }
-    }
-
-    fn with_mut<R>(&mut self, f: impl FnOnce(&mut dns::Socket, &mut Interface) -> R) -> R {
-        let s = &mut *self.stack.borrow_mut();
-        let socket = s.sockets.get_mut::<dns::Socket>(self.handle);
-        let res = f(socket, &mut s.iface);
-        s.waker.wake();
-        res
+    pub fn new(stack: &'a Stack<D>) -> Self {
+        Self { stack }
     }
 
     /// Make a query for a given name and return the corresponding IP addresses.
-    pub async fn query(&mut self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> {
-        let query = match { self.with_mut(|s, i| s.start_query(i.context(), name, qtype)) } {
-            Ok(handle) => handle,
-            Err(e) => return Err(e.into()),
+    pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> {
+        self.stack.dns_query(name, qtype).await
+    }
+}
+
+#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
+impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D>
+where
+    D: Driver + 'static,
+{
+    type Error = Error;
+
+    async fn get_host_by_name(
+        &self,
+        host: &str,
+        addr_type: embedded_nal_async::AddrType,
+    ) -> Result<embedded_nal_async::IpAddr, Self::Error> {
+        use embedded_nal_async::{AddrType, IpAddr};
+        let qtype = match addr_type {
+            AddrType::IPv6 => DnsQueryType::Aaaa,
+            _ => DnsQueryType::A,
         };
-
-        let handle = self.handle;
-        let drop = OnDrop::new(|| {
-            let s = &mut *self.stack.borrow_mut();
-            let socket = s.sockets.get_mut::<dns::Socket>(handle);
-            socket.cancel_query(query);
-            s.waker.wake();
-        });
-
-        let res = poll_fn(|cx| {
-            self.with_mut(|s, _| match s.get_query_result(query) {
-                Ok(addrs) => Poll::Ready(Ok(addrs)),
-                Err(GetQueryResultError::Pending) => {
-                    s.register_query_waker(query, cx.waker());
-                    Poll::Pending
-                }
-                Err(e) => Poll::Ready(Err(e.into())),
+        let addrs = self.query(host, qtype).await?;
+        if let Some(first) = addrs.get(0) {
+            Ok(match first {
+                IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()),
+                IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()),
             })
-        })
-        .await;
+        } else {
+            Err(Error::Failed)
+        }
+    }
 
-        drop.defuse();
-        res
-    }
-}
-
-impl<'a> Drop for DnsSocket<'a> {
-    fn drop(&mut self) {
-        self.stack.borrow_mut().sockets.remove(self.handle);
+    async fn get_host_by_address(
+        &self,
+        _addr: embedded_nal_async::IpAddr,
+    ) -> Result<heapless::String<256>, Self::Error> {
+        todo!()
     }
 }
diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs
index ae447d063..b63aa83df 100644
--- a/embassy-net/src/lib.rs
+++ b/embassy-net/src/lib.rs
@@ -48,15 +48,22 @@ use crate::device::DriverAdapter;
 
 const LOCAL_PORT_MIN: u16 = 1025;
 const LOCAL_PORT_MAX: u16 = 65535;
+const MAX_QUERIES: usize = 2;
 
 pub struct StackResources<const SOCK: usize> {
     sockets: [SocketStorage<'static>; SOCK],
+    #[cfg(feature = "dns")]
+    queries: Option<[Option<dns::DnsQuery>; MAX_QUERIES]>,
 }
 
 impl<const SOCK: usize> StackResources<SOCK> {
     pub fn new() -> Self {
+        #[cfg(feature = "dns")]
+        const INIT: Option<dns::DnsQuery> = None;
         Self {
             sockets: [SocketStorage::EMPTY; SOCK],
+            #[cfg(feature = "dns")]
+            queries: Some([INIT; MAX_QUERIES]),
         }
     }
 }
@@ -109,6 +116,8 @@ struct Inner<D: Driver> {
     config: Option<StaticConfig>,
     #[cfg(feature = "dhcpv4")]
     dhcp_socket: Option<SocketHandle>,
+    #[cfg(feature = "dns")]
+    dns_socket: Option<SocketHandle>,
 }
 
 pub(crate) struct SocketStack {
@@ -153,6 +162,8 @@ impl<D: Driver + 'static> Stack<D> {
             config: None,
             #[cfg(feature = "dhcpv4")]
             dhcp_socket: None,
+            #[cfg(feature = "dns")]
+            dns_socket: None,
         };
         let mut socket = SocketStack {
             sockets,
@@ -161,8 +172,17 @@ impl<D: Driver + 'static> Stack<D> {
             next_local_port,
         };
 
+        #[cfg(feature = "dns")]
+        {
+            if let Some(queries) = resources.queries.take() {
+                inner.dns_socket = Some(socket.sockets.add(dns::Socket::new(&[], queries)));
+            }
+        }
+
         match config {
-            Config::Static(config) => inner.apply_config(&mut socket, config),
+            Config::Static(config) => {
+                inner.apply_config(&mut socket, config);
+            }
             #[cfg(feature = "dhcpv4")]
             Config::Dhcp(config) => {
                 let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
@@ -210,6 +230,59 @@ impl<D: Driver + 'static> Stack<D> {
         .await;
         unreachable!()
     }
+
+    #[cfg(feature = "dns")]
+    async fn dns_query(
+        &self,
+        name: &str,
+        qtype: dns::DnsQueryType,
+    ) -> Result<Vec<IpAddress, { dns::MAX_ADDRESS_COUNT }>, dns::Error> {
+        let query = self.with_mut(|s, i| {
+            if let Some(dns_handle) = i.dns_socket {
+                let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
+                match socket.start_query(s.iface.context(), name, qtype) {
+                    Ok(handle) => Ok(handle),
+                    Err(e) => Err(e.into()),
+                }
+            } else {
+                Err(dns::Error::Failed)
+            }
+        })?;
+
+        use embassy_hal_common::drop::OnDrop;
+        let drop = OnDrop::new(|| {
+            self.with_mut(|s, i| {
+                if let Some(dns_handle) = i.dns_socket {
+                    let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
+                    socket.cancel_query(query);
+                    s.waker.wake();
+                }
+            })
+        });
+
+        let res = poll_fn(|cx| {
+            self.with_mut(|s, i| {
+                if let Some(dns_handle) = i.dns_socket {
+                    let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
+                    match socket.get_query_result(query) {
+                        Ok(addrs) => Poll::Ready(Ok(addrs)),
+                        Err(dns::GetQueryResultError::Pending) => {
+                            socket.register_query_waker(query, cx.waker());
+                            Poll::Pending
+                        }
+                        Err(e) => Poll::Ready(Err(e.into())),
+                    }
+                } else {
+                    Poll::Ready(Err(dns::Error::Failed))
+                }
+            })
+        })
+        .await;
+
+        drop.defuse();
+
+        res
+    }
 }
 
 impl SocketStack {
@@ -251,6 +324,13 @@ impl<D: Driver + 'static> Inner<D> {
             debug!("   DNS server {}:    {}", i, s);
         }
 
+        #[cfg(feature = "dns")]
+        if let Some(dns_socket) = self.dns_socket {
+            let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(dns_socket);
+            let servers: Vec<IpAddress, 3> = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
+            socket.update_servers(&servers[..]);
+        }
+
         self.config = Some(config)
     }
 
@@ -326,6 +406,7 @@ impl<D: Driver + 'static> Inner<D> {
         //if old_link_up || self.link_up {
         //    self.poll_configurator(timestamp)
         //}
+        //
 
         if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
             let t = Timer::at(instant_from_smoltcp(poll_at));
diff --git a/examples/std/src/bin/net_dns.rs b/examples/std/src/bin/net_dns.rs
index 6203f8370..e787cb823 100644
--- a/examples/std/src/bin/net_dns.rs
+++ b/examples/std/src/bin/net_dns.rs
@@ -71,8 +71,7 @@ async fn main_task(spawner: Spawner) {
     spawner.spawn(net_task(stack)).unwrap();
 
     // Then we can use it!
-
-    let mut socket = DnsSocket::new(stack, vec![]);
+    let socket = DnsSocket::new(stack);
 
     let host = "example.com";
     info!("querying host {:?}...", host);