From cd440a49d677f7dfc09e405d99b87a49fba9ba31 Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Fri, 10 Feb 2023 17:43:23 +0100 Subject: [PATCH] Rewrite to use a single socket --- embassy-net/Cargo.toml | 2 +- embassy-net/src/dns.rs | 124 +++++++++++++------------------- embassy-net/src/lib.rs | 83 ++++++++++++++++++++- examples/std/src/bin/net_dns.rs | 3 +- 4 files changed, 133 insertions(+), 79 deletions(-) diff --git a/embassy-net/Cargo.toml b/embassy-net/Cargo.toml index 6eea8c307..53778899b 100644 --- a/embassy-net/Cargo.toml +++ b/embassy-net/Cargo.toml @@ -52,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false } stable_deref_trait = { version = "1.2.0", default-features = false } futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] } atomic-pool = "1.0" -embedded-nal-async = { version = "0.3.0", optional = true } +embedded-nal-async = { version = "0.4.0", optional = true } atomic-polyfill = { version = "1.0" } diff --git a/embassy-net/src/dns.rs b/embassy-net/src/dns.rs index e98247bfd..1815d258f 100644 --- a/embassy-net/src/dns.rs +++ b/embassy-net/src/dns.rs @@ -1,19 +1,10 @@ //! DNS socket with async support. -use core::cell::RefCell; -use core::future::poll_fn; -use core::mem; -use core::task::Poll; - -use embassy_hal_common::drop::OnDrop; -use embassy_net_driver::Driver; use heapless::Vec; -use managed::ManagedSlice; -use smoltcp::iface::{Interface, SocketHandle}; -pub use smoltcp::socket::dns::DnsQuery; -use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError, MAX_ADDRESS_COUNT}; +pub use smoltcp::socket::dns::{DnsQuery, Socket, MAX_ADDRESS_COUNT}; +pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError}; pub use smoltcp::wire::{DnsQueryType, IpAddress}; -use crate::{SocketStack, Stack}; +use crate::{Driver, Stack}; /// Errors returned by DnsSocket. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -46,81 +37,64 @@ impl From for Error { } /// Async socket for making DNS queries. -pub struct DnsSocket<'a> { - stack: &'a RefCell, - handle: SocketHandle, +pub struct DnsSocket<'a, D> +where + D: Driver + 'static, +{ + stack: &'a Stack, } -impl<'a> DnsSocket<'a> { +impl<'a, D> DnsSocket<'a, D> +where + D: Driver + 'static, +{ /// Create a new DNS socket using the provided stack and query storage. /// /// DNS servers are derived from the stack configuration. /// /// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated. - pub fn new(stack: &'a Stack, queries: Q) -> Self - where - D: Driver + 'static, - Q: Into>>, - { - let servers = stack - .config() - .map(|c| { - let v: Vec = c.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect(); - v - }) - .unwrap_or(Vec::new()); - let s = &mut *stack.socket.borrow_mut(); - let queries: ManagedSlice<'static, Option> = unsafe { mem::transmute(queries.into()) }; - - let handle = s.sockets.add(dns::Socket::new(&servers[..], queries)); - Self { - stack: &stack.socket, - handle, - } - } - - fn with_mut(&mut self, f: impl FnOnce(&mut dns::Socket, &mut Interface) -> R) -> R { - let s = &mut *self.stack.borrow_mut(); - let socket = s.sockets.get_mut::(self.handle); - let res = f(socket, &mut s.iface); - s.waker.wake(); - res + pub fn new(stack: &'a Stack) -> Self { + Self { stack } } /// Make a query for a given name and return the corresponding IP addresses. - pub async fn query(&mut self, name: &str, qtype: DnsQueryType) -> Result, Error> { - let query = match { self.with_mut(|s, i| s.start_query(i.context(), name, qtype)) } { - Ok(handle) => handle, - Err(e) => return Err(e.into()), + pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result, Error> { + self.stack.dns_query(name, qtype).await + } +} + +#[cfg(all(feature = "unstable-traits", feature = "nightly"))] +impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D> +where + D: Driver + 'static, +{ + type Error = Error; + + async fn get_host_by_name( + &self, + host: &str, + addr_type: embedded_nal_async::AddrType, + ) -> Result { + use embedded_nal_async::{AddrType, IpAddr}; + let qtype = match addr_type { + AddrType::IPv6 => DnsQueryType::Aaaa, + _ => DnsQueryType::A, }; - - let handle = self.handle; - let drop = OnDrop::new(|| { - let s = &mut *self.stack.borrow_mut(); - let socket = s.sockets.get_mut::(handle); - socket.cancel_query(query); - s.waker.wake(); - }); - - let res = poll_fn(|cx| { - self.with_mut(|s, _| match s.get_query_result(query) { - Ok(addrs) => Poll::Ready(Ok(addrs)), - Err(GetQueryResultError::Pending) => { - s.register_query_waker(query, cx.waker()); - Poll::Pending - } - Err(e) => Poll::Ready(Err(e.into())), + let addrs = self.query(host, qtype).await?; + if let Some(first) = addrs.get(0) { + Ok(match first { + IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()), + IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()), }) - }) - .await; + } else { + Err(Error::Failed) + } + } - drop.defuse(); - res - } -} - -impl<'a> Drop for DnsSocket<'a> { - fn drop(&mut self) { - self.stack.borrow_mut().sockets.remove(self.handle); + async fn get_host_by_address( + &self, + _addr: embedded_nal_async::IpAddr, + ) -> Result, Self::Error> { + todo!() } } diff --git a/embassy-net/src/lib.rs b/embassy-net/src/lib.rs index ae447d063..b63aa83df 100644 --- a/embassy-net/src/lib.rs +++ b/embassy-net/src/lib.rs @@ -48,15 +48,22 @@ use crate::device::DriverAdapter; const LOCAL_PORT_MIN: u16 = 1025; const LOCAL_PORT_MAX: u16 = 65535; +const MAX_QUERIES: usize = 2; pub struct StackResources { sockets: [SocketStorage<'static>; SOCK], + #[cfg(feature = "dns")] + queries: Option<[Option; MAX_QUERIES]>, } impl StackResources { pub fn new() -> Self { + #[cfg(feature = "dns")] + const INIT: Option = None; Self { sockets: [SocketStorage::EMPTY; SOCK], + #[cfg(feature = "dns")] + queries: Some([INIT; MAX_QUERIES]), } } } @@ -109,6 +116,8 @@ struct Inner { config: Option, #[cfg(feature = "dhcpv4")] dhcp_socket: Option, + #[cfg(feature = "dns")] + dns_socket: Option, } pub(crate) struct SocketStack { @@ -153,6 +162,8 @@ impl Stack { config: None, #[cfg(feature = "dhcpv4")] dhcp_socket: None, + #[cfg(feature = "dns")] + dns_socket: None, }; let mut socket = SocketStack { sockets, @@ -161,8 +172,17 @@ impl Stack { next_local_port, }; + #[cfg(feature = "dns")] + { + if let Some(queries) = resources.queries.take() { + inner.dns_socket = Some(socket.sockets.add(dns::Socket::new(&[], queries))); + } + } + match config { - Config::Static(config) => inner.apply_config(&mut socket, config), + Config::Static(config) => { + inner.apply_config(&mut socket, config); + } #[cfg(feature = "dhcpv4")] Config::Dhcp(config) => { let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); @@ -210,6 +230,59 @@ impl Stack { .await; unreachable!() } + + #[cfg(feature = "dns")] + async fn dns_query( + &self, + name: &str, + qtype: dns::DnsQueryType, + ) -> Result, dns::Error> { + let query = self.with_mut(|s, i| { + if let Some(dns_handle) = i.dns_socket { + let socket = s.sockets.get_mut::(dns_handle); + match socket.start_query(s.iface.context(), name, qtype) { + Ok(handle) => Ok(handle), + Err(e) => Err(e.into()), + } + } else { + Err(dns::Error::Failed) + } + })?; + + use embassy_hal_common::drop::OnDrop; + let drop = OnDrop::new(|| { + self.with_mut(|s, i| { + if let Some(dns_handle) = i.dns_socket { + let socket = s.sockets.get_mut::(dns_handle); + socket.cancel_query(query); + s.waker.wake(); + } + }) + }); + + let res = poll_fn(|cx| { + self.with_mut(|s, i| { + if let Some(dns_handle) = i.dns_socket { + let socket = s.sockets.get_mut::(dns_handle); + match socket.get_query_result(query) { + Ok(addrs) => Poll::Ready(Ok(addrs)), + Err(dns::GetQueryResultError::Pending) => { + socket.register_query_waker(query, cx.waker()); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e.into())), + } + } else { + Poll::Ready(Err(dns::Error::Failed)) + } + }) + }) + .await; + + drop.defuse(); + + res + } } impl SocketStack { @@ -251,6 +324,13 @@ impl Inner { debug!(" DNS server {}: {}", i, s); } + #[cfg(feature = "dns")] + if let Some(dns_socket) = self.dns_socket { + let socket = s.sockets.get_mut::(dns_socket); + let servers: Vec = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect(); + socket.update_servers(&servers[..]); + } + self.config = Some(config) } @@ -326,6 +406,7 @@ impl Inner { //if old_link_up || self.link_up { // self.poll_configurator(timestamp) //} + // if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) { let t = Timer::at(instant_from_smoltcp(poll_at)); diff --git a/examples/std/src/bin/net_dns.rs b/examples/std/src/bin/net_dns.rs index 6203f8370..e787cb823 100644 --- a/examples/std/src/bin/net_dns.rs +++ b/examples/std/src/bin/net_dns.rs @@ -71,8 +71,7 @@ async fn main_task(spawner: Spawner) { spawner.spawn(net_task(stack)).unwrap(); // Then we can use it! - - let mut socket = DnsSocket::new(stack, vec![]); + let socket = DnsSocket::new(stack); let host = "example.com"; info!("querying host {:?}...", host);