diff --git a/embassy-executor/src/raw/mod.rs b/embassy-executor/src/raw/mod.rs index 6783c4853..e93e60362 100644 --- a/embassy-executor/src/raw/mod.rs +++ b/embassy-executor/src/raw/mod.rs @@ -141,25 +141,14 @@ impl<F: Future + 'static> TaskStorage<F> { /// Once the task has finished running, you may spawn it again. It is allowed to spawn it /// on a different executor. pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> { - if self.spawn_mark_used() { - return unsafe { SpawnToken::<F>::new(self.spawn_initialize(future)) }; + let task = AvailableTask::claim(self); + match task { + Some(task) => { + let task = task.initialize(future); + unsafe { SpawnToken::<F>::new(task) } + } + None => SpawnToken::new_failed(), } - - SpawnToken::<F>::new_failed() - } - - fn spawn_mark_used(&'static self) -> bool { - let state = STATE_SPAWNED | STATE_RUN_QUEUED; - self.raw - .state - .compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - } - - unsafe fn spawn_initialize(&'static self, future: impl FnOnce() -> F) -> TaskRef { - // Initialize the task - self.future.write(future()); - TaskRef::new(self) } unsafe fn poll(p: TaskRef) { @@ -184,6 +173,27 @@ impl<F: Future + 'static> TaskStorage<F> { unsafe impl<F: Future + 'static> Sync for TaskStorage<F> {} +struct AvailableTask<F: Future + 'static> { + task: &'static TaskStorage<F>, +} + +impl<F: Future + 'static> AvailableTask<F> { + fn claim(task: &'static TaskStorage<F>) -> Option<Self> { + task.raw + .state + .compare_exchange(0, STATE_SPAWNED | STATE_RUN_QUEUED, Ordering::AcqRel, Ordering::Acquire) + .ok() + .map(|_| Self { task }) + } + + fn initialize(self, future: impl FnOnce() -> F) -> TaskRef { + unsafe { + self.task.future.write(future()); + } + TaskRef::new(self.task) + } +} + /// Raw storage that can hold up to N tasks of the same type. /// /// This is essentially a `[TaskStorage<F>; N]`. @@ -207,13 +217,14 @@ impl<F: Future + 'static, const N: usize> TaskPool<F, N> { /// is currently free. If none is free, a "poisoned" SpawnToken is returned, /// which will cause [`Spawner::spawn()`](super::Spawner::spawn) to return the error. pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> { - for task in &self.pool { - if task.spawn_mark_used() { - return unsafe { SpawnToken::<F>::new(task.spawn_initialize(future)) }; + let task = self.pool.iter().find_map(AvailableTask::claim); + match task { + Some(task) => { + let task = task.initialize(future); + unsafe { SpawnToken::<F>::new(task) } } + None => SpawnToken::new_failed(), } - - SpawnToken::<F>::new_failed() } /// Like spawn(), but allows the task to be send-spawned if the args are Send even if @@ -255,13 +266,14 @@ impl<F: Future + 'static, const N: usize> TaskPool<F, N> { // This ONLY holds for `async fn` futures. The other `spawn` methods can be called directly // by the user, with arbitrary hand-implemented futures. This is why these return `SpawnToken<F>`. - for task in &self.pool { - if task.spawn_mark_used() { - return SpawnToken::<FutFn>::new(task.spawn_initialize(future)); + let task = self.pool.iter().find_map(AvailableTask::claim); + match task { + Some(task) => { + let task = task.initialize(future); + unsafe { SpawnToken::<FutFn>::new(task) } } + None => SpawnToken::new_failed(), } - - SpawnToken::<FutFn>::new_failed() } }