diff --git a/embassy-nrf/src/rng.rs b/embassy-nrf/src/rng.rs index e5ec02c67..0ec87effa 100644 --- a/embassy-nrf/src/rng.rs +++ b/embassy-nrf/src/rng.rs @@ -1,13 +1,18 @@ +use core::cell::RefCell; use core::convert::Infallible; +use core::future::Future; use core::marker::PhantomData; +use core::ptr::NonNull; +use core::task::Poll; +use core::task::Waker; use embassy::interrupt::InterruptExt; use embassy::traits; +use embassy::util::CriticalSectionMutex; use embassy::util::OnDrop; -use embassy::util::Signal; use embassy::util::Unborrow; use embassy_extras::unborrow; -use futures::Future; +use futures::future::poll_fn; use rand_core::RngCore; use crate::interrupt; @@ -20,18 +25,41 @@ impl RNG { } } -static NEXT_BYTE: Signal = Signal::new(); +static STATE: CriticalSectionMutex> = + CriticalSectionMutex::new(RefCell::new(State { + buffer: None, + waker: None, + index: 0, + })); + +struct State { + buffer: Option>, + waker: Option, + index: usize, +} + +// SAFETY: `NonNull` is `!Send` because of the possibility of it being aliased. +// However, `buffer` is only used within `on_interrupt`, +// and the original `&mut` passed to `fill_bytes` cannot be used because the safety contract of `Rng::new` +// means that it must still be borrowed by `RngFuture`, and so `rustc` will not let it be accessed. +unsafe impl Send for State {} /// A wrapper around an nRF RNG peripheral. /// /// It has a non-blocking API, through `embassy::traits::Rng`, and a blocking api through `rand`. pub struct Rng<'d> { irq: interrupt::RNG, - phantom: PhantomData<&'d mut RNG>, + phantom: PhantomData<(&'d mut RNG, &'d mut interrupt::RNG)>, } impl<'d> Rng<'d> { - pub fn new( + /// Creates a new RNG driver from the `RNG` peripheral and interrupt. + /// + /// SAFETY: The future returned from `fill_bytes` must not have its lifetime end without running its destructor, + /// e.g. using `mem::forget`. + /// + /// The synchronous API is safe. + pub unsafe fn new( _rng: impl Unborrow + 'd, irq: impl Unborrow + 'd, ) -> Self { @@ -42,7 +70,7 @@ impl<'d> Rng<'d> { phantom: PhantomData, }; - this.stop(); + Self::stop(); this.disable_irq(); this.irq.set_handler(Self::on_interrupt); @@ -53,11 +81,25 @@ impl<'d> Rng<'d> { } fn on_interrupt(_: *mut ()) { - NEXT_BYTE.signal(RNG::regs().value.read().value().bits()); - RNG::regs().events_valrdy.reset(); + critical_section::with(|cs| { + let mut state = STATE.borrow(cs).borrow_mut(); + // SAFETY: the safety requirements on `Rng::new` make sure that the original `&mut`'s lifetime is still valid, + // meaning it can't be aliased and is a valid pointer. + let buffer = unsafe { state.buffer.unwrap().as_mut() }; + buffer[state.index] = RNG::regs().value.read().value().bits(); + state.index += 1; + if state.index == buffer.len() { + // Stop the RNG within the interrupt so that it doesn't get triggered again on the way to waking the future. + Self::stop(); + if let Some(waker) = state.waker.take() { + waker.wake(); + } + } + RNG::regs().events_valrdy.reset(); + }); } - fn stop(&self) { + fn stop() { RNG::regs().tasks_stop.write(|w| unsafe { w.bits(1) }) } @@ -98,17 +140,39 @@ impl<'d> traits::rng::Rng for Rng<'d> { fn fill_bytes<'a>(&'a mut self, dest: &'a mut [u8]) -> Self::RngFuture<'a> { async move { + critical_section::with(|cs| { + let mut state = STATE.borrow(cs).borrow_mut(); + state.buffer = Some(dest.into()); + }); + self.enable_irq(); self.start(); let on_drop = OnDrop::new(|| { - self.stop(); + Self::stop(); self.disable_irq(); }); - for byte in dest.iter_mut() { - *byte = NEXT_BYTE.wait().await; - } + poll_fn(|cx| { + critical_section::with(|cs| { + let mut state = STATE.borrow(cs).borrow_mut(); + state.waker = Some(cx.waker().clone()); + // SAFETY: see safety message in interrupt handler. + // Also, both here and in the interrupt handler, we're in a critical section, + // so they can't interfere with each other. + let buffer = unsafe { state.buffer.unwrap().as_ref() }; + + if state.index == buffer.len() { + // Reset the state for next time + state.buffer = None; + state.index = 0; + Poll::Ready(()) + } else { + Poll::Pending + } + }) + }) + .await; // Trigger the teardown drop(on_drop); @@ -129,7 +193,7 @@ impl<'d> RngCore for Rng<'d> { *byte = regs.value.read().value().bits(); } - self.stop(); + Self::stop(); } fn next_u32(&mut self) -> u32 {