From 6f6c16f44924de4d71d0e5e3acc0908f2dd474e6 Mon Sep 17 00:00:00 2001
From: Dario Nieuwenhuis <dirbaio@dirbaio.net>
Date: Wed, 27 Apr 2022 04:27:42 +0200
Subject: [PATCH] executor: make send-spawning only require the task args to be
 Send, not the whole future.

---
 embassy-macros/src/macros/task.rs | 21 +++-------------
 embassy/src/executor/raw/mod.rs   | 41 +++++++++++++++++++++++++------
 embassy/src/executor/spawner.rs   | 20 +++++++++------
 3 files changed, 49 insertions(+), 33 deletions(-)

diff --git a/embassy-macros/src/macros/task.rs b/embassy-macros/src/macros/task.rs
index c450982c9..96932d77c 100644
--- a/embassy-macros/src/macros/task.rs
+++ b/embassy-macros/src/macros/task.rs
@@ -73,23 +73,10 @@ pub fn run(args: syn::AttributeArgs, f: syn::ItemFn) -> Result<TokenStream, Toke
         // in the user's code.
         #task_inner
 
-        #visibility fn #task_ident(#fargs) -> #embassy_path::executor::SpawnToken<impl ::core::future::Future + 'static> {
-            use ::core::future::Future;
-            use #embassy_path::executor::SpawnToken;
-            use #embassy_path::executor::raw::TaskPool;
-
-            type Fut = impl Future + 'static;
-
-            static POOL: TaskPool<Fut, #pool_size> = TaskPool::new();
-
-            // Opaque type laundering, to obscure its origin!
-            // Workaround for "opaque type's hidden type cannot be another opaque type from the same scope"
-            // https://github.com/rust-lang/rust/issues/96406
-            fn launder_tait(token: SpawnToken<impl Future+'static>) -> SpawnToken<impl Future+'static> {
-                token
-            }
-
-            launder_tait(POOL.spawn(move || #task_inner_ident(#(#arg_names,)*)))
+        #visibility fn #task_ident(#fargs) -> #embassy_path::executor::SpawnToken<impl Sized> {
+            type Fut = impl ::core::future::Future + 'static;
+            static POOL: #embassy_path::executor::raw::TaskPool<Fut, #pool_size> = #embassy_path::executor::raw::TaskPool::new();
+            POOL.spawn(move || #task_inner_ident(#(#arg_names,)*))
         }
     };
 
diff --git a/embassy/src/executor/raw/mod.rs b/embassy/src/executor/raw/mod.rs
index 5af35d868..6b14b8e8c 100644
--- a/embassy/src/executor/raw/mod.rs
+++ b/embassy/src/executor/raw/mod.rs
@@ -110,8 +110,8 @@ impl TaskHeader {
 /// Raw storage in which a task can be spawned.
 ///
 /// This struct holds the necessary memory to spawn one task whose future is `F`.
-/// At a given time, the `Task` may be in spawned or not-spawned state. You may spawn it
-/// with [`Task::spawn()`], which will fail if it is already spawned.
+/// At a given time, the `TaskStorage` may be in spawned or not-spawned state. You
+/// may spawn it with [`TaskStorage::spawn()`], which will fail if it is already spawned.
 ///
 /// A `TaskStorage` must live forever, it may not be deallocated even after the task has finished
 /// running. Hence the relevant methods require `&'static self`. It may be reused, however.
@@ -159,11 +159,11 @@ impl<F: Future + 'static> TaskStorage<F> {
     ///
     /// This function will fail if the task is already spawned and has not finished running.
     /// In this case, the error is delayed: a "poisoned" SpawnToken is returned, which will
-    /// cause [`Executor::spawn()`] to return the error.
+    /// cause [`Spawner::spawn()`] to return the error.
     ///
     /// Once the task has finished running, you may spawn it again. It is allowed to spawn it
     /// on a different executor.
-    pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<F> {
+    pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
         if self.spawn_allocate() {
             unsafe { self.spawn_initialize(future) }
         } else {
@@ -179,12 +179,37 @@ impl<F: Future + 'static> TaskStorage<F> {
             .is_ok()
     }
 
-    unsafe fn spawn_initialize(&'static self, future: impl FnOnce() -> F) -> SpawnToken<F> {
+    unsafe fn spawn_initialize<FutFn>(&'static self, future: FutFn) -> SpawnToken<impl Sized>
+    where
+        FutFn: FnOnce() -> F,
+    {
         // Initialize the task
         self.raw.poll_fn.write(Self::poll);
         self.future.write(future());
 
-        SpawnToken::new(NonNull::new_unchecked(&self.raw as *const TaskHeader as _))
+        // When send-spawning a task, we construct the future in this thread, and effectively
+        // "send" it to the executor thread by enqueuing it in its queue. Therefore, in theory,
+        // send-spawning should require the future `F` to be `Send`.
+        //
+        // The problem is this is more restrictive than needed. Once the future is executing,
+        // it is never sent to another thread. It is only sent when spawning. It should be
+        // enough for the task's arguments to be Send. (and in practice it's super easy to
+        // accidentally make your futures !Send, for example by holding an `Rc` or a `&RefCell` across an `.await`.)
+        //
+        // We can do it by sending the task args and constructing the future in the executor thread
+        // on first poll. However, this cannot be done in-place, so it'll waste stack space for a copy
+        // of the args.
+        //
+        // Luckily, an `async fn` future contains just the args when freshly constructed. So, if the
+        // args are Send, it's OK to send a !Send future, as long as we do it before first polling it.
+        //
+        // (Note: this is how the generators are implemented today, it's not officially guaranteed yet,
+        // but it's possible it'll be guaranteed in the future. See zulip thread:
+        // https://rust-lang.zulipchat.com/#narrow/stream/187312-wg-async/topic/.22only.20before.20poll.22.20Send.20futures )
+        //
+        // The `FutFn` captures all the args, so if it's Send, the task can be send-spawned.
+        // This is why we return `SpawnToken<FutFn>` below.
+        SpawnToken::<FutFn>::new(NonNull::new_unchecked(&self.raw as *const TaskHeader as _))
     }
 
     unsafe fn poll(p: NonNull<TaskHeader>) {
@@ -232,8 +257,8 @@ impl<F: Future + 'static, const N: usize> TaskPool<F, N> {
     ///
     /// This will loop over the pool and spawn the task in the first storage that
     /// is currently free. If none is free, a "poisoned" SpawnToken is returned,
-    /// which will cause [`Executor::spawn()`] to return the error.
-    pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<F> {
+    /// which will cause [`Spawner::spawn()`] to return the error.
+    pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
         for task in &self.pool {
             if task.spawn_allocate() {
                 return unsafe { task.spawn_initialize(future) };
diff --git a/embassy/src/executor/spawner.rs b/embassy/src/executor/spawner.rs
index e6770e299..73c1f786f 100644
--- a/embassy/src/executor/spawner.rs
+++ b/embassy/src/executor/spawner.rs
@@ -12,17 +12,21 @@ use super::raw;
 /// value is a `SpawnToken` that represents an instance of the task, ready to spawn. You must
 /// then spawn it into an executor, typically with [`Spawner::spawn()`].
 ///
+/// The generic parameter `S` determines whether the task can be spawned in executors
+/// in other threads or not. If `S: Send`, it can, which allows spawning it into a [`SendSpawner`].
+/// If not, it can't, so it can only be spawned into the current thread's executor, with [`Spawner`].
+///
 /// # Panics
 ///
 /// Dropping a SpawnToken instance panics. You may not "abort" spawning a task in this way.
 /// Once you've invoked a task function and obtained a SpawnToken, you *must* spawn it.
 #[must_use = "Calling a task function does nothing on its own. You must spawn the returned SpawnToken, typically with Spawner::spawn()"]
-pub struct SpawnToken<F> {
+pub struct SpawnToken<S> {
     raw_task: Option<NonNull<raw::TaskHeader>>,
-    phantom: PhantomData<*mut F>,
+    phantom: PhantomData<*mut S>,
 }
 
-impl<F> SpawnToken<F> {
+impl<S> SpawnToken<S> {
     pub(crate) unsafe fn new(raw_task: NonNull<raw::TaskHeader>) -> Self {
         Self {
             raw_task: Some(raw_task),
@@ -38,7 +42,7 @@ impl<F> SpawnToken<F> {
     }
 }
 
-impl<F> Drop for SpawnToken<F> {
+impl<S> Drop for SpawnToken<S> {
     fn drop(&mut self) {
         // TODO deallocate the task instead.
         panic!("SpawnToken instances may not be dropped. You must pass them to Spawner::spawn()")
@@ -97,7 +101,7 @@ impl Spawner {
     /// Spawn a task into an executor.
     ///
     /// You obtain the `token` by calling a task function (i.e. one marked with `#[embassy::task]`).
-    pub fn spawn<F>(&self, token: SpawnToken<F>) -> Result<(), SpawnError> {
+    pub fn spawn<S>(&self, token: SpawnToken<S>) -> Result<(), SpawnError> {
         let task = token.raw_task;
         mem::forget(token);
 
@@ -119,7 +123,7 @@ impl Spawner {
     /// # Panics
     ///
     /// Panics if the spawning fails.
-    pub fn must_spawn<F>(&self, token: SpawnToken<F>) {
+    pub fn must_spawn<S>(&self, token: SpawnToken<S>) {
         unwrap!(self.spawn(token));
     }
 
@@ -173,7 +177,7 @@ impl SendSpawner {
     /// Spawn a task into an executor.
     ///
     /// You obtain the `token` by calling a task function (i.e. one marked with `#[embassy::task]`).
-    pub fn spawn<F: Send>(&self, token: SpawnToken<F>) -> Result<(), SpawnError> {
+    pub fn spawn<S: Send>(&self, token: SpawnToken<S>) -> Result<(), SpawnError> {
         let header = token.raw_task;
         mem::forget(token);
 
@@ -191,7 +195,7 @@ impl SendSpawner {
     /// # Panics
     ///
     /// Panics if the spawning fails.
-    pub fn must_spawn<F: Send>(&self, token: SpawnToken<F>) {
+    pub fn must_spawn<S: Send>(&self, token: SpawnToken<S>) {
         unwrap!(self.spawn(token));
     }
 }