From 8c2a6df03b852233ef6c774896cbb00c2a15040f Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Thu, 28 Dec 2023 16:23:47 +0800
Subject: [PATCH 1/6] implement PWM waveform generating with DMA

---
 embassy-stm32/build.rs                        |  11 ++
 embassy-stm32/src/timer/mod.rs                |  18 +++-
 embassy-stm32/src/timer/simple_pwm.rs         | 100 +++++++++++++++---
 examples/stm32f4/src/bin/pwm.rs               |  12 ++-
 .../bin/{ws2812_pwm_dma.rs => ws2812_pwm.rs}  |  75 +++----------
 examples/stm32g4/src/bin/pwm.rs               |  12 ++-
 examples/stm32h7/src/bin/pwm.rs               |  13 ++-
 7 files changed, 159 insertions(+), 82 deletions(-)
 rename examples/stm32f4/src/bin/{ws2812_pwm_dma.rs => ws2812_pwm.rs} (50%)

diff --git a/embassy-stm32/build.rs b/embassy-stm32/build.rs
index 058b8a0fc..de03827e9 100644
--- a/embassy-stm32/build.rs
+++ b/embassy-stm32/build.rs
@@ -1008,6 +1008,7 @@ fn main() {
         (("quadspi", "QUADSPI"), quote!(crate::qspi::QuadDma)),
         (("dac", "CH1"), quote!(crate::dac::DacDma1)),
         (("dac", "CH2"), quote!(crate::dac::DacDma2)),
+        (("timer", "UP"), quote!(crate::timer::UpDma)),
     ]
     .into();
 
@@ -1023,6 +1024,16 @@ fn main() {
                 }
 
                 if let Some(tr) = signals.get(&(regs.kind, ch.signal)) {
+                    // TIM6 of stm32f334 is special, DMA channel for TIM6 depending on SYSCFG state
+                    if chip_name.starts_with("stm32f334") && p.name == "TIM6" {
+                        continue;
+                    }
+
+                    // TIM6 of stm32f378 is special, DMA channel for TIM6 depending on SYSCFG state
+                    if chip_name.starts_with("stm32f378") && p.name == "TIM6" {
+                        continue;
+                    }
+
                     let peri = format_ident!("{}", p.name);
 
                     let channel = if let Some(channel) = &ch.channel {
diff --git a/embassy-stm32/src/timer/mod.rs b/embassy-stm32/src/timer/mod.rs
index 74120adad..05a0564a3 100644
--- a/embassy-stm32/src/timer/mod.rs
+++ b/embassy-stm32/src/timer/mod.rs
@@ -91,7 +91,12 @@ pub(crate) mod sealed {
 
         /// Enable/disable the update interrupt.
         fn enable_update_interrupt(&mut self, enable: bool) {
-            Self::regs().dier().write(|r| r.set_uie(enable));
+            Self::regs().dier().modify(|r| r.set_uie(enable));
+        }
+
+        /// Enable/disable the update dma.
+        fn enable_update_dma(&mut self, enable: bool) {
+            Self::regs().dier().modify(|r| r.set_ude(enable));
         }
 
         /// Enable/disable autoreload preload.
@@ -288,6 +293,14 @@ pub(crate) mod sealed {
         fn get_compare_value(&self, channel: Channel) -> u16 {
             Self::regs_gp16().ccr(channel.index()).read().ccr()
         }
+
+        /// Set output compare preload.
+        fn set_output_compare_preload(&mut self, channel: Channel, preload: bool) {
+            let channel_index = channel.index();
+            Self::regs_gp16()
+                .ccmr_output(channel_index / 2)
+                .modify(|w| w.set_ocpe(channel_index % 2, preload));
+        }
     }
 
     /// Capture/Compare 16-bit timer instance with complementary pin support.
@@ -676,3 +689,6 @@ foreach_interrupt! {
         }
     };
 }
+
+// Update Event trigger DMA for every timer
+dma_trait!(UpDma, Basic16bitInstance);
diff --git a/embassy-stm32/src/timer/simple_pwm.rs b/embassy-stm32/src/timer/simple_pwm.rs
index e6072aa15..1819c7c55 100644
--- a/embassy-stm32/src/timer/simple_pwm.rs
+++ b/embassy-stm32/src/timer/simple_pwm.rs
@@ -55,11 +55,12 @@ channel_impl!(new_ch3, Ch3, Channel3Pin);
 channel_impl!(new_ch4, Ch4, Channel4Pin);
 
 /// Simple PWM driver.
-pub struct SimplePwm<'d, T> {
+pub struct SimplePwm<'d, T, Dma> {
     inner: PeripheralRef<'d, T>,
+    dma: PeripheralRef<'d, Dma>,
 }
 
-impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
+impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
     /// Create a new simple PWM driver.
     pub fn new(
         tim: impl Peripheral<P = T> + 'd,
@@ -69,16 +70,22 @@ impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
         _ch4: Option<PwmPin<'d, T, Ch4>>,
         freq: Hertz,
         counting_mode: CountingMode,
+        dma: impl Peripheral<P = Dma> + 'd,
     ) -> Self {
-        Self::new_inner(tim, freq, counting_mode)
+        Self::new_inner(tim, freq, counting_mode, dma)
     }
 
