Merge pull request #1185 from embassy-rs/dns-impl
Add DNS socket to embassy-net
This commit is contained in:
commit
e1eac15c42
7 changed files with 317 additions and 12 deletions
|
@ -22,7 +22,7 @@ unstable-traits = []
|
|||
|
||||
udp = ["smoltcp/socket-udp"]
|
||||
tcp = ["smoltcp/socket-tcp"]
|
||||
dns = ["smoltcp/socket-dns"]
|
||||
dns = ["smoltcp/socket-dns", "smoltcp/proto-dns"]
|
||||
dhcpv4 = ["medium-ethernet", "smoltcp/socket-dhcpv4"]
|
||||
proto-ipv6 = ["smoltcp/proto-ipv6"]
|
||||
medium-ethernet = ["smoltcp/medium-ethernet"]
|
||||
|
@ -40,6 +40,7 @@ smoltcp = { version = "0.9.0", default-features = false, features = [
|
|||
]}
|
||||
|
||||
embassy-net-driver = { version = "0.1.0", path = "../embassy-net-driver" }
|
||||
embassy-hal-common = { version = "0.1.0", path = "../embassy-hal-common" }
|
||||
embassy-time = { version = "0.1.0", path = "../embassy-time" }
|
||||
embassy-sync = { version = "0.1.0", path = "../embassy-sync" }
|
||||
embedded-io = { version = "0.4.0", optional = true }
|
||||
|
@ -51,5 +52,5 @@ generic-array = { version = "0.14.4", default-features = false }
|
|||
stable_deref_trait = { version = "1.2.0", default-features = false }
|
||||
futures = { version = "0.3.17", default-features = false, features = [ "async-await" ] }
|
||||
atomic-pool = "1.0"
|
||||
embedded-nal-async = { version = "0.3.0", optional = true }
|
||||
embedded-nal-async = { version = "0.4.0", optional = true }
|
||||
atomic-polyfill = { version = "1.0" }
|
||||
|
|
97
embassy-net/src/dns.rs
Normal file
97
embassy-net/src/dns.rs
Normal file
|
@ -0,0 +1,97 @@
|
|||
//! DNS socket with async support.
|
||||
use heapless::Vec;
|
||||
pub use smoltcp::socket::dns::{DnsQuery, Socket};
|
||||
pub(crate) use smoltcp::socket::dns::{GetQueryResultError, StartQueryError};
|
||||
pub use smoltcp::wire::{DnsQueryType, IpAddress};
|
||||
|
||||
use crate::{Driver, Stack};
|
||||
|
||||
/// Errors returned by DnsSocket.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
||||
pub enum Error {
|
||||
/// Invalid name
|
||||
InvalidName,
|
||||
/// Name too long
|
||||
NameTooLong,
|
||||
/// Name lookup failed
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl From<GetQueryResultError> for Error {
|
||||
fn from(_: GetQueryResultError) -> Self {
|
||||
Self::Failed
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StartQueryError> for Error {
|
||||
fn from(e: StartQueryError) -> Self {
|
||||
match e {
|
||||
StartQueryError::NoFreeSlot => Self::Failed,
|
||||
StartQueryError::InvalidName => Self::InvalidName,
|
||||
StartQueryError::NameTooLong => Self::NameTooLong,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Async socket for making DNS queries.
|
||||
pub struct DnsSocket<'a, D>
|
||||
where
|
||||
D: Driver + 'static,
|
||||
{
|
||||
stack: &'a Stack<D>,
|
||||
}
|
||||
|
||||
impl<'a, D> DnsSocket<'a, D>
|
||||
where
|
||||
D: Driver + 'static,
|
||||
{
|
||||
/// Create a new DNS socket using the provided stack.
|
||||
///
|
||||
/// NOTE: If using DHCP, make sure it has reconfigured the stack to ensure the DNS servers are updated.
|
||||
pub fn new(stack: &'a Stack<D>) -> Self {
|
||||
Self { stack }
|
||||
}
|
||||
|
||||
/// Make a query for a given name and return the corresponding IP addresses.
|
||||
pub async fn query(&self, name: &str, qtype: DnsQueryType) -> Result<Vec<IpAddress, 1>, Error> {
|
||||
self.stack.dns_query(name, qtype).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "unstable-traits", feature = "nightly"))]
|
||||
impl<'a, D> embedded_nal_async::Dns for DnsSocket<'a, D>
|
||||
where
|
||||
D: Driver + 'static,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
async fn get_host_by_name(
|
||||
&self,
|
||||
host: &str,
|
||||
addr_type: embedded_nal_async::AddrType,
|
||||
) -> Result<embedded_nal_async::IpAddr, Self::Error> {
|
||||
use embedded_nal_async::{AddrType, IpAddr};
|
||||
let qtype = match addr_type {
|
||||
AddrType::IPv6 => DnsQueryType::Aaaa,
|
||||
_ => DnsQueryType::A,
|
||||
};
|
||||
let addrs = self.query(host, qtype).await?;
|
||||
if let Some(first) = addrs.get(0) {
|
||||
Ok(match first {
|
||||
IpAddress::Ipv4(addr) => IpAddr::V4(addr.0.into()),
|
||||
#[cfg(feature = "proto-ipv6")]
|
||||
IpAddress::Ipv6(addr) => IpAddr::V6(addr.0.into()),
|
||||
})
|
||||
} else {
|
||||
Err(Error::Failed)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_host_by_address(
|
||||
&self,
|
||||
_addr: embedded_nal_async::IpAddr,
|
||||
) -> Result<heapless::String<256>, Self::Error> {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -11,6 +11,8 @@ pub(crate) mod fmt;
|
|||
pub use embassy_net_driver as driver;
|
||||
|
||||
mod device;
|
||||
#[cfg(feature = "dns")]
|
||||
pub mod dns;
|
||||
#[cfg(feature = "tcp")]
|
||||
pub mod tcp;
|
||||
#[cfg(feature = "udp")]
|
||||
|
@ -46,15 +48,23 @@ use crate::device::DriverAdapter;
|
|||
|
||||
const LOCAL_PORT_MIN: u16 = 1025;
|
||||
const LOCAL_PORT_MAX: u16 = 65535;
|
||||
#[cfg(feature = "dns")]
|
||||
const MAX_QUERIES: usize = 4;
|
||||
|
||||
pub struct StackResources<const SOCK: usize> {
|
||||
sockets: [SocketStorage<'static>; SOCK],
|
||||
#[cfg(feature = "dns")]
|
||||
queries: [Option<dns::DnsQuery>; MAX_QUERIES],
|
||||
}
|
||||
|
||||
impl<const SOCK: usize> StackResources<SOCK> {
|
||||
pub fn new() -> Self {
|
||||
#[cfg(feature = "dns")]
|
||||
const INIT: Option<dns::DnsQuery> = None;
|
||||
Self {
|
||||
sockets: [SocketStorage::EMPTY; SOCK],
|
||||
#[cfg(feature = "dns")]
|
||||
queries: [INIT; MAX_QUERIES],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,6 +117,10 @@ struct Inner<D: Driver> {
|
|||
config: Option<StaticConfig>,
|
||||
#[cfg(feature = "dhcpv4")]
|
||||
dhcp_socket: Option<SocketHandle>,
|
||||
#[cfg(feature = "dns")]
|
||||
dns_socket: SocketHandle,
|
||||
#[cfg(feature = "dns")]
|
||||
dns_waker: WakerRegistration,
|
||||
}
|
||||
|
||||
pub(crate) struct SocketStack {
|
||||
|
@ -145,13 +159,6 @@ impl<D: Driver + 'static> Stack<D> {
|
|||
|
||||
let next_local_port = (random_seed % (LOCAL_PORT_MAX - LOCAL_PORT_MIN) as u64) as u16 + LOCAL_PORT_MIN;
|
||||
|
||||
let mut inner = Inner {
|
||||
device,
|
||||
link_up: false,
|
||||
config: None,
|
||||
#[cfg(feature = "dhcpv4")]
|
||||
dhcp_socket: None,
|
||||
};
|
||||
let mut socket = SocketStack {
|
||||
sockets,
|
||||
iface,
|
||||
|
@ -159,8 +166,25 @@ impl<D: Driver + 'static> Stack<D> {
|
|||
next_local_port,
|
||||
};
|
||||
|
||||
let mut inner = Inner {
|
||||
device,
|
||||
link_up: false,
|
||||
config: None,
|
||||
#[cfg(feature = "dhcpv4")]
|
||||
dhcp_socket: None,
|
||||
#[cfg(feature = "dns")]
|
||||
dns_socket: socket.sockets.add(dns::Socket::new(
|
||||
&[],
|
||||
managed::ManagedSlice::Borrowed(&mut resources.queries),
|
||||
)),
|
||||
#[cfg(feature = "dns")]
|
||||
dns_waker: WakerRegistration::new(),
|
||||
};
|
||||
|
||||
match config {
|
||||
Config::Static(config) => inner.apply_config(&mut socket, config),
|
||||
Config::Static(config) => {
|
||||
inner.apply_config(&mut socket, config);
|
||||
}
|
||||
#[cfg(feature = "dhcpv4")]
|
||||
Config::Dhcp(config) => {
|
||||
let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new();
|
||||
|
@ -208,6 +232,60 @@ impl<D: Driver + 'static> Stack<D> {
|
|||
.await;
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// Make a query for a given name and return the corresponding IP addresses.
|
||||
#[cfg(feature = "dns")]
|
||||
pub async fn dns_query(&self, name: &str, qtype: dns::DnsQueryType) -> Result<Vec<IpAddress, 1>, dns::Error> {
|
||||
let query = poll_fn(|cx| {
|
||||
self.with_mut(|s, i| {
|
||||
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
|
||||
match socket.start_query(s.iface.context(), name, qtype) {
|
||||
Ok(handle) => Poll::Ready(Ok(handle)),
|
||||
Err(dns::StartQueryError::NoFreeSlot) => {
|
||||
i.dns_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
}
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
use embassy_hal_common::drop::OnDrop;
|
||||
let drop = OnDrop::new(|| {
|
||||
self.with_mut(|s, i| {
|
||||
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
|
||||
socket.cancel_query(query);
|
||||
s.waker.wake();
|
||||
i.dns_waker.wake();
|
||||
})
|
||||
});
|
||||
|
||||
let res = poll_fn(|cx| {
|
||||
self.with_mut(|s, i| {
|
||||
let socket = s.sockets.get_mut::<dns::Socket>(i.dns_socket);
|
||||
match socket.get_query_result(query) {
|
||||
Ok(addrs) => {
|
||||
i.dns_waker.wake();
|
||||
Poll::Ready(Ok(addrs))
|
||||
}
|
||||
Err(dns::GetQueryResultError::Pending) => {
|
||||
socket.register_query_waker(query, cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
Err(e) => {
|
||||
i.dns_waker.wake();
|
||||
Poll::Ready(Err(e.into()))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
drop.defuse();
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketStack {
|
||||
|
@ -249,6 +327,13 @@ impl<D: Driver + 'static> Inner<D> {
|
|||
debug!(" DNS server {}: {}", i, s);
|
||||
}
|
||||
|
||||
#[cfg(feature = "dns")]
|
||||
{
|
||||
let socket = s.sockets.get_mut::<smoltcp::socket::dns::Socket>(self.dns_socket);
|
||||
let servers: Vec<IpAddress, 3> = config.dns_servers.iter().map(|c| IpAddress::Ipv4(*c)).collect();
|
||||
socket.update_servers(&servers[..]);
|
||||
}
|
||||
|
||||
self.config = Some(config)
|
||||
}
|
||||
|
||||
|
@ -324,6 +409,7 @@ impl<D: Driver + 'static> Inner<D> {
|
|||
//if old_link_up || self.link_up {
|
||||
// self.poll_configurator(timestamp)
|
||||
//}
|
||||
//
|
||||
|
||||
if let Some(poll_at) = s.iface.poll_at(timestamp, &mut s.sockets) {
|
||||
let t = Timer::at(instant_from_smoltcp(poll_at));
|
||||
|
|
|
@ -46,8 +46,31 @@ async fn net_task(stack: &'static Stack<Device<'static, MTU>>) -> ! {
|
|||
stack.run().await
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn test_function() -> (usize, u32, [u32; 2]) {
|
||||
let mut array = [3; 2];
|
||||
|
||||
let mut index = 0;
|
||||
let mut result = 0;
|
||||
|
||||
for x in [1, 2] {
|
||||
if x == 1 {
|
||||
array[1] = 99;
|
||||
} else {
|
||||
index = if x == 2 { 1 } else { 0 };
|
||||
|
||||
// grabs value from array[0], not array[1]
|
||||
result = array[index];
|
||||
}
|
||||
}
|
||||
|
||||
(index, result, array)
|
||||
}
|
||||
|
||||
#[embassy_executor::main]
|
||||
async fn main(spawner: Spawner) {
|
||||
info!("{:?}", test_function());
|
||||
|
||||
let p = embassy_nrf::init(Default::default());
|
||||
let clock: pac::CLOCK = unsafe { mem::transmute(()) };
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
|
|||
embassy-sync = { version = "0.1.0", path = "../../embassy-sync", features = ["log"] }
|
||||
embassy-executor = { version = "0.1.0", path = "../../embassy-executor", features = ["log", "std", "nightly", "integrated-timers"] }
|
||||
embassy-time = { version = "0.1.0", path = "../../embassy-time", features = ["log", "std", "nightly"] }
|
||||
embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dhcpv4"] }
|
||||
embassy-net = { version = "0.1.0", path = "../../embassy-net", features=[ "std", "nightly", "log", "medium-ethernet", "tcp", "udp", "dns", "dhcpv4", "unstable-traits", "proto-ipv6"] }
|
||||
embassy-net-driver = { version = "0.1.0", path = "../../embassy-net-driver" }
|
||||
embedded-io = { version = "0.4.0", features = ["async", "std", "futures"] }
|
||||
critical-section = { version = "1.1", features = ["std"] }
|
||||
|
|
98
examples/std/src/bin/net_dns.rs
Normal file
98
examples/std/src/bin/net_dns.rs
Normal file
|
@ -0,0 +1,98 @@
|
|||
#![feature(type_alias_impl_trait)]
|
||||
|
||||
use std::default::Default;
|
||||
|
||||
use clap::Parser;
|
||||
use embassy_executor::{Executor, Spawner};
|
||||
use embassy_net::dns::DnsQueryType;
|
||||
use embassy_net::{Config, Ipv4Address, Ipv4Cidr, Stack, StackResources};
|
||||
use heapless::Vec;
|
||||
use log::*;
|
||||
use rand_core::{OsRng, RngCore};
|
||||
use static_cell::StaticCell;
|
||||
|
||||
#[path = "../tuntap.rs"]
|
||||
mod tuntap;
|
||||
|
||||
use crate::tuntap::TunTapDevice;
|
||||
|
||||
macro_rules! singleton {
|
||||
($val:expr) => {{
|
||||
type T = impl Sized;
|
||||
static STATIC_CELL: StaticCell<T> = StaticCell::new();
|
||||
STATIC_CELL.init_with(move || $val)
|
||||
}};
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(version = "1.0")]
|
||||
struct Opts {
|
||||
/// TAP device name
|
||||
#[clap(long, default_value = "tap0")]
|
||||
tap: String,
|
||||
/// use a static IP instead of DHCP
|
||||
#[clap(long)]
|
||||
static_ip: bool,
|
||||
}
|
||||
|
||||
#[embassy_executor::task]
|
||||
async fn net_task(stack: &'static Stack<TunTapDevice>) -> ! {
|
||||
stack.run().await
|
||||
}
|
||||
|
||||
#[embassy_executor::task]
|
||||
async fn main_task(spawner: Spawner) {
|
||||
let opts: Opts = Opts::parse();
|
||||
|
||||
// Init network device
|
||||
let device = TunTapDevice::new(&opts.tap).unwrap();
|
||||
|
||||
// Choose between dhcp or static ip
|
||||
let config = if opts.static_ip {
|
||||
Config::Static(embassy_net::StaticConfig {
|
||||
address: Ipv4Cidr::new(Ipv4Address::new(192, 168, 69, 1), 24),
|
||||
dns_servers: Vec::from_slice(&[Ipv4Address::new(8, 8, 4, 4).into(), Ipv4Address::new(8, 8, 8, 8).into()])
|
||||
.unwrap(),
|
||||
gateway: Some(Ipv4Address::new(192, 168, 69, 100)),
|
||||
})
|
||||
} else {
|
||||
Config::Dhcp(Default::default())
|
||||
};
|
||||
|
||||
// Generate random seed
|
||||
let mut seed = [0; 8];
|
||||
OsRng.fill_bytes(&mut seed);
|
||||
let seed = u64::from_le_bytes(seed);
|
||||
|
||||
// Init network stack
|
||||
let stack: &Stack<_> = &*singleton!(Stack::new(device, config, singleton!(StackResources::<2>::new()), seed));
|
||||
|
||||
// Launch network task
|
||||
spawner.spawn(net_task(stack)).unwrap();
|
||||
|
||||
let host = "example.com";
|
||||
info!("querying host {:?}...", host);
|
||||
match stack.dns_query(host, DnsQueryType::A).await {
|
||||
Ok(r) => {
|
||||
info!("query response: {:?}", r);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("query error: {:?}", e);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static EXECUTOR: StaticCell<Executor> = StaticCell::new();
|
||||
|
||||
fn main() {
|
||||
env_logger::builder()
|
||||
.filter_level(log::LevelFilter::Debug)
|
||||
.filter_module("async_io", log::LevelFilter::Info)
|
||||
.format_timestamp_nanos()
|
||||
.init();
|
||||
|
||||
let executor = EXECUTOR.init(Executor::new());
|
||||
executor.run(|spawner| {
|
||||
spawner.spawn(main_task(spawner)).unwrap();
|
||||
});
|
||||
}
|
|
@ -21,7 +21,7 @@ cortex-m-rt = "0.7.0"
|
|||
embedded-hal = "0.2.6"
|
||||
embedded-hal-1 = { package = "embedded-hal", version = "=1.0.0-alpha.9" }
|
||||
embedded-hal-async = { version = "=0.2.0-alpha.0" }
|
||||
embedded-nal-async = "0.3.0"
|
||||
embedded-nal-async = "0.4.0"
|
||||
panic-probe = { version = "0.3", features = ["print-defmt"] }
|
||||
futures = { version = "0.3.17", default-features = false, features = ["async-await"] }
|
||||
heapless = { version = "0.7.5", default-features = false }
|
||||
|
|
Loading…
Reference in a new issue