Merge pull request #277 from Liamolucko/fix-peripheral-ub

extras: Fix UB in `Peripheral`
This commit is contained in:
Dario Nieuwenhuis 2021-07-29 13:08:30 +02:00 committed by GitHub
commit c8a48d726a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 209 additions and 53 deletions

View file

@ -2,9 +2,19 @@ use core::cell::UnsafeCell;
use core::marker::{PhantomData, PhantomPinned};
use core::pin::Pin;
use cortex_m::peripheral::scb::VectActive;
use cortex_m::peripheral::{NVIC, SCB};
use embassy::interrupt::{Interrupt, InterruptExt};
pub trait PeripheralState {
/// A type which can be used as state with `PeripheralMutex`.
///
/// It needs to be `Send` because `&mut` references are sent back and forth between the 'thread' which owns the `PeripheralMutex` and the interrupt,
/// and `&mut T` is only `Send` where `T: Send`.
///
/// It also requires `'static` to be used safely with `PeripheralMutex::register_interrupt`,
/// because although `Pin` guarantees that the memory of the state won't be invalidated,
/// it doesn't guarantee that the lifetime will last.
pub trait PeripheralState: Send {
type Interrupt: Interrupt;
fn on_interrupt(&mut self);
}
@ -19,8 +29,51 @@ pub struct PeripheralMutex<S: PeripheralState> {
_pinned: PhantomPinned,
}
/// Whether `irq` can be preempted by the current interrupt.
pub(crate) fn can_be_preempted(irq: &impl Interrupt) -> bool {
match SCB::vect_active() {
// Thread mode can't preempt anything.
VectActive::ThreadMode => false,
// Exceptions don't always preempt interrupts,
// but there isn't much of a good reason to be keeping a `PeripheralMutex` in an exception anyway.
VectActive::Exception(_) => true,
VectActive::Interrupt { irqn } => {
#[derive(Clone, Copy)]
struct NrWrap(u16);
unsafe impl cortex_m::interrupt::InterruptNumber for NrWrap {
fn number(self) -> u16 {
self.0
}
}
NVIC::get_priority(NrWrap(irqn.into())) < irq.get_priority().into()
}
}
}
impl<S: PeripheralState + 'static> PeripheralMutex<S> {
/// Registers `on_interrupt` as the wrapped interrupt's interrupt handler and enables it.
///
/// This requires this `PeripheralMutex`'s `PeripheralState` to live for `'static`,
/// because `Pin` only guarantees that it's memory won't be repurposed,
/// not that it's lifetime will last.
///
/// To use non-`'static` `PeripheralState`, use the unsafe `register_interrupt_unchecked`.
///
/// Note: `'static` doesn't mean it _has_ to live for the entire program, like an `&'static T`;
/// it just means it _can_ live for the entire program - for example, `u8` lives for `'static`.
pub fn register_interrupt(self: Pin<&mut Self>) {
// SAFETY: `S: 'static`, so there's no way it's lifetime can expire.
unsafe { self.register_interrupt_unchecked() }
}
}
impl<S: PeripheralState> PeripheralMutex<S> {
/// Create a new `PeripheralMutex` wrapping `irq`, with the initial state `state`.
pub fn new(state: S, irq: S::Interrupt) -> Self {
if can_be_preempted(&irq) {
panic!("`PeripheralMutex` cannot be created in an interrupt with higher priority than the interrupt it wraps");
}
Self {
irq,
irq_setup_done: false,
@ -31,8 +84,18 @@ impl<S: PeripheralState> PeripheralMutex<S> {
}
}
pub fn register_interrupt(self: Pin<&mut Self>) {
let this = unsafe { self.get_unchecked_mut() };
/// Registers `on_interrupt` as the wrapped interrupt's interrupt handler and enables it.
///
/// # Safety
/// The lifetime of any data in `PeripheralState` that is accessed by the interrupt handler
/// must not end without `Drop` being called on this `PeripheralMutex`.
///
/// This can be accomplished by either not accessing any data with a lifetime in `on_interrupt`,
/// or making sure that nothing like `mem::forget` is used on the `PeripheralMutex`.
// TODO: this name isn't the best.
pub unsafe fn register_interrupt_unchecked(self: Pin<&mut Self>) {
let this = self.get_unchecked_mut();
if this.irq_setup_done {
return;
}
@ -40,7 +103,9 @@ impl<S: PeripheralState> PeripheralMutex<S> {
this.irq.disable();
this.irq.set_handler(|p| {
// Safety: it's OK to get a &mut to the state, since
// - We're in the IRQ, no one else can't preempt us
// - We checked that the thread owning the `PeripheralMutex` can't preempt us in `new`.
// Interrupts' priorities can only be changed with raw embassy `Interrupts`,
// which can't safely store a `PeripheralMutex` across invocations.
// - We can't have preempted a with() call because the irq is disabled during it.
let state = unsafe { &mut *(p as *mut S) };
state.on_interrupt();
@ -52,19 +117,39 @@ impl<S: PeripheralState> PeripheralMutex<S> {
this.irq_setup_done = true;
}
pub fn with<R>(self: Pin<&mut Self>, f: impl FnOnce(&mut S, &mut S::Interrupt) -> R) -> R {
pub fn with<R>(self: Pin<&mut Self>, f: impl FnOnce(&mut S) -> R) -> R {
let this = unsafe { self.get_unchecked_mut() };
this.irq.disable();
// Safety: it's OK to get a &mut to the state, since the irq is disabled.
let state = unsafe { &mut *this.state.get() };
let r = f(state, &mut this.irq);
let r = f(state);
this.irq.enable();
r
}
/// Returns whether the wrapped interrupt is currently in a pending state.
pub fn is_pending(&self) -> bool {
self.irq.is_pending()
}
/// Forces the wrapped interrupt into a pending state.
pub fn pend(&self) {
self.irq.pend()
}
/// Forces the wrapped interrupt out of a pending state.
pub fn unpend(&self) {
self.irq.unpend()
}
/// Gets the priority of the wrapped interrupt.
pub fn priority(&self) -> <S::Interrupt as Interrupt>::Priority {
self.irq.get_priority()
}
}
impl<S: PeripheralState> Drop for PeripheralMutex<S> {

View file

@ -1,16 +1,24 @@
use core::cell::UnsafeCell;
use core::marker::{PhantomData, PhantomPinned};
use core::pin::Pin;
use embassy::interrupt::{Interrupt, InterruptExt};
pub trait PeripheralState {
use crate::peripheral::can_be_preempted;
/// A type which can be used as state with `Peripheral`.
///
/// It needs to be `Sync` because references are shared between the 'thread' which owns the `Peripheral` and the interrupt.
///
/// It also requires `'static` to be used safely with `Peripheral::register_interrupt`,
/// because although `Pin` guarantees that the memory of the state won't be invalidated,
/// it doesn't guarantee that the lifetime will last.
pub trait PeripheralState: Sync {
type Interrupt: Interrupt;
fn on_interrupt(&self);
}
pub struct Peripheral<S: PeripheralState> {
state: UnsafeCell<S>,
state: S,
irq_setup_done: bool,
irq: S::Interrupt,
@ -19,26 +27,58 @@ pub struct Peripheral<S: PeripheralState> {
_pinned: PhantomPinned,
}
impl<S: PeripheralState + 'static> Peripheral<S> {
/// Registers `on_interrupt` as the wrapped interrupt's interrupt handler and enables it.
///
/// This requires this `Peripheral`'s `PeripheralState` to live for `'static`,
/// because `Pin` only guarantees that it's memory won't be repurposed,
/// not that it's lifetime will last.
///
/// To use non-`'static` `PeripheralState`, use the unsafe `register_interrupt_unchecked`.
///
/// Note: `'static` doesn't mean it _has_ to live for the entire program, like an `&'static T`;
/// it just means it _can_ live for the entire program - for example, `u8` lives for `'static`.
pub fn register_interrupt(self: Pin<&mut Self>) {
// SAFETY: `S: 'static`, so there's no way it's lifetime can expire.
unsafe { self.register_interrupt_unchecked() }
}
}
impl<S: PeripheralState> Peripheral<S> {
pub fn new(irq: S::Interrupt, state: S) -> Self {
if can_be_preempted(&irq) {
panic!("`Peripheral` cannot be created in an interrupt with higher priority than the interrupt it wraps");
}
Self {
irq,
irq_setup_done: false,
state: UnsafeCell::new(state),
state,
_not_send: PhantomData,
_pinned: PhantomPinned,
}
}
pub fn register_interrupt(self: Pin<&mut Self>) {
let this = unsafe { self.get_unchecked_mut() };
/// Registers `on_interrupt` as the wrapped interrupt's interrupt handler and enables it.
///
/// # Safety
/// The lifetime of any data in `PeripheralState` that is accessed by the interrupt handler
/// must not end without `Drop` being called on this `Peripheral`.
///
/// This can be accomplished by either not accessing any data with a lifetime in `on_interrupt`,
/// or making sure that nothing like `mem::forget` is used on the `Peripheral`.
pub unsafe fn register_interrupt_unchecked(self: Pin<&mut Self>) {
let this = self.get_unchecked_mut();
if this.irq_setup_done {
return;
}
this.irq.disable();
this.irq.set_handler(|p| {
// The state can't have been dropped, otherwise the interrupt would have been disabled.
// We checked in `new` that the thread owning the `Peripheral` can't preempt the interrupt,
// so someone can't have preempted us before this point and dropped the `Peripheral`.
let state = unsafe { &*(p as *const S) };
state.on_interrupt();
});
@ -50,8 +90,27 @@ impl<S: PeripheralState> Peripheral<S> {
}
pub fn state(self: Pin<&mut Self>) -> &S {
let this = unsafe { self.get_unchecked_mut() };
unsafe { &*this.state.get() }
&self.into_ref().get_ref().state
}
/// Returns whether the wrapped interrupt is currently in a pending state.
pub fn is_pending(&self) -> bool {
self.irq.is_pending()
}
/// Forces the wrapped interrupt into a pending state.
pub fn pend(&self) {
self.irq.pend()
}
/// Forces the wrapped interrupt out of a pending state.
pub fn unpend(&self) {
self.irq.unpend()
}
/// Gets the priority of the wrapped interrupt.
pub fn priority(&self) -> <S::Interrupt as Interrupt>::Priority {
self.irq.get_priority()
}
}

View file

@ -14,7 +14,7 @@ use embassy::interrupt::Interrupt;
use usb_serial::{ReadInterface, UsbSerial, WriteInterface};
/// Marker trait to mark an interrupt to be used with the [`Usb`] abstraction.
pub unsafe trait USBInterrupt: Interrupt {}
pub unsafe trait USBInterrupt: Interrupt + Send {}
pub(crate) struct State<'bus, B, T, I>
where
@ -55,13 +55,17 @@ where
}
}
pub fn start(self: Pin<&mut Self>) {
let this = unsafe { self.get_unchecked_mut() };
/// # Safety
/// The `UsbDevice` passed to `Self::new` must not be dropped without calling `Drop` on this `Usb` first.
pub unsafe fn start(self: Pin<&mut Self>) {
let this = self.get_unchecked_mut();
let mut mutex = this.inner.borrow_mut();
let mutex = unsafe { Pin::new_unchecked(&mut *mutex) };
let mutex = Pin::new_unchecked(&mut *mutex);
// Use inner to register the irq
mutex.register_interrupt();
// SAFETY: the safety contract of this function makes sure the `UsbDevice` won't be invalidated
// without the `PeripheralMutex` being dropped.
mutex.register_interrupt_unchecked();
}
}
@ -137,7 +141,7 @@ where
}
}
pub trait ClassSet<B: UsbBus> {
pub trait ClassSet<B: UsbBus>: Send {
fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool;
}
@ -173,8 +177,8 @@ pub struct Index1;
impl<B, C1> ClassSet<B> for ClassSet1<B, C1>
where
B: UsbBus,
C1: UsbClass<B>,
B: UsbBus + Send,
C1: UsbClass<B> + Send,
{
fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool {
device.poll(&mut [&mut self.class])
@ -183,9 +187,9 @@ where
impl<B, C1, C2> ClassSet<B> for ClassSet2<B, C1, C2>
where
B: UsbBus,
C1: UsbClass<B>,
C2: UsbClass<B>,
B: UsbBus + Send,
C1: UsbClass<B> + Send,
C2: UsbClass<B> + Send,
{
fn poll_all(&mut self, device: &mut UsbDevice<'_, B>) -> bool {
device.poll(&mut [&mut self.class1, &mut self.class2])
@ -194,8 +198,8 @@ where
impl<B, C1> IntoClassSet<B, ClassSet1<B, C1>> for C1
where
B: UsbBus,
C1: UsbClass<B>,
B: UsbBus + Send,
C1: UsbClass<B> + Send,
{
fn into_class_set(self) -> ClassSet1<B, C1> {
ClassSet1 {
@ -207,9 +211,9 @@ where
impl<B, C1, C2> IntoClassSet<B, ClassSet2<B, C1, C2>> for (C1, C2)
where
B: UsbBus,
C1: UsbClass<B>,
C2: UsbClass<B>,
B: UsbBus + Send,
C1: UsbClass<B> + Send,
C2: UsbClass<B> + Send,
{
fn into_class_set(self) -> ClassSet2<B, C1, C2> {
ClassSet2 {

View file

@ -55,7 +55,7 @@ where
let this = self.get_mut();
let mut mutex = this.inner.borrow_mut();
let mutex = unsafe { Pin::new_unchecked(&mut *mutex) };
mutex.with(|state, _irq| {
mutex.with(|state| {
let serial = state.classes.get_serial();
let serial = Pin::new(serial);
@ -77,7 +77,7 @@ where
let this = self.get_mut();
let mut mutex = this.inner.borrow_mut();
let mutex = unsafe { Pin::new_unchecked(&mut *mutex) };
mutex.with(|state, _irq| {
mutex.with(|state| {
let serial = state.classes.get_serial();
let serial = Pin::new(serial);
@ -101,7 +101,7 @@ where
let this = self.get_mut();
let mut mutex = this.inner.borrow_mut();
let mutex = unsafe { Pin::new_unchecked(&mut *mutex) };
mutex.with(|state, _irq| {
mutex.with(|state| {
let serial = state.classes.get_serial();
let serial = Pin::new(serial);

View file

@ -175,8 +175,8 @@ impl<'d, U: UarteInstance, T: TimerInstance> BufferedUarte<'d, U, T> {
pub fn set_baudrate(self: Pin<&mut Self>, baudrate: Baudrate) {
let mut inner = self.inner();
inner.as_mut().register_interrupt();
inner.with(|state, _irq| {
unsafe { inner.as_mut().register_interrupt_unchecked() }
inner.with(|state| {
let r = U::regs();
let timeout = 0x8000_0000 / (baudrate as u32 / 40);
@ -195,8 +195,8 @@ impl<'d, U: UarteInstance, T: TimerInstance> BufferedUarte<'d, U, T> {
impl<'d, U: UarteInstance, T: TimerInstance> AsyncBufRead for BufferedUarte<'d, U, T> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
let mut inner = self.inner();
inner.as_mut().register_interrupt();
inner.with(|state, _irq| {
unsafe { inner.as_mut().register_interrupt_unchecked() }
inner.with(|state| {
// Conservative compiler fence to prevent optimizations that do not
// take in to account actions by DMA. The fence has been placed here,
// before any DMA action has started
@ -220,20 +220,20 @@ impl<'d, U: UarteInstance, T: TimerInstance> AsyncBufRead for BufferedUarte<'d,
fn consume(self: Pin<&mut Self>, amt: usize) {
let mut inner = self.inner();
inner.as_mut().register_interrupt();
inner.with(|state, irq| {
unsafe { inner.as_mut().register_interrupt_unchecked() }
inner.as_mut().with(|state| {
trace!("consume {:?}", amt);
state.rx.pop(amt);
irq.pend();
})
});
inner.pend();
}
}
impl<'d, U: UarteInstance, T: TimerInstance> AsyncWrite for BufferedUarte<'d, U, T> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let mut inner = self.inner();
inner.as_mut().register_interrupt();
inner.with(|state, irq| {
unsafe { inner.as_mut().register_interrupt_unchecked() }
let poll = inner.as_mut().with(|state| {
trace!("poll_write: {:?}", buf.len());
let tx_buf = state.tx.push_buf();
@ -254,10 +254,12 @@ impl<'d, U: UarteInstance, T: TimerInstance> AsyncWrite for BufferedUarte<'d, U,
// before any DMA action has started
compiler_fence(Ordering::SeqCst);
irq.pend();
Poll::Ready(Ok(n))
})
});
inner.pend();
poll
}
}

View file

@ -29,7 +29,7 @@ pub(crate) mod sealed {
pub trait ExtendedInstance {}
}
pub trait Instance: Unborrow<Target = Self> + sealed::Instance + 'static {
pub trait Instance: Unborrow<Target = Self> + sealed::Instance + 'static + Send {
type Interrupt: Interrupt;
}
pub trait ExtendedInstance: Instance + sealed::ExtendedInstance {}

View file

@ -461,7 +461,7 @@ pub(crate) mod sealed {
}
}
pub trait Instance: Unborrow<Target = Self> + sealed::Instance + 'static {
pub trait Instance: Unborrow<Target = Self> + sealed::Instance + 'static + Send {
type Interrupt: Interrupt;
}

View file

@ -159,9 +159,10 @@ impl<'d, P: PHY, const TX: usize, const RX: usize> Ethernet<'d, P, TX, RX> {
// NOTE(unsafe) We won't move this
let this = unsafe { self.get_unchecked_mut() };
let mut mutex = unsafe { Pin::new_unchecked(&mut this.state) };
mutex.as_mut().register_interrupt();
// SAFETY: The lifetime of `Inner` is only due to `PhantomData`; it isn't actually referencing any data with that lifetime.
unsafe { mutex.as_mut().register_interrupt_unchecked() }
mutex.with(|s, _| {
mutex.with(|s| {
s.desc_ring.init();
fence(Ordering::SeqCst);
@ -237,7 +238,7 @@ impl<'d, P: PHY, const TX: usize, const RX: usize> Device for Pin<&mut Ethernet<
let this = unsafe { self.as_mut().get_unchecked_mut() };
let mutex = unsafe { Pin::new_unchecked(&mut this.state) };
mutex.with(|s, _| s.desc_ring.tx.available())
mutex.with(|s| s.desc_ring.tx.available())
}
fn transmit(&mut self, pkt: PacketBuf) {
@ -245,7 +246,7 @@ impl<'d, P: PHY, const TX: usize, const RX: usize> Device for Pin<&mut Ethernet<
let this = unsafe { self.as_mut().get_unchecked_mut() };
let mutex = unsafe { Pin::new_unchecked(&mut this.state) };
mutex.with(|s, _| unwrap!(s.desc_ring.tx.transmit(pkt)));
mutex.with(|s| unwrap!(s.desc_ring.tx.transmit(pkt)));
}
fn receive(&mut self) -> Option<PacketBuf> {
@ -253,7 +254,7 @@ impl<'d, P: PHY, const TX: usize, const RX: usize> Device for Pin<&mut Ethernet<
let this = unsafe { self.as_mut().get_unchecked_mut() };
let mutex = unsafe { Pin::new_unchecked(&mut this.state) };
mutex.with(|s, _| s.desc_ring.rx.pop_packet())
mutex.with(|s| s.desc_ring.rx.pop_packet())
}
fn register_waker(&mut self, waker: &Waker) {

View file

@ -48,6 +48,11 @@ impl WakerRegistration {
}
}
// SAFETY: `WakerRegistration` effectively contains an `Option<Waker>`,
// which is `Send` and `Sync`.
unsafe impl Send for WakerRegistration {}
unsafe impl Sync for WakerRegistration {}
pub struct AtomicWaker {
waker: AtomicPtr<TaskHeader>,
}