extras: Fix UB in Peripheral

`Peripheral` assumed that interrupts can't be preempted,
when they can be preempted by higher priority interrupts.
So I put the interrupt handler inside a critical section,
and also added checks for whether the state had been dropped
before the critical section was entered.

I also added a `'static` bound to `PeripheralState`,
since `Pin` only guarantees that the memory it directly references
will not be invalidated.
It doesn't guarantee that memory its pointee references also won't be invalidated.

There were already some implementations of `PeripheralState`
that weren't `'static`, though,
so I added an unsafe `PeripheralStateUnchecked` trait
and forwarded the `unsafe` to the constructors of the implementors.
This commit is contained in:
Liam Murphy 2021-07-05 17:42:43 +10:00
parent ed83b93b6d
commit 744e2cbb8a
5 changed files with 86 additions and 28 deletions

View file

@ -17,4 +17,5 @@ embassy = { version = "0.1.0", path = "../embassy" }
defmt = { version = "0.2.0", optional = true } defmt = { version = "0.2.0", optional = true }
log = { version = "0.4.11", optional = true } log = { version = "0.4.11", optional = true }
cortex-m = "0.7.1" cortex-m = "0.7.1"
critical-section = "0.2.1"
usb-device = "0.2.7" usb-device = "0.2.7"

View file

@ -1,15 +1,38 @@
use core::cell::UnsafeCell; use core::cell::UnsafeCell;
use core::marker::{PhantomData, PhantomPinned}; use core::marker::{PhantomData, PhantomPinned};
use core::pin::Pin; use core::pin::Pin;
use core::ptr;
use embassy::interrupt::{Interrupt, InterruptExt}; use embassy::interrupt::{Interrupt, InterruptExt};
pub trait PeripheralState { /// # Safety
/// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`,
/// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`.
pub unsafe trait PeripheralStateUnchecked {
type Interrupt: Interrupt; type Interrupt: Interrupt;
fn on_interrupt(&mut self); fn on_interrupt(&mut self);
} }
pub struct PeripheralMutex<S: PeripheralState> { // `PeripheralMutex` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused
// without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid,
// so this `'static` bound is necessary.
pub trait PeripheralState: 'static {
type Interrupt: Interrupt;
fn on_interrupt(&mut self);
}
// SAFETY: `T` has to live for `'static` to implement `PeripheralState`, thus its lifetime cannot end.
unsafe impl<T> PeripheralStateUnchecked for T
where
T: PeripheralState,
{
type Interrupt = T::Interrupt;
fn on_interrupt(&mut self) {
self.on_interrupt()
}
}
pub struct PeripheralMutex<S: PeripheralStateUnchecked> {
state: UnsafeCell<S>, state: UnsafeCell<S>,
irq_setup_done: bool, irq_setup_done: bool,
@ -19,7 +42,7 @@ pub struct PeripheralMutex<S: PeripheralState> {
_pinned: PhantomPinned, _pinned: PhantomPinned,
} }
impl<S: PeripheralState> PeripheralMutex<S> { impl<S: PeripheralStateUnchecked> PeripheralMutex<S> {
pub fn new(state: S, irq: S::Interrupt) -> Self { pub fn new(state: S, irq: S::Interrupt) -> Self {
Self { Self {
irq, irq,
@ -39,11 +62,17 @@ impl<S: PeripheralState> PeripheralMutex<S> {
this.irq.disable(); this.irq.disable();
this.irq.set_handler(|p| { this.irq.set_handler(|p| {
// Safety: it's OK to get a &mut to the state, since critical_section::with(|_| {
// - We're in the IRQ, no one else can't preempt us if p.is_null() {
// - We can't have preempted a with() call because the irq is disabled during it. // The state was dropped, so we can't operate on it.
let state = unsafe { &mut *(p as *mut S) }; return;
state.on_interrupt(); }
// Safety: it's OK to get a &mut to the state, since
// - We're in a critical section, no one can preempt us (and call with())
// - We can't have preempted a with() call because the irq is disabled during it.
let state = unsafe { &mut *(p as *mut S) };
state.on_interrupt();
})
}); });
this.irq this.irq
.set_handler_context((&mut this.state) as *mut _ as *mut ()); .set_handler_context((&mut this.state) as *mut _ as *mut ());
@ -67,9 +96,12 @@ impl<S: PeripheralState> PeripheralMutex<S> {
} }
} }
impl<S: PeripheralState> Drop for PeripheralMutex<S> { impl<S: PeripheralStateUnchecked> Drop for PeripheralMutex<S> {
fn drop(&mut self) { fn drop(&mut self) {
self.irq.disable(); self.irq.disable();
self.irq.remove_handler(); self.irq.remove_handler();
// Set the context to null so that the interrupt will know we're dropped
// if we pre-empted it before it entered a critical section.
self.irq.set_handler_context(ptr::null_mut());
} }
} }

View file

@ -1,16 +1,27 @@
use core::cell::UnsafeCell;
use core::marker::{PhantomData, PhantomPinned}; use core::marker::{PhantomData, PhantomPinned};
use core::pin::Pin; use core::pin::Pin;
use core::ptr;
use embassy::interrupt::{Interrupt, InterruptExt}; use embassy::interrupt::{Interrupt, InterruptExt};
pub trait PeripheralState { /// # Safety
/// When types implementing this trait are used with `Peripheral` or `PeripheralMutex`,
/// their lifetime must not end without first calling `Drop` on the `Peripheral` or `PeripheralMutex`.
pub unsafe trait PeripheralStateUnchecked {
type Interrupt: Interrupt; type Interrupt: Interrupt;
fn on_interrupt(&self); fn on_interrupt(&self);
} }
pub struct Peripheral<S: PeripheralState> { // `Peripheral` is safe because `Pin` guarantees that the memory it references will not be invalidated or reused
state: UnsafeCell<S>, // without calling `Drop`. However, it provides no guarantees about references contained within the state still being valid,
// so this `'static` bound is necessary.
pub trait PeripheralState: 'static {
type Interrupt: Interrupt;
fn on_interrupt(&self);
}
pub struct Peripheral<S: PeripheralStateUnchecked> {
state: S,
irq_setup_done: bool, irq_setup_done: bool,
irq: S::Interrupt, irq: S::Interrupt,
@ -19,13 +30,13 @@ pub struct Peripheral<S: PeripheralState> {
_pinned: PhantomPinned, _pinned: PhantomPinned,
} }
impl<S: PeripheralState> Peripheral<S> { impl<S: PeripheralStateUnchecked> Peripheral<S> {
pub fn new(irq: S::Interrupt, state: S) -> Self { pub fn new(irq: S::Interrupt, state: S) -> Self {
Self { Self {
irq, irq,
irq_setup_done: false, irq_setup_done: false,
state: UnsafeCell::new(state), state,
_not_send: PhantomData, _not_send: PhantomData,
_pinned: PhantomPinned, _pinned: PhantomPinned,
} }
@ -39,8 +50,16 @@ impl<S: PeripheralState> Peripheral<S> {
this.irq.disable(); this.irq.disable();
this.irq.set_handler(|p| { this.irq.set_handler(|p| {
let state = unsafe { &*(p as *const S) }; // We need to be in a critical section so that no one can preempt us
state.on_interrupt(); // and drop the state after we check whether `p.is_null()`.
critical_section::with(|_| {
if p.is_null() {
// The state was dropped, so we can't operate on it.
return;
}
let state = unsafe { &*(p as *const S) };
state.on_interrupt();
});
}); });
this.irq this.irq
.set_handler_context((&this.state) as *const _ as *mut ()); .set_handler_context((&this.state) as *const _ as *mut ());
@ -49,15 +68,17 @@ impl<S: PeripheralState> Peripheral<S> {
this.irq_setup_done = true; this.irq_setup_done = true;
} }
pub fn state(self: Pin<&mut Self>) -> &S { pub fn state<'a>(self: Pin<&'a mut Self>) -> &'a S {
let this = unsafe { self.get_unchecked_mut() }; &self.into_ref().get_ref().state
unsafe { &*this.state.get() }
} }
} }
impl<S: PeripheralState> Drop for Peripheral<S> { impl<S: PeripheralStateUnchecked> Drop for Peripheral<S> {
fn drop(&mut self) { fn drop(&mut self) {
self.irq.disable(); self.irq.disable();
self.irq.remove_handler(); self.irq.remove_handler();
// Set the context to null so that the interrupt will know we're dropped
// if we pre-empted it before it entered a critical section.
self.irq.set_handler_context(ptr::null_mut());
} }
} }

View file

@ -9,7 +9,7 @@ use usb_device::device::UsbDevice;
mod cdc_acm; mod cdc_acm;
pub mod usb_serial; pub mod usb_serial;
use crate::peripheral::{PeripheralMutex, PeripheralState}; use crate::peripheral::{PeripheralMutex, PeripheralStateUnchecked};
use embassy::interrupt::Interrupt; use embassy::interrupt::Interrupt;
use usb_serial::{ReadInterface, UsbSerial, WriteInterface}; use usb_serial::{ReadInterface, UsbSerial, WriteInterface};
@ -55,10 +55,12 @@ where
} }
} }
pub fn start(self: Pin<&mut Self>) { /// # Safety
let this = unsafe { self.get_unchecked_mut() }; /// The `UsbDevice` passed to `Self::new` must not be dropped without calling `Drop` on this `Usb` first.
pub unsafe fn start(self: Pin<&mut Self>) {
let this = self.get_unchecked_mut();
let mut mutex = this.inner.borrow_mut(); let mut mutex = this.inner.borrow_mut();
let mutex = unsafe { Pin::new_unchecked(&mut *mutex) }; let mutex = Pin::new_unchecked(&mut *mutex);
// Use inner to register the irq // Use inner to register the irq
mutex.register_interrupt(); mutex.register_interrupt();
@ -125,7 +127,8 @@ where
} }
} }
impl<'bus, B, T, I> PeripheralState for State<'bus, B, T, I> // SAFETY: The safety contract of `PeripheralStateUnchecked` is forwarded to `Usb::start`.
unsafe impl<'bus, B, T, I> PeripheralStateUnchecked for State<'bus, B, T, I>
where where
B: UsbBus, B: UsbBus,
T: ClassSet<B>, T: ClassSet<B>,

