From c9f759bb21782eb0487c96a59500310d1283694c Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Sat, 16 Mar 2024 21:20:17 +0800
Subject: [PATCH] stm32 CORDIC: ZeroOverhead for q1.31 and q1.15

---
 embassy-stm32/src/cordic/enums.rs |  13 -
 embassy-stm32/src/cordic/mod.rs   | 468 ++++++++++++++++++------------
 embassy-stm32/src/cordic/utils.rs |   4 +-
 3 files changed, 278 insertions(+), 207 deletions(-)

diff --git a/embassy-stm32/src/cordic/enums.rs b/embassy-stm32/src/cordic/enums.rs
index 3e1c47f7f..37c73f549 100644
--- a/embassy-stm32/src/cordic/enums.rs
+++ b/embassy-stm32/src/cordic/enums.rs
@@ -68,16 +68,3 @@ pub enum Width {
     Bits32,
     Bits16,
 }
-
-/// Cordic driver running mode
-#[derive(Clone, Copy)]
-pub enum Mode {
-    /// After caculation start, a read to RDATA register will block AHB until the caculation finished
-    ZeroOverhead,
-
-    /// Use CORDIC interrupt to trigger a read result value
-    Interrupt,
-
-    /// Use DMA to write/read value
-    Dma,
-}
diff --git a/embassy-stm32/src/cordic/mod.rs b/embassy-stm32/src/cordic/mod.rs
index 997ace113..61277d7e1 100644
--- a/embassy-stm32/src/cordic/mod.rs
+++ b/embassy-stm32/src/cordic/mod.rs
@@ -1,8 +1,9 @@
 //! CORDIC co-processor
 
-use crate::peripherals;
 use embassy_hal_internal::{into_ref, Peripheral, PeripheralRef};
 
+use crate::peripherals;
+
 mod enums;
 pub use enums::*;
 
@@ -10,10 +11,6 @@ pub mod utils;
 
 pub(crate) mod sealed;
 
