Don't wake the future for every byte in fill_bytes

This commit is contained in:
Liam Murphy 2021-06-30 12:34:57 +10:00
parent ae0219de6f
commit 89fdad3a6b

View file

@ -1,13 +1,18 @@
use core::cell::RefCell;
use core::convert::Infallible;
use core::future::Future;
use core::marker::PhantomData;
use core::ptr::NonNull;
use core::task::Poll;
use core::task::Waker;
use embassy::interrupt::InterruptExt;
use embassy::traits;
use embassy::util::CriticalSectionMutex;
use embassy::util::OnDrop;
use embassy::util::Signal;
use embassy::util::Unborrow;
use embassy_extras::unborrow;
use futures::Future;
use futures::future::poll_fn;
use rand_core::RngCore;
use crate::interrupt;
@ -20,18 +25,41 @@ impl RNG {
}
}
static NEXT_BYTE: Signal<u8> = Signal::new();
static STATE: CriticalSectionMutex<RefCell<State>> =
CriticalSectionMutex::new(RefCell::new(State {
buffer: None,
waker: None,
index: 0,
}));
struct State {
buffer: Option<NonNull<[u8]>>,
waker: Option<Waker>,
index: usize,
}
// SAFETY: `NonNull` is `!Send` because of the possibility of it being aliased.
// However, `buffer` is only used within `on_interrupt`,
// and the original `&mut` passed to `fill_bytes` cannot be used because the safety contract of `Rng::new`
// means that it must still be borrowed by `RngFuture`, and so `rustc` will not let it be accessed.
unsafe impl Send for State {}
/// A wrapper around an nRF RNG peripheral.
///
/// It has a non-blocking API, through `embassy::traits::Rng`, and a blocking api through `rand`.
pub struct Rng<'d> {
irq: interrupt::RNG,
phantom: PhantomData<&'d mut RNG>,
phantom: PhantomData<(&'d mut RNG, &'d mut interrupt::RNG)>,
}
impl<'d> Rng<'d> {
pub fn new(
/// Creates a new RNG driver from the `RNG` peripheral and interrupt.
///
/// SAFETY: The future returned from `fill_bytes` must not have its lifetime end without running its destructor,
/// e.g. using `mem::forget`.
///
/// The synchronous API is safe.
pub unsafe fn new(
_rng: impl Unborrow<Target = RNG> + 'd,
irq: impl Unborrow<Target = interrupt::RNG> + 'd,
) -> Self {
@ -42,7 +70,7 @@ impl<'d> Rng<'d> {
phantom: PhantomData,
};
this.stop();
Self::stop();
this.disable_irq();
this.irq.set_handler(Self::on_interrupt);
@ -53,11 +81,25 @@ impl<'d> Rng<'d> {
}
fn on_interrupt(_: *mut ()) {
NEXT_BYTE.signal(RNG::regs().value.read().value().bits());
RNG::regs().events_valrdy.reset();
critical_section::with(|cs| {
let mut state = STATE.borrow(cs).borrow_mut();
// SAFETY: the safety requirements on `Rng::new` make sure that the original `&mut`'s lifetime is still valid,
// meaning it can't be aliased and is a valid pointer.
let buffer = unsafe { state.buffer.unwrap().as_mut() };
buffer[state.index] = RNG::regs().value.read().value().bits();
state.index += 1;
if state.index == buffer.len() {
// Stop the RNG within the interrupt so that it doesn't get triggered again on the way to waking the future.
Self::stop();
if let Some(waker) = state.waker.take() {
waker.wake();
}
}
RNG::regs().events_valrdy.reset();
});
}
fn stop(&self) {
fn stop() {
RNG::regs().tasks_stop.write(|w| unsafe { w.bits(1) })
}
@ -98,17 +140,39 @@ impl<'d> traits::rng::Rng for Rng<'d> {
fn fill_bytes<'a>(&'a mut self, dest: &'a mut [u8]) -> Self::RngFuture<'a> {
async move {
critical_section::with(|cs| {
let mut state = STATE.borrow(cs).borrow_mut();
state.buffer = Some(dest.into());
});
self.enable_irq();
self.start();
let on_drop = OnDrop::new(|| {
self.stop();
Self::stop();
self.disable_irq();
});
for byte in dest.iter_mut() {
*byte = NEXT_BYTE.wait().await;
}
poll_fn(|cx| {
critical_section::with(|cs| {
let mut state = STATE.borrow(cs).borrow_mut();
state.waker = Some(cx.waker().clone());
// SAFETY: see safety message in interrupt handler.
// Also, both here and in the interrupt handler, we're in a critical section,
// so they can't interfere with each other.
let buffer = unsafe { state.buffer.unwrap().as_ref() };
if state.index == buffer.len() {
// Reset the state for next time
state.buffer = None;
state.index = 0;
Poll::Ready(())
} else {
Poll::Pending
}
})
})
.await;
// Trigger the teardown
drop(on_drop);
@ -129,7 +193,7 @@ impl<'d> RngCore for Rng<'d> {
*byte = regs.value.read().value().bits();
}
self.stop();
Self::stop();
}
fn next_u32(&mut self) -> u32 {