View file

@ -7,7 +7,7 @@ use core::task::{Context, Poll};
use embassy::interrupt::InterruptExt; use embassy::interrupt::InterruptExt;
use embassy::io::{AsyncBufRead, AsyncWrite, Result}; use embassy::io::{AsyncBufRead, AsyncWrite, Result};
use embassy::util::{Unborrow, WakerRegistration}; use embassy::util::{Unborrow, WakerRegistration};
use embassy_extras::peripheral::{PeripheralMutex, PeripheralState}; use embassy_extras::peripheral::{PeripheralMutex, PeripheralStateUnchecked};
use embassy_extras::ring_buffer::RingBuffer; use embassy_extras::ring_buffer::RingBuffer;
use embassy_extras::{low_power_wait_until, unborrow}; use embassy_extras::{low_power_wait_until, unborrow};
@ -283,7 +283,8 @@ impl<'a, U: UarteInstance, T: TimerInstance> Drop for State<'a, U, T> {
} }
} }
impl<'a, U: UarteInstance, T: TimerInstance> PeripheralState for State<'a, U, T> { // SAFETY: the safety contract of `PeripheralStateUnchecked` is forwarded to `BufferedUarte::new`.
unsafe impl<'a, U: UarteInstance, T: TimerInstance> PeripheralStateUnchecked for State<'a, U, T> {
type Interrupt = U::Interrupt; type Interrupt = U::Interrupt;
fn on_interrupt(&mut self) { fn on_interrupt(&mut self) {
trace!("irq: start"); trace!("irq: start");