-// length of pre-allocated [u32] memory for CORDIC input,
-// length should be multiple of 2
-const INPUT_BUF_LEN: usize = 8;
-
 /// Low-level CORDIC access.
 #[cfg(feature = "unstable-pac")]
 pub mod low_level {
@@ -31,30 +28,16 @@ pub trait Instance: sealed::Instance + Peripheral<P = Self> + crate::rcc::RccPer
 
 /// CORDIC configuration
 pub struct Config {
-    mode: Mode,
     function: Function,
     precision: Precision,
     scale: Scale,
     first_result: bool,
 }
 
-// CORDIC running state
-struct State {
-    input_buf: [u32; INPUT_BUF_LEN],
-    buf_index: usize,
-}
-
 impl Config {
     /// Create a config for Cordic driver
-    pub fn new(
-        mode: Mode,
-        function: Function,
-        precision: Option<Precision>,
-        scale: Option<Scale>,
-        first_result: bool,
-    ) -> Self {
+    pub fn new(function: Function, precision: Option<Precision>, scale: Option<Scale>, first_result: bool) -> Self {
         Self {
-            mode,
             function,
             precision: precision.unwrap_or_default(),
             scale: scale.unwrap_or_default(),
@@ -133,22 +116,123 @@ impl<'d, T: Instance> Cordic<'d, T> {
         } else {
             self.peri.set_result_count(Count::Two)
         }
+    }
 
-        match self.config.mode {
-            Mode::ZeroOverhead => (),
-            Mode::Interrupt => {
-                self.peri.enable_irq();
-            }
-            Mode::Dma => {
-                self.peri.enable_write_dma();
-                self.peri.enable_read_dma();
-            }
+    fn blocking_read_f32(&mut self) -> (f32, Option<f32>) {
+        let reg_value = self.peri.read_result();
+
+        let res1 = utils::q1_15_to_f32((reg_value & ((1u32 << 16) - 1)) as u16);
+
+        // We don't care about whether the function return 1 or 2 results,
+        // the only thing matter is whether user want 1 or 2 results.
+        let res2 = if !self.config.first_result {
+            Some(utils::q1_15_to_f32((reg_value >> 16) as u16))
+        } else {
+            None
+        };
+
+        (res1, res2)
+    }
+}
+
+impl<'d, T: Instance> Drop for Cordic<'d, T> {
+    fn drop(&mut self) {
+        T::disable();
+    }
+}
+
+// q1.31 related
+impl<'d, T: Instance> Cordic<'d, T> {
+    /// Run a CORDIC calculation
+    pub fn blocking_calc_32bit(&mut self, arg1s: &[f64], arg2s: Option<&[f64]>, output: &mut [f64]) -> usize {
+        if arg1s.is_empty() {
+            return 0;
         }
+
+        assert!(
+            match self.config.first_result {
+                true => output.len() >= arg1s.len(),
+                false => output.len() >= 2 * arg1s.len(),
+            },
+            "Output buf length is not long enough"
+        );
+
+        self.check_input_f64(arg1s, arg2s);
+
+        self.peri.disable_irq();
+        self.peri.disable_write_dma();
+        self.peri.disable_read_dma();
+
+        self.peri.set_result_count(if self.config.first_result {
+            Count::One
+        } else {
+            Count::Two
+        });
+
+        self.peri.set_data_width(Width::Bits32, Width::Bits32);
+
+        let mut output_count = 0;
+
+        let mut consumed_input_len = 0;
+
+        // put double input into cordic
+        if arg2s.is_some() && !arg2s.expect("It's infailable").is_empty() {
+            let arg2s = arg2s.expect("It's infailable");
+
+            self.peri.set_argument_count(Count::Two);
+
+            // Skip 1st value from arg1s, this value will be manually "preload" to cordic, to make use of cordic preload function.
+            // And we preserve last value from arg2s, since it need to manually write to cordic, and read the result out.
+            let double_input = arg1s.iter().skip(1).zip(&arg2s[..arg2s.len() - 1]);
+            // Since we preload 1st value from arg1s, the consumed input length is double_input length + 1.
+            consumed_input_len = double_input.len() + 1;
+
+            // preload first value from arg1 to cordic
+            self.blocking_write_f64(arg1s[0]);
+
+            for (&arg1, &arg2) in double_input {
+                // Since we manually preload a value before,
+                // we will write arg2 (from the actual last pair) first, (at this moment, cordic start to calculating,)
+                // and write arg1 (from the actual next pair), then read the result, to "keep preloading"
+
+                self.blocking_write_f64(arg2);
+                self.blocking_write_f64(arg1);
+                self.blocking_read_f64_to_buf(output, &mut output_count);
+            }
+
+            // write last input value from arg2s, then read out the result
+            self.blocking_write_f64(arg2s[arg2s.len() - 1]);
+            self.blocking_read_f64_to_buf(output, &mut output_count);
+        }
+
+        // put single input into cordic
+        let input_left = &arg1s[consumed_input_len..];
+
+        if !input_left.is_empty() {
+            self.peri.set_argument_count(Count::One);
+
+            // "preload" value to cordic (at this moment, cordic start to calculating)
+            self.blocking_write_f64(input_left[0]);
+
+            for &arg in input_left.iter().skip(1) {
+                // this line write arg for next round caculation to cordic,
+                // and read result from last round
+                self.blocking_write_f64(arg);
+                self.blocking_read_f64_to_buf(output, &mut output_count);
+            }
+
+            // read the last output
+            self.blocking_read_f64_to_buf(output, &mut output_count);
+        }
+
+        output_count
     }
 
     fn blocking_read_f64(&mut self) -> (f64, Option<f64>) {
         let res1 = utils::q1_31_to_f64(self.peri.read_result());
 
+        // We don't care about whether the function return 1 or 2 results,
+        // the only thing matter is whether user want 1 or 2 results.
         let res2 = if !self.config.first_result {
             Some(utils::q1_31_to_f64(self.peri.read_result()))
         } else {
@@ -174,16 +258,14 @@ impl<'d, T: Instance> Cordic<'d, T> {
     }
 }
 
-impl<'d, T: Instance> Drop for Cordic<'d, T> {
-    fn drop(&mut self) {
-        T::disable();
-    }
-}
-
-// q1.31 related
+// q1.15 related
 impl<'d, T: Instance> Cordic<'d, T> {
     /// Run a CORDIC calculation
-    pub fn calc_32bit(&mut self, arg1s: &[f64], arg2s: Option<&[f64]>, output: &mut [f64]) -> usize {
+    pub fn blocking_calc_16bit(&mut self, arg1s: &[f32], arg2s: Option<&[f32]>, output: &mut [f32]) -> usize {
+        if arg1s.is_empty() {
+            return 0;
+        }
+
         assert!(
             match self.config.first_result {
                 true => output.len() >= arg1s.len(),
@@ -192,180 +274,182 @@ impl<'d, T: Instance> Cordic<'d, T> {
             "Output buf length is not long enough"
         );
 
-        self.check_input_f64(arg1s, arg2s);
+        self.check_input_f32(arg1s, arg2s);
 
-        self.peri.set_result_count(if self.config.first_result {
-            Count::One
-        } else {
-            Count::Two
-        });
+        self.peri.disable_irq();
+        self.peri.disable_write_dma();
+        self.peri.disable_read_dma();
 
-        self.peri.set_data_width(Width::Bits32, Width::Bits32);
+        // In q1.15 mode, 1 write/read to access 2 arguments/results
+        self.peri.set_argument_count(Count::One);
+        self.peri.set_result_count(Count::One);
+
+        self.peri.set_data_width(Width::Bits16, Width::Bits16);
 
         let mut output_count = 0;
 
-        let mut consumed_input_len = 0;
+        // In q1.15 mode, we always fill 1 pair of 16bit value into WDATA register.
+        // If arg2s is None or empty array, we assume arg2 value always 1.0 (as reset value for ARG2).
+        // If arg2s has some value, and but not as long as arg1s,
+        // we fill the reset of arg2 values with last value from arg2s (as q1.31 version does)
 
-        match self.config.mode {
-            Mode::ZeroOverhead => {
-                // put double input into cordic
-                if arg2s.is_some() && !arg2s.unwrap().is_empty() {
-                    let arg2s = arg2s.unwrap();
+        let arg2_default_value = match arg2s {
+            Some(arg2s) if !arg2s.is_empty() => arg2s[arg2s.len() - 1],
+            _ => 1.0,
+        };
 
-                    self.peri.set_argument_count(Count::Two);
+        let mut args = arg1s.iter().zip(
+            arg2s
+                .unwrap_or(&[])
+                .iter()
+                .chain(core::iter::repeat(&arg2_default_value)),
+        );
 
-                    // Skip 1st value from arg1s, this value will be manually "preload" to cordic, to make use of cordic preload function.
-                    // And we preserve last value from arg2s, since it need to manually write to cordic, and read the result out.
-                    let double_input = arg1s.iter().skip(1).zip(&arg2s[..arg2s.len() - 1]);
-                    // Since we preload 1st value from arg1s, the consumed input length is double_input length + 1.
-                    consumed_input_len = double_input.len() + 1;
+        let (&arg1, &arg2) = args
+            .next()
+            .expect("This should be infallible, since arg1s is not empty");
 
-                    // preload first value from arg1 to cordic
-                    self.blocking_write_f64(arg1s[0]);
+        // preloading 1 pair of arguments
+        self.blocking_write_f32(arg1, arg2);
 
-                    for (&arg1, &arg2) in double_input {
-                        // Since we manually preload a value before,
-                        // we will write arg2 (from the actual last pair) first, (at this moment, cordic start to calculating,)
-                        // and write arg1 (from the actual next pair), then read the result, to "keep preloading"
-
-                        self.blocking_write_f64(arg2);
-                        self.blocking_write_f64(arg1);
-                        self.blocking_read_f64_to_buf(output, &mut output_count);
-                    }
-
-                    // write last input value from arg2s, then read out the result
-                    self.blocking_write_f64(arg2s[arg2s.len() - 1]);
-                    self.blocking_read_f64_to_buf(output, &mut output_count);
-                }
-
-                // put single input into cordic
-                let input_left = &arg1s[consumed_input_len..];
-
-                if !input_left.is_empty() {
-                    self.peri.set_argument_count(Count::One);
-
-                    // "preload" value to cordic (at this moment, cordic start to calculating)
-                    self.blocking_write_f64(input_left[0]);
-
-                    for &arg in input_left.iter().skip(1) {
-                        // this line write arg for next round caculation to cordic,
-                        // and read result from last round
-                        self.blocking_write_f64(arg);
-                        self.blocking_read_f64_to_buf(output, &mut output_count);
-                    }
-
-                    // read the last output
-                    self.blocking_read_f64_to_buf(output, &mut output_count);
-                }
-
-                output_count
-            }
-            Mode::Interrupt => todo!(),
-            Mode::Dma => todo!(),
+        for (&arg1, &arg2) in args {
+            self.blocking_write_f32(arg1, arg2);
+            self.blocking_read_f32_to_buf(output, &mut output_count);
         }
+
+        // read last pair of value from cordic
+        self.blocking_read_f32_to_buf(output, &mut output_count);
+
+        output_count
     }
 
-    fn check_input_f64(&self, arg1s: &[f64], arg2s: Option<&[f64]>) {
-        let config = &self.config;
+    fn blocking_write_f32(&mut self, arg1: f32, arg2: f32) {
+        let reg_value: u32 = utils::f32_to_q1_15(arg1) as u32 + ((utils::f32_to_q1_15(arg2) as u32) << 16);
+        self.peri.write_argument(reg_value);
+    }
 
-        use Function::*;
+    fn blocking_read_f32_to_buf(&mut self, result_buf: &mut [f32], result_index: &mut usize) {
+        let (res1, res2) = self.blocking_read_f32();
+        result_buf[*result_index] = res1;
+        *result_index += 1;
 
-        // check SCALE value
-        match config.function {
-            Cos | Sin | Phase | Modulus => assert!(Scale::A1_R1 == config.scale, "SCALE should be 0"),
-            Arctan => assert!(
-                (0..=7).contains(&(config.scale as u8)),
-                "SCALE should be: 0 <= SCALE <= 7"
-            ),
-            Cosh | Sinh | Arctanh => assert!(Scale::A1o2_R2 == config.scale, "SCALE should be 1"),
-
-            Ln => assert!(
-                (1..=4).contains(&(config.scale as u8)),
-                "SCALE should be: 1 <= SCALE <= 4"
-            ),
-            Sqrt => assert!(
-                (0..=2).contains(&(config.scale as u8)),
-                "SCALE should be: 0 <= SCALE <= 2"
-            ),
-        }
-
-        // check ARG1 value
-        match config.function {
-            Cos | Sin | Phase | Modulus | Arctan => {
-                assert!(
-                    arg1s.iter().all(|v| (-1.0..=1.0).contains(v)),
-                    "ARG1 should be: -1 <= ARG1 <= 1"
-                );
-            }
-
-            Cosh | Sinh => assert!(
-                arg1s.iter().all(|v| (-0.559..=0.559).contains(v)),
-                "ARG1 should be: -0.559 <= ARG1 <= 0.559"
-            ),
-
-            Arctanh => assert!(
-                arg1s.iter().all(|v| (-0.403..=0.403).contains(v)),
-                "ARG1 should be: -0.403 <= ARG1 <= 0.403"
-            ),
-
-            Ln => {
-                match config.scale {
-                    Scale::A1o2_R2 => assert!(
-                        arg1s.iter().all(|v| (0.05354..0.5).contains(v)),
-                        "When SCALE set to 1, ARG1 should be: 0.05354 <= ARG1 < 0.5"
-                    ),
-                    Scale::A1o4_R4 => assert!(
-                        arg1s.iter().all(|v| (0.25..0.75).contains(v)),
-                        "When SCALE set to 2, ARG1 should be: 0.25 <= ARG1 < 0.75"
-                    ),
-                    Scale::A1o8_R8 => assert!(
-                        arg1s.iter().all(|v| (0.375..0.875).contains(v)),
-                        "When SCALE set to 3, ARG1 should be: 0.375 <= ARG1 < 0.875"
-                    ),
-                    Scale::A1o16_R16 => assert!(
-                        arg1s.iter().all(|v| (0.4375f64..0.584f64).contains(v)),
-                        "When SCALE set to 4, ARG1 should be: 0.4375 <= ARG1 < 0.584"
-                    ),
-                    _ => unreachable!(),
-                };
-            }
-
-            Function::Sqrt => match config.scale {
-                Scale::A1_R1 => assert!(
-                    arg1s.iter().all(|v| (0.027..0.75).contains(v)),
-                    "When SCALE set to 0, ARG1 should be: 0.027 <= ARG1 < 0.75"
-                ),
-                Scale::A1o2_R2 => assert!(
-                    arg1s.iter().all(|v| (0.375..0.875).contains(v)),
-                    "When SCALE set to 1, ARG1 should be: 0.375 <= ARG1 < 0.875"
-                ),
-                Scale::A1o4_R4 => assert!(
-                    arg1s.iter().all(|v| (0.4375..0.585).contains(v)),
-                    "When SCALE set to 2, ARG1 should be: 0.4375  <= ARG1 < 0.585"
-                ),
-                _ => unreachable!(),
-            },
-        }
-
-        // check ARG2 value
-        if let Some(arg2s) = arg2s {
-            match config.function {
-                Cos | Sin => assert!(
-                    arg2s.iter().all(|v| (0.0..=1.0).contains(v)),
-                    "ARG2 should be: 0 <= ARG2 <= 1"
-                ),
-
-                Phase | Modulus => assert!(
-                    arg2s.iter().all(|v| (-1.0..=1.0).contains(v)),
-                    "ARG2 should be: -1 <= ARG2 <= 1"
-                ),
-
-                _ => (),
-            }
+        if let Some(res2) = res2 {
+            result_buf[*result_index] = res2;
+            *result_index += 1;
         }
     }
 }
 
+// check input value ARG1, ARG2, SCALE and FUNCTION are compatible with each other
+macro_rules! check_input_value {
+    ($func_name:ident, $float_type:ty) => {
+        impl<'d, T: Instance> Cordic<'d, T> {
+            fn $func_name(&self, arg1s: &[$float_type], arg2s: Option<&[$float_type]>) {
+                let config = &self.config;
+
+                use Function::*;
+
+                // check SCALE value
+                match config.function {
+                    Cos | Sin | Phase | Modulus => assert!(Scale::A1_R1 == config.scale, "SCALE should be 0"),
+                    Arctan => assert!(
+                        (0..=7).contains(&(config.scale as u8)),
+                        "SCALE should be: 0 <= SCALE <= 7"
+                    ),
+                    Cosh | Sinh | Arctanh => assert!(Scale::A1o2_R2 == config.scale, "SCALE should be 1"),
+
+                    Ln => assert!(
+                        (1..=4).contains(&(config.scale as u8)),
+                        "SCALE should be: 1 <= SCALE <= 4"
+                    ),
+                    Sqrt => assert!(
+                        (0..=2).contains(&(config.scale as u8)),
+                        "SCALE should be: 0 <= SCALE <= 2"
+                    ),
+                }
+
+                // check ARG1 value
+                match config.function {
+                    Cos | Sin | Phase | Modulus | Arctan => {
+                        assert!(
+                            arg1s.iter().all(|v| (-1.0..=1.0).contains(v)),
+                            "ARG1 should be: -1 <= ARG1 <= 1"
+                        );
+                    }
+
+                    Cosh | Sinh => assert!(
+                        arg1s.iter().all(|v| (-0.559..=0.559).contains(v)),
+                        "ARG1 should be: -0.559 <= ARG1 <= 0.559"
+                    ),
+
+                    Arctanh => assert!(
+                        arg1s.iter().all(|v| (-0.403..=0.403).contains(v)),
+                        "ARG1 should be: -0.403 <= ARG1 <= 0.403"
+                    ),
+
+                    Ln => {
+                        match config.scale {
+                            Scale::A1o2_R2 => assert!(
+                                arg1s.iter().all(|v| (0.05354..0.5).contains(v)),
+                                "When SCALE set to 1, ARG1 should be: 0.05354 <= ARG1 < 0.5"
+                            ),
+                            Scale::A1o4_R4 => assert!(
+                                arg1s.iter().all(|v| (0.25..0.75).contains(v)),
+                                "When SCALE set to 2, ARG1 should be: 0.25 <= ARG1 < 0.75"
+                            ),
+                            Scale::A1o8_R8 => assert!(
+                                arg1s.iter().all(|v| (0.375..0.875).contains(v)),
+                                "When SCALE set to 3, ARG1 should be: 0.375 <= ARG1 < 0.875"
+                            ),
+                            Scale::A1o16_R16 => assert!(
+                                arg1s.iter().all(|v| (0.4375..0.584).contains(v)),
+                                "When SCALE set to 4, ARG1 should be: 0.4375 <= ARG1 < 0.584"
+                            ),
+                            _ => unreachable!(),
+                        };
+                    }
+
+                    Function::Sqrt => match config.scale {
+                        Scale::A1_R1 => assert!(
+                            arg1s.iter().all(|v| (0.027..0.75).contains(v)),
+                            "When SCALE set to 0, ARG1 should be: 0.027 <= ARG1 < 0.75"
+                        ),
+                        Scale::A1o2_R2 => assert!(
+                            arg1s.iter().all(|v| (0.375..0.875).contains(v)),
+                            "When SCALE set to 1, ARG1 should be: 0.375 <= ARG1 < 0.875"
+                        ),
+                        Scale::A1o4_R4 => assert!(
+                            arg1s.iter().all(|v| (0.4375..0.585).contains(v)),
+                            "When SCALE set to 2, ARG1 should be: 0.4375  <= ARG1 < 0.585"
+                        ),
+                        _ => unreachable!(),
+                    },
+                }
+
+                // check ARG2 value
+                if let Some(arg2s) = arg2s {
+                    match config.function {
+                        Cos | Sin => assert!(
+                            arg2s.iter().all(|v| (0.0..=1.0).contains(v)),
+                            "ARG2 should be: 0 <= ARG2 <= 1"
+                        ),
+
+                        Phase | Modulus => assert!(
+                            arg2s.iter().all(|v| (-1.0..=1.0).contains(v)),
+                            "ARG2 should be: -1 <= ARG2 <= 1"
+                        ),
+
+                        _ => (),
+                    }
+                }
+            }
+        }
+    };
+}
+
+check_input_value!(check_input_f64, f64);
+check_input_value!(check_input_f32, f32);
+
 foreach_interrupt!(
     ($inst:ident, cordic, $block:ident, GLOBAL, $irq:ident) => {
         impl Instance for peripherals::$inst {
diff --git a/embassy-stm32/src/cordic/utils.rs b/embassy-stm32/src/cordic/utils.rs
index 3f055c34b..2f4b5c5e8 100644
--- a/embassy-stm32/src/cordic/utils.rs
+++ b/embassy-stm32/src/cordic/utils.rs
@@ -3,7 +3,7 @@
 macro_rules! floating_fixed_convert {
     ($f_to_q:ident, $q_to_f:ident, $unsigned_bin_typ:ty, $signed_bin_typ:ty, $float_ty:ty, $offset:literal, $min_positive:literal) => {
         /// convert float point to fixed point format
-        pub fn $f_to_q(value: $float_ty) -> $unsigned_bin_typ {
+        pub(crate) fn $f_to_q(value: $float_ty) -> $unsigned_bin_typ {
             const MIN_POSITIVE: $float_ty = unsafe { core::mem::transmute($min_positive) };
 
             assert!(
@@ -31,7 +31,7 @@ macro_rules! floating_fixed_convert {
 
         #[inline(always)]
         /// convert fixed point to float point format
-        pub fn $q_to_f(value: $unsigned_bin_typ) -> $float_ty {
+        pub(crate) fn $q_to_f(value: $unsigned_bin_typ) -> $float_ty {
             // It's needed to convert from unsigned to signed first, for correct result.
             -(value as $signed_bin_typ as $float_ty) / ((1 as $unsigned_bin_typ << $offset) as $float_ty)
         }