-    fn new_inner(tim: impl Peripheral<P = T> + 'd, freq: Hertz, counting_mode: CountingMode) -> Self {
-        into_ref!(tim);
+    fn new_inner(
+        tim: impl Peripheral<P = T> + 'd,
+        freq: Hertz,
+        counting_mode: CountingMode,
+        dma: impl Peripheral<P = Dma> + 'd,
+    ) -> Self {
+        into_ref!(tim, dma);
 
         T::enable_and_reset();
 
-        let mut this = Self { inner: tim };
+        let mut this = Self { inner: tim, dma };
 
         this.inner.set_counting_mode(counting_mode);
         this.set_frequency(freq);
@@ -86,14 +93,13 @@ impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
 
         this.inner.enable_outputs();
 
-        this.inner
-            .set_output_compare_mode(Channel::Ch1, OutputCompareMode::PwmMode1);
-        this.inner
-            .set_output_compare_mode(Channel::Ch2, OutputCompareMode::PwmMode1);
-        this.inner
-            .set_output_compare_mode(Channel::Ch3, OutputCompareMode::PwmMode1);
-        this.inner
-            .set_output_compare_mode(Channel::Ch4, OutputCompareMode::PwmMode1);
+        [Channel::Ch1, Channel::Ch2, Channel::Ch3, Channel::Ch4]
+            .iter()
+            .for_each(|&channel| {
+                this.inner.set_output_compare_mode(channel, OutputCompareMode::PwmMode1);
+                this.inner.set_output_compare_preload(channel, true)
+            });
+
         this
     }
 
@@ -141,7 +147,71 @@ impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
     }
 }
 
