diff --git a/embassy-nrf/src/rng.rs b/embassy-nrf/src/rng.rs index 0ec87effa..80282c4b0 100644 --- a/embassy-nrf/src/rng.rs +++ b/embassy-nrf/src/rng.rs @@ -1,14 +1,14 @@ -use core::cell::RefCell; use core::convert::Infallible; use core::future::Future; use core::marker::PhantomData; -use core::ptr::NonNull; +use core::ptr; +use core::sync::atomic::AtomicPtr; +use core::sync::atomic::Ordering; use core::task::Poll; -use core::task::Waker; use embassy::interrupt::InterruptExt; use embassy::traits; -use embassy::util::CriticalSectionMutex; +use embassy::util::AtomicWaker; use embassy::util::OnDrop; use embassy::util::Unborrow; use embassy_extras::unborrow; @@ -25,25 +25,18 @@ impl RNG { } } -static STATE: CriticalSectionMutex> = - CriticalSectionMutex::new(RefCell::new(State { - buffer: None, - waker: None, - index: 0, - })); +static STATE: State = State { + ptr: AtomicPtr::new(ptr::null_mut()), + end: AtomicPtr::new(ptr::null_mut()), + waker: AtomicWaker::new(), +}; struct State { - buffer: Option>, - waker: Option, - index: usize, + ptr: AtomicPtr, + end: AtomicPtr, + waker: AtomicWaker, } -// SAFETY: `NonNull` is `!Send` because of the possibility of it being aliased. -// However, `buffer` is only used within `on_interrupt`, -// and the original `&mut` passed to `fill_bytes` cannot be used because the safety contract of `Rng::new` -// means that it must still be borrowed by `RngFuture`, and so `rustc` will not let it be accessed. -unsafe impl Send for State {} - /// A wrapper around an nRF RNG peripheral. /// /// It has a non-blocking API, through `embassy::traits::Rng`, and a blocking api through `rand`. @@ -70,7 +63,7 @@ impl<'d> Rng<'d> { phantom: PhantomData, }; - Self::stop(); + this.stop(); this.disable_irq(); this.irq.set_handler(Self::on_interrupt); @@ -81,25 +74,54 @@ impl<'d> Rng<'d> { } fn on_interrupt(_: *mut ()) { - critical_section::with(|cs| { - let mut state = STATE.borrow(cs).borrow_mut(); - // SAFETY: the safety requirements on `Rng::new` make sure that the original `&mut`'s lifetime is still valid, - // meaning it can't be aliased and is a valid pointer. - let buffer = unsafe { state.buffer.unwrap().as_mut() }; - buffer[state.index] = RNG::regs().value.read().value().bits(); - state.index += 1; - if state.index == buffer.len() { - // Stop the RNG within the interrupt so that it doesn't get triggered again on the way to waking the future. - Self::stop(); - if let Some(waker) = state.waker.take() { - waker.wake(); + // Clear the event. + RNG::regs().events_valrdy.reset(); + + // Mutate the slice within a critical section, + // so that the future isn't dropped in between us loading the pointer and actually dereferencing it. + let (ptr, end) = critical_section::with(|_| { + let ptr = STATE.ptr.load(Ordering::Relaxed); + // We need to make sure we haven't already filled the whole slice, + // in case the interrupt fired again before the executor got back to the future. + let end = STATE.end.load(Ordering::Relaxed); + if !ptr.is_null() && ptr != end { + // If the future was dropped, the pointer would have been set to null, + // so we're still good to mutate the slice. + // The safety contract of `Rng::new` means that the future can't have been dropped + // without calling its destructor. + unsafe { + *ptr = RNG::regs().value.read().value().bits(); } } - RNG::regs().events_valrdy.reset(); + (ptr, end) }); + + if ptr.is_null() || ptr == end { + // If the future was dropped, there's nothing to do. + // If `ptr == end`, we were called by mistake, so return. + return; + } + + let new_ptr = unsafe { ptr.add(1) }; + match STATE + .ptr + .compare_exchange(ptr, new_ptr, Ordering::Relaxed, Ordering::Relaxed) + { + Ok(ptr) => { + let end = STATE.end.load(Ordering::Relaxed); + // It doesn't matter if `end` was changed under our feet, because then this will just be false. + if ptr == end { + STATE.waker.wake(); + } + } + Err(_) => { + // If the future was dropped or finished, there's no point trying to wake it. + // It will have already stopped the RNG, so there's no need to do that either. + } + } } - fn stop() { + fn stop(&self) { RNG::regs().tasks_stop.write(|w| unsafe { w.bits(1) }) } @@ -140,37 +162,41 @@ impl<'d> traits::rng::Rng for Rng<'d> { fn fill_bytes<'a>(&'a mut self, dest: &'a mut [u8]) -> Self::RngFuture<'a> { async move { - critical_section::with(|cs| { - let mut state = STATE.borrow(cs).borrow_mut(); - state.buffer = Some(dest.into()); - }); + if dest.len() == 0 { + return Ok(()); // Nothing to fill + } + + let range = dest.as_mut_ptr_range(); + // Even if we've preempted the interrupt, it can't preempt us again, + // so we don't need to worry about the order we write these in. + STATE.ptr.store(range.start, Ordering::Relaxed); + STATE.end.store(range.end, Ordering::Relaxed); self.enable_irq(); self.start(); let on_drop = OnDrop::new(|| { - Self::stop(); + self.stop(); self.disable_irq(); + + // The interrupt is now disabled and can't preempt us anymore, so the order doesn't matter here. + STATE.ptr.store(ptr::null_mut(), Ordering::Relaxed); + STATE.end.store(ptr::null_mut(), Ordering::Relaxed); }); poll_fn(|cx| { - critical_section::with(|cs| { - let mut state = STATE.borrow(cs).borrow_mut(); - state.waker = Some(cx.waker().clone()); - // SAFETY: see safety message in interrupt handler. - // Also, both here and in the interrupt handler, we're in a critical section, - // so they can't interfere with each other. - let buffer = unsafe { state.buffer.unwrap().as_ref() }; + STATE.waker.register(cx.waker()); - if state.index == buffer.len() { - // Reset the state for next time - state.buffer = None; - state.index = 0; - Poll::Ready(()) - } else { - Poll::Pending - } - }) + // The interrupt will never modify `end`, so load it first and then get the most up-to-date `ptr`. + let end = STATE.end.load(Ordering::Relaxed); + let ptr = STATE.ptr.load(Ordering::Relaxed); + + if ptr == end { + // We're done. + Poll::Ready(()) + } else { + Poll::Pending + } }) .await; @@ -193,7 +219,7 @@ impl<'d> RngCore for Rng<'d> { *byte = regs.value.read().value().bits(); } - Self::stop(); + self.stop(); } fn next_u32(&mut self) -> u32 {