Rewrite to use a single socket
This commit is contained in:
parent
614740a1b2
commit
cd440a49d6
4 changed files with 133 additions and 79 deletions
|
@ -52,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false }
|
||||||
stable_deref_trait = { version = "1.2.0", default-features = false }
|
stable_deref_trait = { version = "1.2.0", default-features = false }
|
||||||
futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] }
|
futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] }
|
||||||
atomic-pool = "1.0"
|
atomic-pool = "1.0"
|
||||||
embedded-nal-async = { version = "0.3.0", optional = true }
|
embedded-nal-async = { version = "0.4.0", optional = true }
|
||||||
atomic-polyfill = { version = "1.0" }
|
atomic-polyfill = { version = "1.0" }
|
||||||
|
|
|
@ -1,19 +1,10 @@
|
||||||
//! DNS socket with async support.
|
//! DNS socket with async support.
|
||||||
use core::cell::RefCell;
|
|
||||||
use core::future::poll_fn;
|
|
||||||
use core::mem;
|
|
||||||
use core::task::Poll;
|
|
||||||
|
|
||||||
use embassy_hal_common::drop::OnDrop;
|
|
||||||
use embassy_net_driver::Driver;
|
|
||||||
use heapless::Vec;
|
use heapless::Vec;
|
||||||
use managed::ManagedSlice;
|
pub use smoltcp::socket::dns::{DnsQuery, Socket, MAX_ADDRESS_COUNT};
|
||||||
use smoltcp::iface::{Interface, SocketHandle};
|
pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
|
||||||
pub use smoltcp::socket::dns::DnsQuery;
|
|
||||||
use smoltcp::socket::dns::{self, GetQueryResultError, StartQueryError, MAX_ADDRESS_COUNT};
|
|
||||||
pub use smoltcp::wire::{DnsQueryType, IpAddress};
|
pub use smoltcp::wire::{DnsQueryType, IpAddress};
|
||||||
|
|
||||||
use crate::{SocketStack, Stack};
|
use crate::{Driver, Stack};
|
||||||
|
|
||||||
/// Errors returned by DnsSocket.
|
/// Errors returned by DnsSocket.
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
@ -46,81 +37,64 @@ impl From<StartQueryError> for Error {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Async socket for making DNS queries.
|
/// Async socket for making DNS queries.
|
||||||
pub struct DnsSocket<'a> {
|
pub struct DnsSocket<'a, D>
|
||||||
stack: &'a RefCell<SocketStack>,
|
where
|
||||||
handle: SocketHandle,
|
D: Driver + 'static,
|
||||||
|
{
|
||||||
|
stack: &'a Stack<D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> DnsSocket<'a> {
|
impl<'a, D> DnsSocket<'a, D>
|
||||||
|
where
|
||||||
|
D: Driver + 'static,
|
||||||
|
{
|
||||||
/// Create a new DNS socket using the provided stack and query storage.
|
/// Create a new DNS socket using the provided stack and query storage.
|
||||||
///
|
///
|
||||||
/// DNS servers are derived from the stack configuration.
|
/// DNS servers are derived from the stack configuration.
|
||||||
///
|
///
|
||||||
/// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
|
/// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
|
||||||
pub fn new<D, Q>(stack: &'a Stack<D>, queries: Q) -> Self
|
pub fn new(stack: &'a Stack<D>) -> Self {
|
||||||
where
|
Self { stack }
|
||||||
D: Driver + 'static,
|
|
||||||
Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
|
|
||||||
{
|
|
||||||
let servers = stack
|
|
||||||
.config()
|
|
||||||
.map(|c| {
|
|
||||||
let v: Vec<IpAddress, 3> = c.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
|
|
||||||
v
|
|
||||||
})
|
|
||||||
.unwrap_or(Vec::new());
|
|
||||||
let s = &mut *stack.socket.borrow_mut();
|
|
||||||
let queries: ManagedSlice<'static, Option<DnsQuery>> = unsafe { mem::transmute(queries.into()) };
|
|
||||||
|
|
||||||
let handle = s.sockets.add(dns::Socket::new(&servers[..], queries));
|
|
||||||
Self {
|
|
||||||
stack: &stack.socket,
|
|
||||||
handle,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn with_mut<R>(&mut self, f: impl FnOnce(&mut dns::Socket, &mut Interface) -> R) -> R {
|
|
||||||
let s = &mut *self.stack.borrow_mut();
|
|
||||||
let socket = s.sockets.get_mut::<dns::Socket>(self.handle);
|
|
||||||
let res = f(socket, &mut s.iface);
|
|
||||||
s.waker.wake();
|
|
||||||
res
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Make a query for a given name and return the corresponding IP addresses.
|
/// Make a query for a given name and return the corresponding IP addresses.
|
||||||
pub async fn query(&mut self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> {
|
pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, MAX_ADDRESS_COUNT>, Error> {
|
||||||
let query = match { self.with_mut(|s, i| s.start_query(i.context(), name, qtype)) } {
|
self.stack.dns_query(name, qtype).await
|
||||||
Ok(handle) => handle,
|
}
|
||||||
Err(e) => return Err(e.into()),
|
}
|
||||||
|
|
||||||
|
#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
|
||||||
|
impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D>
|
||||||
|
where
|
||||||
|
D: Driver + 'static,
|
||||||
|
{
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
async fn get_host_by_name(
|
||||||
|
&self,
|
||||||
|
host: &str,
|
||||||
|
addr_type: embedded_nal_async::AddrType,
|
||||||
|
) -> Result<embedded_nal_async::IpAddr, Self::Error> {
|
||||||
|
use embedded_nal_async::{AddrType, IpAddr};
|
||||||
|
let qtype = match addr_type {
|
||||||
|
AddrType::IPv6 => DnsQueryType::Aaaa,
|
||||||
|
_ => DnsQueryType::A,
|
||||||
};
|
};
|
||||||
|
let addrs = self.query(host, qtype).await?;
|
||||||
let handle = self.handle;
|
if let Some(first) = addrs.get(0) {
|
||||||
let drop = OnDrop::new(|| {
|
Ok(match first {
|
||||||
let s = &mut *self.stack.borrow_mut();
|
IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()),
|
||||||
let socket = s.sockets.get_mut::<dns::Socket>(handle);
|
IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()),
|
||||||
socket.cancel_query(query);
|
|
||||||
s.waker.wake();
|
|
||||||
});
|
|
||||||
|
|
||||||
let res = poll_fn(|cx| {
|
|
||||||
self.with_mut(|s, _| match s.get_query_result(query) {
|
|
||||||
Ok(addrs) => Poll::Ready(Ok(addrs)),
|
|
||||||
Err(GetQueryResultError::Pending) => {
|
|
||||||
s.register_query_waker(query, cx.waker());
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
Err(e) => Poll::Ready(Err(e.into())),
|
|
||||||
})
|
})
|
||||||
})
|
} else {
|
||||||
.await;
|
Err(Error::Failed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
drop.defuse();
|
async fn get_host_by_address(
|
||||||
res
|
&self,
|
||||||
}
|
_addr: embedded_nal_async::IpAddr,
|
||||||
}
|
) -> Result<heapless::String<256>, Self::Error> {
|
||||||
|
todo!()
|
||||||
impl<'a> Drop for DnsSocket<'a> {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.stack.borrow_mut().sockets.remove(self.handle);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,15 +48,22 @@ use crate::device::DriverAdapter;
|
||||||
|
|
||||||
const LOCAL_PORT_MIN: u16 = 1025;
|
const LOCAL_PORT_MIN: u16 = 1025;
|
||||||
const LOCAL_PORT_MAX: u16 = 65535;
|
const LOCAL_PORT_MAX: u16 = 65535;
|
||||||
|
const MAX_QUERIES: usize = 2;
|
||||||
|
|
||||||
pub struct StackResources<const SOCK: usize> {
|
pub struct StackResources<const SOCK: usize> {
|
||||||
sockets: [SocketStorage<'static>; SOCK],
|
sockets: [SocketStorage<'static>; SOCK],
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
queries: Option<[Option<dns::DnsQuery>; MAX_QUERIES]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const SOCK: usize> StackResources<SOCK> {
|
impl<const SOCK: usize> StackResources<SOCK> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
const INIT: Option<dns::DnsQuery> = None;
|
||||||
Self {
|
Self {
|
||||||
sockets: [SocketStorage::EMPTY; SOCK],
|
sockets: [SocketStorage::EMPTY; SOCK],
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
queries: Some([INIT; MAX_QUERIES]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,6 +116,8 @@ struct Inner<D: Driver> {
|
||||||
config: Option<StaticConfig>,
|
config: Option<StaticConfig>,
|
||||||
#[cfg(feature = "dhcpv4")]
|
#[cfg(feature = "dhcpv4")]
|
||||||
dhcp_socket: Option<SocketHandle>,
|
dhcp_socket: Option<SocketHandle>,
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
dns_socket: Option<SocketHandle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct SocketStack {
|
pub(crate) struct SocketStack {
|
||||||
|
@ -153,6 +162,8 @@ impl<D: Driver + 'static> Stack<D> {
|
||||||
config: None,
|
config: None,
|
||||||
#[cfg(feature = "dhcpv4")]
|
#[cfg(feature = "dhcpv4")]
|
||||||
dhcp_socket: None,
|
dhcp_socket: None,
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
dns_socket: None,
|
||||||
};
|
};
|
||||||
let mut socket = SocketStack {
|
let mut socket = SocketStack {
|
||||||
sockets,
|
sockets,
|
||||||
|
@ -161,8 +172,17 @@ impl<D: Driver + 'static> Stack<D> {
|
||||||
next_local_port,
|
next_local_port,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
{
|
||||||
|
if let Some(queries) = resources.queries.take() {
|
||||||
|
inner.dns_socket = Some(socket.sockets.add(dns::Socket::new(&[], queries)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match config {
|
match config {
|
||||||
Config::Static(config) => inner.apply_config(&mut socket, config),
|
Config::Static(config) => {
|
||||||
|
inner.apply_config(&mut socket, config);
|
||||||
|
}
|
||||||
#[cfg(feature = "dhcpv4")]
|
#[cfg(feature = "dhcpv4")]
|
||||||
Config::Dhcp(config) => {
|
Config::Dhcp(config) => {
|
||||||
let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
|
let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
|
||||||
|
@ -210,6 +230,59 @@ impl<D: Driver + 'static> Stack<D> {
|
||||||
.await;
|
.await;
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
async fn dns_query(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
qtype: dns::DnsQueryType,
|
||||||
|
) -> Result<Vec<IpAddress, { dns::MAX_ADDRESS_COUNT }>, dns::Error> {
|
||||||
|
let query = self.with_mut(|s, i| {
|
||||||
|
if let Some(dns_handle) = i.dns_socket {
|
||||||
|
let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
|
||||||
|
match socket.start_query(s.iface.context(), name, qtype) {
|
||||||
|
Ok(handle) => Ok(handle),
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(dns::Error::Failed)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
use embassy_hal_common::drop::OnDrop;
|
||||||
|
let drop = OnDrop::new(|| {
|
||||||
|
self.with_mut(|s, i| {
|
||||||
|
if let Some(dns_handle) = i.dns_socket {
|
||||||
|
let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
|
||||||
|
socket.cancel_query(query);
|
||||||
|
s.waker.wake();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
let res = poll_fn(|cx| {
|
||||||
|
self.with_mut(|s, i| {
|
||||||
|
if let Some(dns_handle) = i.dns_socket {
|
||||||
|
let socket = s.sockets.get_mut::<dns::Socket>(dns_handle);
|
||||||
|
match socket.get_query_result(query) {
|
||||||
|
Ok(addrs) => Poll::Ready(Ok(addrs)),
|
||||||
|
Err(dns::GetQueryResultError::Pending) => {
|
||||||
|
socket.register_query_waker(query, cx.waker());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
Err(e) => Poll::Ready(Err(e.into())),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Err(dns::Error::Failed))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
drop.defuse();
|
||||||
|
|
||||||
|
res
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SocketStack {
|
impl SocketStack {
|
||||||
|
@ -251,6 +324,13 @@ impl<D: Driver + 'static> Inner<D> {
|
||||||
debug!(" DNS server {}: {}", i, s);
|
debug!(" DNS server {}: {}", i, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "dns")]
|
||||||
|
if let Some(dns_socket) = self.dns_socket {
|
||||||
|
let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(dns_socket);
|
||||||
|
let servers: Vec<IpAddress, 3> = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
|
||||||
|
socket.update_servers(&servers[..]);
|
||||||
|
}
|
||||||
|
|
||||||
self.config = Some(config)
|
self.config = Some(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,6 +406,7 @@ impl<D: Driver + 'static> Inner<D> {
|
||||||
//if old_link_up || self.link_up {
|
//if old_link_up || self.link_up {
|
||||||
// self.poll_configurator(timestamp)
|
// self.poll_configurator(timestamp)
|
||||||
//}
|
//}
|
||||||
|
//
|
||||||
|
|
||||||
if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
|
if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
|
||||||
let t = Timer::at(instant_from_smoltcp(poll_at));
|
let t = Timer::at(instant_from_smoltcp(poll_at));
|
||||||
|
|
|
@ -71,8 +71,7 @@ async fn main_task(spawner: Spawner) {
|
||||||
spawner.spawn(net_task(stack)).unwrap();
|
spawner.spawn(net_task(stack)).unwrap();
|
||||||
|
|
||||||
// Then we can use it!
|
// Then we can use it!
|
||||||
|
let socket = DnsSocket::new(stack);
|
||||||
let mut socket = DnsSocket::new(stack, vec![]);
|
|
||||||
|
|
||||||
let host = "example.com";
|
let host = "example.com";
|
||||||
info!("querying host {:?}...", host);
|
info!("querying host {:?}...", host);
|
||||||
|
|
Loading…
Reference in a new issue