-impl<'d, T: CaptureCompare16bitInstance> embedded_hal_02::Pwm for SimplePwm<'d, T> {
+impl<'d, T: CaptureCompare16bitInstance + Basic16bitInstance, Dma> SimplePwm<'d, T, Dma>
+where
+    Dma: super::UpDma<T>,
+{
+    /// Generate a sequence of PWM waveform
+    pub async fn gen_waveform(&mut self, channel: Channel, duty: &[u16]) {
+        duty.iter().all(|v| v.le(&self.get_max_duty()));
+
+        self.inner.enable_update_dma(true);
+
+        #[cfg_attr(any(stm32f334, stm32f378), allow(clippy::let_unit_value))]
+        let req = self.dma.request();
+
+        self.enable(channel);
+
+        #[cfg(not(any(bdma, gpdma)))]
+        let dma_regs = self.dma.regs();
+        #[cfg(not(any(bdma, gpdma)))]
+        let isr_num = self.dma.num() / 4;
+        #[cfg(not(any(bdma, gpdma)))]
+        let isr_bit = self.dma.num() % 4;
+
+        #[cfg(not(any(bdma, gpdma)))]
+        // clean DMA FIFO error before a transfer
+        if dma_regs.isr(isr_num).read().feif(isr_bit) {
+            dma_regs.ifcr(isr_num).write(|v| v.set_feif(isr_bit, true));
+        }
+
+        unsafe {
+            #[cfg(not(any(bdma, gpdma)))]
+            use crate::dma::{Burst, FifoThreshold};
+            use crate::dma::{Transfer, TransferOptions};
+
+            let dma_transfer_option = TransferOptions {
+                #[cfg(not(any(bdma, gpdma)))]
+                fifo_threshold: Some(FifoThreshold::Full),
+                #[cfg(not(any(bdma, gpdma)))]
+                mburst: Burst::Incr8,
+                ..Default::default()
+            };
+
+            Transfer::new_write(
+                &mut self.dma,
+                req,
+                duty,
+                T::regs_gp16().ccr(channel.index()).as_ptr() as *mut _,
+                dma_transfer_option,
+            )
+            .await
+        };
+
+        self.disable(channel);
+
+        self.inner.enable_update_dma(false);
+
+        #[cfg(not(any(bdma, gpdma)))]
+        // Since DMA is closed before timer update event trigger DMA is turn off, it will almost always trigger a DMA FIFO error.
+        // Thus, we will always clean DMA FEIF after each transfer
+        if dma_regs.isr(isr_num).read().feif(isr_bit) {
+            dma_regs.ifcr(isr_num).write(|v| v.set_feif(isr_bit, true));
+        }
+    }
+}
+
+impl<'d, T: CaptureCompare16bitInstance, Dma> embedded_hal_02::Pwm for SimplePwm<'d, T, Dma> {
     type Channel = Channel;
     type Time = Hertz;
     type Duty = u16;
diff --git a/examples/stm32f4/src/bin/pwm.rs b/examples/stm32f4/src/bin/pwm.rs
index 8844a9f0e..92bc42ec8 100644
--- a/examples/stm32f4/src/bin/pwm.rs
+++ b/examples/stm32f4/src/bin/pwm.rs
@@ -3,6 +3,7 @@
 
 use defmt::*;
 use embassy_executor::Spawner;
+use embassy_stm32::dma;
 use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
@@ -16,7 +17,16 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PE9, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(p.TIM1, Some(ch1), None, None, None, khz(10), Default::default());
+    let mut pwm = SimplePwm::new(
+        p.TIM1,
+        Some(ch1),
+        None,
+        None,
+        None,
+        khz(10),
+        Default::default(),
+        dma::NoDma,
+    );
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 
diff --git a/examples/stm32f4/src/bin/ws2812_pwm_dma.rs b/examples/stm32f4/src/bin/ws2812_pwm.rs
similarity index 50%
rename from examples/stm32f4/src/bin/ws2812_pwm_dma.rs
rename to examples/stm32f4/src/bin/ws2812_pwm.rs
index 4458b643f..973743e49 100644
--- a/examples/stm32f4/src/bin/ws2812_pwm_dma.rs
+++ b/examples/stm32f4/src/bin/ws2812_pwm.rs
@@ -2,15 +2,9 @@
 // We assume the DIN pin of ws2812 connect to GPIO PB4, and ws2812 is properly powered.
 //
 // The idea is that the data rate of ws2812 is 800 kHz, and it use different duty ratio to represent bit 0 and bit 1.
-// Thus we can set TIM overflow at 800 kHz, and let TIM Update Event trigger a DMA transfer, then let DMA change CCR value,
-// such that pwm duty ratio meet the bit representation of ws2812.
+// Thus we can set TIM overflow at 800 kHz, and change duty ratio of TIM to meet the bit representation of ws2812.
 //
-// You may want to modify TIM CCR with Cortex core directly,
-// but according to my test, Cortex core will need to run far more than 100 MHz to catch up with TIM.
-// Thus we need to use a DMA.
-//
-// This demo is a combination of HAL, PAC, and manually invoke `dma::Transfer`.
-// If you need a simpler way to control ws2812, you may want to take a look at `ws2812_spi.rs` file, which make use of SPI.
+// you may also want to take a look at `ws2812_spi.rs` file, which make use of SPI instead.
 //
 // Warning:
 // DO NOT stare at ws2812 directy (especially after each MCU Reset), its (max) brightness could easily make your eyes feel burn.
@@ -20,7 +14,6 @@
 
 use embassy_executor::Spawner;
 use embassy_stm32::gpio::OutputType;
-use embassy_stm32::pac;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
 use embassy_stm32::timer::{Channel, CountingMode};
@@ -52,7 +45,7 @@ async fn main(_spawner: Spawner) {
         device_config.rcc.sys = Sysclk::PLL1_P;
     }
 
-    let mut dp = embassy_stm32::init(device_config);
+    let dp = embassy_stm32::init(device_config);
 
     let mut ws2812_pwm = SimplePwm::new(
         dp.TIM3,
@@ -62,6 +55,7 @@ async fn main(_spawner: Spawner) {
         None,
         khz(800), // data rate of ws2812
         CountingMode::EdgeAlignedUp,
+        dp.DMA1_CH2,
     );
 
     // construct ws2812 non-return-to-zero (NRZ) code bit by bit
@@ -89,62 +83,19 @@ async fn main(_spawner: Spawner) {
 
     let pwm_channel = Channel::Ch1;
 
-    // PAC level hacking, enable output compare preload
-    // keep output waveform integrity
-    pac::TIM3
-        .ccmr_output(pwm_channel.index())
-        .modify(|v| v.set_ocpe(0, true));
-
     // make sure PWM output keep low on first start
     ws2812_pwm.set_duty(pwm_channel, 0);
 
-    {
-        use embassy_stm32::dma::{Burst, FifoThreshold, Transfer, TransferOptions};
+    // flip color at 2 Hz
+    let mut ticker = Ticker::every(Duration::from_millis(500));
 
-        // configure FIFO and MBURST of DMA, to minimize DMA occupation on AHB/APB
-        let mut dma_transfer_option = TransferOptions::default();
-        dma_transfer_option.fifo_threshold = Some(FifoThreshold::Full);
-        dma_transfer_option.mburst = Burst::Incr8;
-
-        // flip color at 2 Hz
-        let mut ticker = Ticker::every(Duration::from_millis(500));
-
-        loop {
-            for &color in color_list {
-                // start PWM output
-                ws2812_pwm.enable(pwm_channel);
-
-                // PAC level hacking, enable timer-update-event trigger DMA
-                pac::TIM3.dier().modify(|v| v.set_ude(true));
-
-                unsafe {
-                    Transfer::new_write(
-                        // with &mut, we can easily reuse same DMA channel multiple times
-                        &mut dp.DMA1_CH2,
-                        5,
-                        color,
-                        pac::TIM3.ccr(pwm_channel.index()).as_ptr() as *mut _,
-                        dma_transfer_option,
-                    )
-                    .await;
-
-                    // Turn off timer-update-event trigger DMA as soon as possible.
-                    // Then clean the FIFO Error Flag if set.
-                    pac::TIM3.dier().modify(|v| v.set_ude(false));
-                    if pac::DMA1.isr(0).read().feif(2) {
-                        pac::DMA1.ifcr(0).write(|v| v.set_feif(2, true));
-                    }
-
-                    // ws2812 need at least 50 us low level input to confirm the input data and change it's state
-                    Timer::after_micros(50).await;
-                }
-
-                // stop PWM output for saving some energy
-                ws2812_pwm.disable(pwm_channel);
-
-                // wait until ticker tick
-                ticker.next().await;
-            }
+    loop {
+        for &color in color_list {
+            ws2812_pwm.gen_waveform(Channel::Ch1, color).await;
+            // ws2812 need at least 50 us low level input to confirm the input data and change it's state
+            Timer::after_micros(50).await;
+            // wait until ticker tick
+            ticker.next().await;
         }
     }
 }
diff --git a/examples/stm32g4/src/bin/pwm.rs b/examples/stm32g4/src/bin/pwm.rs
index d4809a481..9fa004c3e 100644
--- a/examples/stm32g4/src/bin/pwm.rs
+++ b/examples/stm32g4/src/bin/pwm.rs
@@ -3,6 +3,7 @@
 
 use defmt::*;
 use embassy_executor::Spawner;
+use embassy_stm32::dma;
 use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
@@ -16,7 +17,16 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PC0, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(p.TIM1, Some(ch1), None, None, None, khz(10), Default::default());
+    let mut pwm = SimplePwm::new(
+        p.TIM1,
+        Some(ch1),
+        None,
+        None,
+        None,
+        khz(10),
+        Default::default(),
+        dma::NoDma,
+    );
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 
diff --git a/examples/stm32h7/src/bin/pwm.rs b/examples/stm32h7/src/bin/pwm.rs
index 1e48ba67b..de155fc94 100644
--- a/examples/stm32h7/src/bin/pwm.rs
+++ b/examples/stm32h7/src/bin/pwm.rs
@@ -7,7 +7,7 @@ use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
 use embassy_stm32::timer::Channel;
-use embassy_stm32::Config;
+use embassy_stm32::{dma, Config};
 use embassy_time::Timer;
 use {defmt_rtt as _, panic_probe as _};
 
@@ -38,7 +38,16 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PA6, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(p.TIM3, Some(ch1), None, None, None, khz(10), Default::default());
+    let mut pwm = SimplePwm::new(
+        p.TIM3,
+        Some(ch1),
+        None,
+        None,
+        None,
+        khz(10),
+        Default::default(),
+        dma::NoDma,
+    );
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 

From 24f569821c7812714070a1ea3692b87100fc53e1 Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Fri, 29 Dec 2023 09:22:08 +0800
Subject: [PATCH 2/6] record&restore TIM OC to it's earlier state

---
 embassy-stm32/src/timer/mod.rs        | 10 +++++
 embassy-stm32/src/timer/simple_pwm.rs | 65 ++++++++++++++++++++-------
 2 files changed, 60 insertions(+), 15 deletions(-)

diff --git a/embassy-stm32/src/timer/mod.rs b/embassy-stm32/src/timer/mod.rs
index 05a0564a3..389666c40 100644
--- a/embassy-stm32/src/timer/mod.rs
+++ b/embassy-stm32/src/timer/mod.rs
@@ -99,6 +99,11 @@ pub(crate) mod sealed {
             Self::regs().dier().modify(|r| r.set_ude(enable));
         }
 
+        /// Get the update dma enable/disable state.
+        fn get_update_dma_state(&self) -> bool {
+            Self::regs().dier().read().ude()
+        }
+
         /// Enable/disable autoreload preload.
         fn set_autoreload_preload(&mut self, enable: bool) {
             Self::regs().cr1().modify(|r| r.set_arpe(enable));
@@ -274,6 +279,11 @@ pub(crate) mod sealed {
             Self::regs_gp16().ccer().modify(|w| w.set_cce(channel.index(), enable));
         }
 
+        /// Get enable/disable state of a channel
+        fn get_channel_enable_state(&self, channel: Channel) -> bool {
+            Self::regs_gp16().ccer().read().cce(channel.index())
+        }
+
         /// Set compare value for a channel.
         fn set_compare_value(&mut self, channel: Channel, value: u16) {
             Self::regs_gp16().ccr(channel.index()).modify(|w| w.set_ccr(value));
diff --git a/embassy-stm32/src/timer/simple_pwm.rs b/embassy-stm32/src/timer/simple_pwm.rs
index 1819c7c55..acf0d12f9 100644
--- a/embassy-stm32/src/timer/simple_pwm.rs
+++ b/embassy-stm32/src/timer/simple_pwm.rs
@@ -62,6 +62,12 @@ pub struct SimplePwm<'d, T, Dma> {
 
 impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
     /// Create a new simple PWM driver.
+    ///
+    /// Note:  
+    /// If you want to use [`Self::gen_waveform()`], you need to provide corresponding TIMx_UP DMA channel.
+    /// Otherwise you can just put a [`dma::NoDma`](crate::dma::NoDma)  
+    /// Currently, you can only use one channel at a time to generate waveform with [`Self::gen_waveform()`].  
+    /// But you can always use multiple TIM to generate multiple waveform simultaneously.
     pub fn new(
         tim: impl Peripheral<P = T> + 'd,
         _ch1: Option<PwmPin<'d, T, Ch1>>,
@@ -113,6 +119,11 @@ impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
         self.inner.enable_channel(channel, false);
     }
 
+    /// Check whether given channel is enabled
+    pub fn is_enabled(&self, channel: Channel) -> bool {
+        self.inner.get_channel_enable_state(channel)
+    }
+
     /// Set PWM frequency.
     ///
     /// Note: when you call this, the max duty value changes, so you will have to
@@ -141,6 +152,13 @@ impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
         self.inner.set_compare_value(channel, duty)
     }
 
+    /// Get the duty for a given channel.
+    ///
+    /// The value ranges from 0 for 0% duty, to [`get_max_duty`](Self::get_max_duty) for 100% duty, both included.
+    pub fn get_duty(&self, channel: Channel) -> u16 {
+        self.inner.get_compare_value(channel)
+    }
+
     /// Set the output polarity for a given channel.
     pub fn set_polarity(&mut self, channel: Channel, polarity: OutputPolarity) {
         self.inner.set_output_polarity(channel, polarity);
@@ -153,26 +171,38 @@ where
 {
     /// Generate a sequence of PWM waveform
     pub async fn gen_waveform(&mut self, channel: Channel, duty: &[u16]) {
-        duty.iter().all(|v| v.le(&self.get_max_duty()));
-
-        self.inner.enable_update_dma(true);
+        assert!(duty.iter().all(|v| *v <= self.get_max_duty()));
 
         #[cfg_attr(any(stm32f334, stm32f378), allow(clippy::let_unit_value))]
         let req = self.dma.request();
 
-        self.enable(channel);
-
         #[cfg(not(any(bdma, gpdma)))]
         let dma_regs = self.dma.regs();
         #[cfg(not(any(bdma, gpdma)))]
         let isr_num = self.dma.num() / 4;
         #[cfg(not(any(bdma, gpdma)))]
         let isr_bit = self.dma.num() % 4;
+        #[cfg(not(any(bdma, gpdma)))]
+        let isr_reg = dma_regs.isr(isr_num);
+        #[cfg(not(any(bdma, gpdma)))]
+        let ifcr_reg = dma_regs.ifcr(isr_num);
 
         #[cfg(not(any(bdma, gpdma)))]
         // clean DMA FIFO error before a transfer
-        if dma_regs.isr(isr_num).read().feif(isr_bit) {
-            dma_regs.ifcr(isr_num).write(|v| v.set_feif(isr_bit, true));
+        if isr_reg.read().feif(isr_bit) {
+            ifcr_reg.write(|v| v.set_feif(isr_bit, true));
+        }
+
+        let original_duty_state = self.get_duty(channel);
+        let original_enable_state = self.is_enabled(channel);
+        let original_update_dma_state = self.inner.get_update_dma_state();
+
+        if !original_update_dma_state {
+            self.inner.enable_update_dma(true);
+        }
+
+        if !original_enable_state {
+            self.enable(channel);
         }
 
         unsafe {
@@ -198,15 +228,20 @@ where
             .await
         };
 
-        self.disable(channel);
+        // restore output compare state
+        if !original_enable_state {
+            self.disable(channel);
+        }
+        self.set_duty(channel, original_duty_state);
+        if !original_update_dma_state {
+            self.inner.enable_update_dma(false);
 
-        self.inner.enable_update_dma(false);
-
-        #[cfg(not(any(bdma, gpdma)))]
-        // Since DMA is closed before timer update event trigger DMA is turn off, it will almost always trigger a DMA FIFO error.
-        // Thus, we will always clean DMA FEIF after each transfer
-        if dma_regs.isr(isr_num).read().feif(isr_bit) {
-            dma_regs.ifcr(isr_num).write(|v| v.set_feif(isr_bit, true));
+            #[cfg(not(any(bdma, gpdma)))]
+            // Since DMA could be closed before timer update event trigger DMA is turn off, this can almost always trigger a DMA FIFO error.
+            // Thus, we will try clean DMA FEIF after each transfer
+            if isr_reg.read().feif(isr_bit) {
+                ifcr_reg.write(|v| v.set_feif(isr_bit, true));
+            }
         }
     }
 }

From 873ee0615147b4a4e0aacd069ce8ac8df611bbbf Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Sat, 30 Dec 2023 12:01:08 +0800
Subject: [PATCH 3/6] some trivial fix

use less #[cfg] macro; reuse same variable
---
 embassy-stm32/src/timer/simple_pwm.rs  | 17 ++++++++---------
 examples/stm32f4/src/bin/ws2812_pwm.rs |  2 +-
 2 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/embassy-stm32/src/timer/simple_pwm.rs b/embassy-stm32/src/timer/simple_pwm.rs
index acf0d12f9..7a5475c31 100644
--- a/embassy-stm32/src/timer/simple_pwm.rs
+++ b/embassy-stm32/src/timer/simple_pwm.rs
@@ -177,15 +177,14 @@ where
         let req = self.dma.request();
 
         #[cfg(not(any(bdma, gpdma)))]
-        let dma_regs = self.dma.regs();
-        #[cfg(not(any(bdma, gpdma)))]
-        let isr_num = self.dma.num() / 4;
-        #[cfg(not(any(bdma, gpdma)))]
-        let isr_bit = self.dma.num() % 4;
-        #[cfg(not(any(bdma, gpdma)))]
-        let isr_reg = dma_regs.isr(isr_num);
-        #[cfg(not(any(bdma, gpdma)))]
-        let ifcr_reg = dma_regs.ifcr(isr_num);
+        let (isr_bit, isr_reg, ifcr_reg) = {
+            let dma_regs = self.dma.regs();
+            let isr_num = self.dma.num() / 4;
+            let isr_bit = self.dma.num() % 4;
+            let isr_reg = dma_regs.isr(isr_num);
+            let ifcr_reg = dma_regs.ifcr(isr_num);
+            (isr_bit, isr_reg, ifcr_reg)
+        };
 
         #[cfg(not(any(bdma, gpdma)))]
         // clean DMA FIFO error before a transfer
diff --git a/examples/stm32f4/src/bin/ws2812_pwm.rs b/examples/stm32f4/src/bin/ws2812_pwm.rs
index 973743e49..93a89f16a 100644
--- a/examples/stm32f4/src/bin/ws2812_pwm.rs
+++ b/examples/stm32f4/src/bin/ws2812_pwm.rs
@@ -91,7 +91,7 @@ async fn main(_spawner: Spawner) {
 
     loop {
         for &color in color_list {
-            ws2812_pwm.gen_waveform(Channel::Ch1, color).await;
+            ws2812_pwm.gen_waveform(pwm_channel, color).await;
             // ws2812 need at least 50 us low level input to confirm the input data and change it's state
             Timer::after_micros(50).await;
             // wait until ticker tick

From c276da5fcb93ce20da0c2f3bfccdeb7e0fee67a7 Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Tue, 2 Jan 2024 13:30:13 +0800
Subject: [PATCH 4/6] ask a DMA Channel only when use .gen_waveform()

---
 embassy-stm32/src/timer/simple_pwm.rs  | 78 +++++++++-----------------
 examples/stm32f4/src/bin/pwm.rs        | 12 +---
 examples/stm32f4/src/bin/ws2812_pwm.rs |  6 +-
 examples/stm32g4/src/bin/pwm.rs        | 12 +---
 examples/stm32h7/src/bin/pwm.rs        | 13 +----
 5 files changed, 35 insertions(+), 86 deletions(-)

diff --git a/embassy-stm32/src/timer/simple_pwm.rs b/embassy-stm32/src/timer/simple_pwm.rs
index 7a5475c31..77d902e35 100644
--- a/embassy-stm32/src/timer/simple_pwm.rs
+++ b/embassy-stm32/src/timer/simple_pwm.rs
@@ -55,19 +55,12 @@ channel_impl!(new_ch3, Ch3, Channel3Pin);
 channel_impl!(new_ch4, Ch4, Channel4Pin);
 
 /// Simple PWM driver.
-pub struct SimplePwm<'d, T, Dma> {
+pub struct SimplePwm<'d, T> {
     inner: PeripheralRef<'d, T>,
-    dma: PeripheralRef<'d, Dma>,
 }
 
-impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
+impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
     /// Create a new simple PWM driver.
-    ///
-    /// Note:  
-    /// If you want to use [`Self::gen_waveform()`], you need to provide corresponding TIMx_UP DMA channel.
-    /// Otherwise you can just put a [`dma::NoDma`](crate::dma::NoDma)  
-    /// Currently, you can only use one channel at a time to generate waveform with [`Self::gen_waveform()`].  
-    /// But you can always use multiple TIM to generate multiple waveform simultaneously.
     pub fn new(
         tim: impl Peripheral<P = T> + 'd,
         _ch1: Option<PwmPin<'d, T, Ch1>>,
@@ -76,22 +69,16 @@ impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
         _ch4: Option<PwmPin<'d, T, Ch4>>,
         freq: Hertz,
         counting_mode: CountingMode,
-        dma: impl Peripheral<P = Dma> + 'd,
     ) -> Self {
-        Self::new_inner(tim, freq, counting_mode, dma)
+        Self::new_inner(tim, freq, counting_mode)
     }
 
-    fn new_inner(
-        tim: impl Peripheral<P = T> + 'd,
-        freq: Hertz,
-        counting_mode: CountingMode,
-        dma: impl Peripheral<P = Dma> + 'd,
-    ) -> Self {
-        into_ref!(tim, dma);
+    fn new_inner(tim: impl Peripheral<P = T> + 'd, freq: Hertz, counting_mode: CountingMode) -> Self {
+        into_ref!(tim);
 
         T::enable_and_reset();
 
-        let mut this = Self { inner: tim, dma };
+        let mut this = Self { inner: tim };
 
         this.inner.set_counting_mode(counting_mode);
         this.set_frequency(freq);
@@ -165,32 +152,23 @@ impl<'d, T: CaptureCompare16bitInstance, Dma> SimplePwm<'d, T, Dma> {
     }
 }
 
-impl<'d, T: CaptureCompare16bitInstance + Basic16bitInstance, Dma> SimplePwm<'d, T, Dma>
-where
-    Dma: super::UpDma<T>,
-{
+impl<'d, T: CaptureCompare16bitInstance + Basic16bitInstance> SimplePwm<'d, T> {
     /// Generate a sequence of PWM waveform
-    pub async fn gen_waveform(&mut self, channel: Channel, duty: &[u16]) {
+    ///
+    /// Note:  
+    /// you will need to provide corresponding TIMx_UP DMA channel to use this method.
+    pub async fn gen_waveform(
+        &mut self,
+        dma: impl Peripheral<P = impl super::UpDma<T>>,
+        channel: Channel,
+        duty: &[u16],
+    ) {
         assert!(duty.iter().all(|v| *v <= self.get_max_duty()));
 
-        #[cfg_attr(any(stm32f334, stm32f378), allow(clippy::let_unit_value))]
-        let req = self.dma.request();
+        into_ref!(dma);
 
-        #[cfg(not(any(bdma, gpdma)))]
-        let (isr_bit, isr_reg, ifcr_reg) = {
-            let dma_regs = self.dma.regs();
-            let isr_num = self.dma.num() / 4;
-            let isr_bit = self.dma.num() % 4;
-            let isr_reg = dma_regs.isr(isr_num);
-            let ifcr_reg = dma_regs.ifcr(isr_num);
-            (isr_bit, isr_reg, ifcr_reg)
-        };
-
-        #[cfg(not(any(bdma, gpdma)))]
-        // clean DMA FIFO error before a transfer
-        if isr_reg.read().feif(isr_bit) {
-            ifcr_reg.write(|v| v.set_feif(isr_bit, true));
-        }
+        #[allow(clippy::let_unit_value)] // eg. stm32f334
+        let req = dma.request();
 
         let original_duty_state = self.get_duty(channel);
         let original_enable_state = self.is_enabled(channel);
@@ -218,7 +196,7 @@ where
             };
 
             Transfer::new_write(
-                &mut self.dma,
+                &mut dma,
                 req,
                 duty,
                 T::regs_gp16().ccr(channel.index()).as_ptr() as *mut _,
@@ -231,21 +209,21 @@ where
         if !original_enable_state {
             self.disable(channel);
         }
+
         self.set_duty(channel, original_duty_state);
+
+        // Since DMA is closed before timer update event trigger DMA is turn off,
+        // this can almost always trigger a DMA FIFO error.
+        //
+        // optional TODO:
+        // clean FEIF after disable UDE
         if !original_update_dma_state {
             self.inner.enable_update_dma(false);
-
-            #[cfg(not(any(bdma, gpdma)))]
-            // Since DMA could be closed before timer update event trigger DMA is turn off, this can almost always trigger a DMA FIFO error.
-            // Thus, we will try clean DMA FEIF after each transfer
-            if isr_reg.read().feif(isr_bit) {
-                ifcr_reg.write(|v| v.set_feif(isr_bit, true));
-            }
         }
     }
 }
 
-impl<'d, T: CaptureCompare16bitInstance, Dma> embedded_hal_02::Pwm for SimplePwm<'d, T, Dma> {
+impl<'d, T: CaptureCompare16bitInstance> embedded_hal_02::Pwm for SimplePwm<'d, T> {
     type Channel = Channel;
     type Time = Hertz;
     type Duty = u16;
diff --git a/examples/stm32f4/src/bin/pwm.rs b/examples/stm32f4/src/bin/pwm.rs
index 92bc42ec8..8844a9f0e 100644
--- a/examples/stm32f4/src/bin/pwm.rs
+++ b/examples/stm32f4/src/bin/pwm.rs
@@ -3,7 +3,6 @@
 
 use defmt::*;
 use embassy_executor::Spawner;
-use embassy_stm32::dma;
 use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
@@ -17,16 +16,7 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PE9, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(
-        p.TIM1,
-        Some(ch1),
-        None,
-        None,
-        None,
-        khz(10),
-        Default::default(),
-        dma::NoDma,
-    );
+    let mut pwm = SimplePwm::new(p.TIM1, Some(ch1), None, None, None, khz(10), Default::default());
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 
diff --git a/examples/stm32f4/src/bin/ws2812_pwm.rs b/examples/stm32f4/src/bin/ws2812_pwm.rs
index 93a89f16a..239709253 100644
--- a/examples/stm32f4/src/bin/ws2812_pwm.rs
+++ b/examples/stm32f4/src/bin/ws2812_pwm.rs
@@ -45,7 +45,7 @@ async fn main(_spawner: Spawner) {
         device_config.rcc.sys = Sysclk::PLL1_P;
     }
 
-    let dp = embassy_stm32::init(device_config);
+    let mut dp = embassy_stm32::init(device_config);
 
     let mut ws2812_pwm = SimplePwm::new(
         dp.TIM3,
@@ -55,7 +55,6 @@ async fn main(_spawner: Spawner) {
         None,
         khz(800), // data rate of ws2812
         CountingMode::EdgeAlignedUp,
-        dp.DMA1_CH2,
     );
 
     // construct ws2812 non-return-to-zero (NRZ) code bit by bit
@@ -91,7 +90,8 @@ async fn main(_spawner: Spawner) {
 
     loop {
         for &color in color_list {
-            ws2812_pwm.gen_waveform(pwm_channel, color).await;
+            // with &mut, we can easily reuse same DMA channel multiple times
+            ws2812_pwm.gen_waveform(&mut dp.DMA1_CH2, pwm_channel, color).await;
             // ws2812 need at least 50 us low level input to confirm the input data and change it's state
             Timer::after_micros(50).await;
             // wait until ticker tick
diff --git a/examples/stm32g4/src/bin/pwm.rs b/examples/stm32g4/src/bin/pwm.rs
index 9fa004c3e..d4809a481 100644
--- a/examples/stm32g4/src/bin/pwm.rs
+++ b/examples/stm32g4/src/bin/pwm.rs
@@ -3,7 +3,6 @@
 
 use defmt::*;
 use embassy_executor::Spawner;
-use embassy_stm32::dma;
 use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
@@ -17,16 +16,7 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PC0, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(
-        p.TIM1,
-        Some(ch1),
-        None,
-        None,
-        None,
-        khz(10),
-        Default::default(),
-        dma::NoDma,
-    );
+    let mut pwm = SimplePwm::new(p.TIM1, Some(ch1), None, None, None, khz(10), Default::default());
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 
diff --git a/examples/stm32h7/src/bin/pwm.rs b/examples/stm32h7/src/bin/pwm.rs
index de155fc94..1e48ba67b 100644
--- a/examples/stm32h7/src/bin/pwm.rs
+++ b/examples/stm32h7/src/bin/pwm.rs
@@ -7,7 +7,7 @@ use embassy_stm32::gpio::OutputType;
 use embassy_stm32::time::khz;
 use embassy_stm32::timer::simple_pwm::{PwmPin, SimplePwm};
 use embassy_stm32::timer::Channel;
-use embassy_stm32::{dma, Config};
+use embassy_stm32::Config;
 use embassy_time::Timer;
 use {defmt_rtt as _, panic_probe as _};
 
@@ -38,16 +38,7 @@ async fn main(_spawner: Spawner) {
     info!("Hello World!");
 
     let ch1 = PwmPin::new_ch1(p.PA6, OutputType::PushPull);
-    let mut pwm = SimplePwm::new(
-        p.TIM3,
-        Some(ch1),
-        None,
-        None,
-        None,
-        khz(10),
-        Default::default(),
-        dma::NoDma,
-    );
+    let mut pwm = SimplePwm::new(p.TIM3, Some(ch1), None, None, None, khz(10), Default::default());
     let max = pwm.get_max_duty();
     pwm.enable(Channel::Ch1);
 

From cad4efe57f9817b9368bb431dd12f18d05030c9f Mon Sep 17 00:00:00 2001
From: Dario Nieuwenhuis <dirbaio@dirbaio.net>
Date: Tue, 2 Jan 2024 17:28:08 +0100
Subject: [PATCH 5/6] stm32/timer: add missing supertrait bounds.

---
 embassy-stm32/src/timer/mod.rs | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/embassy-stm32/src/timer/mod.rs b/embassy-stm32/src/timer/mod.rs
index 389666c40..d07fd2776 100644
--- a/embassy-stm32/src/timer/mod.rs
+++ b/embassy-stm32/src/timer/mod.rs
@@ -558,13 +558,16 @@ impl From<OutputPolarity> for bool {
 pub trait Basic16bitInstance: sealed::Basic16bitInstance + 'static {}
 
 /// Gneral-purpose 16-bit timer instance.
-pub trait GeneralPurpose16bitInstance: sealed::GeneralPurpose16bitInstance + 'static {}
+pub trait GeneralPurpose16bitInstance: sealed::GeneralPurpose16bitInstance + Basic16bitInstance + 'static {}
 
 /// Gneral-purpose 32-bit timer instance.
-pub trait GeneralPurpose32bitInstance: sealed::GeneralPurpose32bitInstance + 'static {}
+pub trait GeneralPurpose32bitInstance:
+    sealed::GeneralPurpose32bitInstance + GeneralPurpose16bitInstance + 'static
+{
+}
 
 /// Advanced control timer instance.
-pub trait AdvancedControlInstance: sealed::AdvancedControlInstance + 'static {}
+pub trait AdvancedControlInstance: sealed::AdvancedControlInstance + GeneralPurpose16bitInstance + 'static {}
 
 /// Capture/Compare 16-bit timer instance.
 pub trait CaptureCompare16bitInstance:
@@ -574,7 +577,7 @@ pub trait CaptureCompare16bitInstance:
 
 /// Capture/Compare 16-bit timer instance with complementary pin support.
 pub trait ComplementaryCaptureCompare16bitInstance:
-    sealed::ComplementaryCaptureCompare16bitInstance + AdvancedControlInstance + 'static
+    sealed::ComplementaryCaptureCompare16bitInstance + CaptureCompare16bitInstance + AdvancedControlInstance + 'static
 {
 }
 

From 638aa313d4e5e80649f7f6201fef3154e5b2bbd5 Mon Sep 17 00:00:00 2001
From: Dario Nieuwenhuis <dirbaio@dirbaio.net>
Date: Tue, 2 Jan 2024 17:28:23 +0100
Subject: [PATCH 6/6] stm32/pwm: simplify impl blocks.

---
 embassy-stm32/src/timer/simple_pwm.rs | 2 --
 1 file changed, 2 deletions(-)

diff --git a/embassy-stm32/src/timer/simple_pwm.rs b/embassy-stm32/src/timer/simple_pwm.rs
index 77d902e35..80f10424c 100644
--- a/embassy-stm32/src/timer/simple_pwm.rs
+++ b/embassy-stm32/src/timer/simple_pwm.rs
@@ -150,9 +150,7 @@ impl<'d, T: CaptureCompare16bitInstance> SimplePwm<'d, T> {
     pub fn set_polarity(&mut self, channel: Channel, polarity: OutputPolarity) {
         self.inner.set_output_polarity(channel, polarity);
     }
-}
 
-impl<'d, T: CaptureCompare16bitInstance + Basic16bitInstance> SimplePwm<'d, T> {
     /// Generate a sequence of PWM waveform
     ///
     /// Note: