From 5d12f594303bdb76bf2356d9fc0661826e2e658e Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Sat, 16 Mar 2024 00:25:38 +0800
Subject: [PATCH] stm32 CORDIC: make use of "preload" feature

---
 embassy-stm32/src/cordic/mod.rs | 180 +++++++++++++++-----------------
 1 file changed, 85 insertions(+), 95 deletions(-)

diff --git a/embassy-stm32/src/cordic/mod.rs b/embassy-stm32/src/cordic/mod.rs
index b15521ca6..997ace113 100644
--- a/embassy-stm32/src/cordic/mod.rs
+++ b/embassy-stm32/src/cordic/mod.rs
@@ -22,9 +22,8 @@ pub mod low_level {
 
 /// CORDIC driver
 pub struct Cordic<'d, T: Instance> {
-    cordic: PeripheralRef<'d, T>,
+    peri: PeripheralRef<'d, T>,
     config: Config,
-    state: State,
 }
 
 /// CORDIC instance trait
@@ -83,23 +82,16 @@ impl<'d, T: Instance> Cordic<'d, T> {
     /// Note:  
     /// If you need a periperhal -> CORDIC -> peripehral mode,  
     /// you may want to set Cordic into [Mode::ZeroOverhead] mode, and add extra arguemnts with [Self::extra_config]
-    pub fn new(cordic: impl Peripheral<P = T> + 'd, config: Config) -> Self {
+    pub fn new(peri: impl Peripheral<P = T> + 'd, config: Config) -> Self {
         T::enable_and_reset();
 
-        into_ref!(cordic);
+        into_ref!(peri);
 
         if !config.check_scale() {
             panic!("Scale value is not compatible with Function")
         }
 
-        let mut instance = Self {
-            cordic,
-            config,
-            state: State {
-                input_buf: [0u32; 8],
-                buf_index: 0,
-            },
-        };
+        let mut instance = Self { peri, config };
 
         instance.reconfigure();
 
@@ -114,51 +106,71 @@ impl<'d, T: Instance> Cordic<'d, T> {
 
     /// Set extra config for data count and data width.
     pub fn extra_config(&mut self, arg_cnt: Count, arg_width: Width, res_width: Width) {
-        let peri = &self.cordic;
-        peri.set_argument_count(arg_cnt);
-        peri.set_data_width(arg_width, res_width);
+        self.peri.set_argument_count(arg_cnt);
+        self.peri.set_data_width(arg_width, res_width);
     }
 
     fn reconfigure(&mut self) {
-        let peri = &self.cordic;
-        let config = &self.config;
-
-        if peri.ready_to_read() {
+        if self.peri.ready_to_read() {
             warn!("At least 1 result hasn't been read, reconfigure will cause DATA LOST");
         };
 
-        peri.disable_irq();
-        peri.disable_write_dma();
-        peri.disable_read_dma();
+        self.peri.disable_irq();
+        self.peri.disable_write_dma();
+        self.peri.disable_read_dma();
 
         // clean RRDY flag
-        while peri.ready_to_read() {
-            peri.read_result();
+        while self.peri.ready_to_read() {
+            self.peri.read_result();
         }
 
-        peri.set_func(config.function);
-        peri.set_precision(config.precision);
-        peri.set_scale(config.scale);
+        self.peri.set_func(self.config.function);
+        self.peri.set_precision(self.config.precision);
+        self.peri.set_scale(self.config.scale);
 
-        if config.first_result {
-            peri.set_result_count(Count::One)
+        if self.config.first_result {
+            self.peri.set_result_count(Count::One)
         } else {
-            peri.set_result_count(Count::Two)
+            self.peri.set_result_count(Count::Two)
         }
 
-        match config.mode {
+        match self.config.mode {
             Mode::ZeroOverhead => (),
             Mode::Interrupt => {
-                peri.enable_irq();
+                self.peri.enable_irq();
             }
             Mode::Dma => {
-                peri.enable_write_dma();
-                peri.enable_read_dma();
+                self.peri.enable_write_dma();
+                self.peri.enable_read_dma();
             }
         }
+    }
 
-        self.state.input_buf.fill(0u32);
-        self.state.buf_index = 0;
+    fn blocking_read_f64(&mut self) -> (f64, Option<f64>) {
+        let res1 = utils::q1_31_to_f64(self.peri.read_result());
+
+        let res2 = if !self.config.first_result {
+            Some(utils::q1_31_to_f64(self.peri.read_result()))
+        } else {
+            None
+        };
+
+        (res1, res2)
+    }
+
+    fn blocking_read_f64_to_buf(&mut self, result_buf: &mut [f64], result_index: &mut usize) {
+        let (res1, res2) = self.blocking_read_f64();
+        result_buf[*result_index] = res1;
+        *result_index += 1;
+
+        if let Some(res2) = res2 {
+            result_buf[*result_index] = res2;
+            *result_index += 1;
+        }
+    }
+
+    fn blocking_write_f64(&mut self, arg: f64) {
+        self.peri.write_argument(utils::f64_to_q1_31(arg));
     }
 }
 
@@ -172,11 +184,8 @@ impl<'d, T: Instance> Drop for Cordic<'d, T> {
 impl<'d, T: Instance> Cordic<'d, T> {
     /// Run a CORDIC calculation
     pub fn calc_32bit(&mut self, arg1s: &[f64], arg2s: Option<&[f64]>, output: &mut [f64]) -> usize {
-        let peri = &self.cordic;
-        let config = &self.config;
-
         assert!(
-            match config.first_result {
+            match self.config.first_result {
                 true => output.len() >= arg1s.len(),
                 false => output.len() >= 2 * arg1s.len(),
             },
@@ -185,87 +194,68 @@ impl<'d, T: Instance> Cordic<'d, T> {
 
         self.check_input_f64(arg1s, arg2s);
 
-        peri.set_result_count(if config.first_result { Count::One } else { Count::Two });
-        peri.set_data_width(Width::Bits32, Width::Bits32);
+        self.peri.set_result_count(if self.config.first_result {
+            Count::One
+        } else {
+            Count::Two
+        });
 
-        let state = &mut self.state;
+        self.peri.set_data_width(Width::Bits32, Width::Bits32);
 
         let mut output_count = 0;
 
         let mut consumed_input_len = 0;
 
-        match config.mode {
+        match self.config.mode {
             Mode::ZeroOverhead => {
                 // put double input into cordic
                 if arg2s.is_some() && !arg2s.unwrap().is_empty() {
                     let arg2s = arg2s.unwrap();
 
-                    peri.set_argument_count(Count::Two);
+                    self.peri.set_argument_count(Count::Two);
 
-                    let double_value = arg1s.iter().zip(arg2s);
-                    consumed_input_len = double_value.len();
+                    // Skip 1st value from arg1s, this value will be manually "preload" to cordic, to make use of cordic preload function.
+                    // And we preserve last value from arg2s, since it need to manually write to cordic, and read the result out.
+                    let double_input = arg1s.iter().skip(1).zip(&arg2s[..arg2s.len() - 1]);
+                    // Since we preload 1st value from arg1s, the consumed input length is double_input length + 1.
+                    consumed_input_len = double_input.len() + 1;
 
-                    for (arg1, arg2) in double_value {
-                        // if input_buf is full, send values to cordic
-                        if state.buf_index == INPUT_BUF_LEN - 1 {
-                            for arg in state.input_buf.chunks(2) {
-                                peri.write_argument(arg[0]);
-                                peri.write_argument(arg[1]);
+                    // preload first value from arg1 to cordic
+                    self.blocking_write_f64(arg1s[0]);
 
-                                output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                                output_count += 1;
+                    for (&arg1, &arg2) in double_input {
+                        // Since we manually preload a value before,
+                        // we will write arg2 (from the actual last pair) first, (at this moment, cordic start to calculating,)
+                        // and write arg1 (from the actual next pair), then read the result, to "keep preloading"
 
-                                if !config.first_result {
-                                    output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                                    output_count += 1;
-                                }
-                            }
-
-                            state.buf_index = 0;
-                        }
-
-                        for &&arg in [arg1, arg2].iter() {
-                            state.input_buf[state.buf_index] = utils::f64_to_q1_31(arg);
-                            state.buf_index += 1;
-                        }
+                        self.blocking_write_f64(arg2);
+                        self.blocking_write_f64(arg1);
+                        self.blocking_read_f64_to_buf(output, &mut output_count);
                     }
 
-                    // put left paired args into cordic
-                    if state.buf_index > 0 {
-                        for arg in state.input_buf[..state.buf_index].chunks(2) {
-                            peri.write_argument(arg[0]);
-                            peri.write_argument(arg[1]);
-
-                            output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                            output_count += 1;
-
-                            if !config.first_result {
-                                output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                                output_count += 1;
-                            }
-                        }
-
-                        state.buf_index = 0;
-                    }
+                    // write last input value from arg2s, then read out the result
+                    self.blocking_write_f64(arg2s[arg2s.len() - 1]);
+                    self.blocking_read_f64_to_buf(output, &mut output_count);
                 }
 
                 // put single input into cordic
                 let input_left = &arg1s[consumed_input_len..];
 
                 if !input_left.is_empty() {
-                    peri.set_argument_count(Count::One);
+                    self.peri.set_argument_count(Count::One);
 
-                    for &arg in input_left.iter() {
-                        peri.write_argument(utils::f64_to_q1_31(arg));
+                    // "preload" value to cordic (at this moment, cordic start to calculating)
+                    self.blocking_write_f64(input_left[0]);
 
-                        output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                        output_count += 1;
-
-                        if !config.first_result {
-                            output[output_count] = utils::q1_31_to_f64(peri.read_result());
-                            output_count += 1;
-                        }
+                    for &arg in input_left.iter().skip(1) {
+                        // this line write arg for next round caculation to cordic,
+                        // and read result from last round
+                        self.blocking_write_f64(arg);
+                        self.blocking_read_f64_to_buf(output, &mut output_count);
                     }
+
+                    // read the last output
+                    self.blocking_read_f64_to_buf(output, &mut output_count);
                 }
 
                 output_count