From c42d9f9eaae546faae46c4d1121f1fbc393c2073 Mon Sep 17 00:00:00 2001
From: eZio Pan <eziopan@qq.com>
Date: Thu, 21 Mar 2024 13:25:40 +0800
Subject: [PATCH] stm32 CORDIC: bug fix

---
 embassy-stm32/src/cordic/errors.rs | 76 +++++++++++++++++++++++-------
 embassy-stm32/src/cordic/mod.rs    | 53 ++++++++++-----------
 embassy-stm32/src/cordic/utils.rs  | 43 +++++++++--------
 3 files changed, 109 insertions(+), 63 deletions(-)

diff --git a/embassy-stm32/src/cordic/errors.rs b/embassy-stm32/src/cordic/errors.rs
index d0b2dc618..2c0aca4a2 100644
--- a/embassy-stm32/src/cordic/errors.rs
+++ b/embassy-stm32/src/cordic/errors.rs
@@ -9,6 +9,26 @@ pub enum CordicError {
     ArgError(ArgError),
     /// Output buffer length error
     OutputLengthNotEnough,
+    /// Input value is out of range for Q1.x format
+    NumberOutOfRange(NumberOutOfRange),
+}
+
+impl From<ConfigError> for CordicError {
+    fn from(value: ConfigError) -> Self {
+        Self::ConfigError(value)
+    }
+}
+
+impl From<ArgError> for CordicError {
+    fn from(value: ArgError) -> Self {
+        Self::ArgError(value)
+    }
+}
+
+impl From<NumberOutOfRange> for CordicError {
+    fn from(value: NumberOutOfRange) -> Self {
+        Self::NumberOutOfRange(value)
+    }
 }
 
 #[cfg(feature = "defmt")]
@@ -19,6 +39,7 @@ impl defmt::Format for CordicError {
         match self {
             ConfigError(e) => defmt::write!(fmt, "{}", e),
             ArgError(e) => defmt::write!(fmt, "{}", e),
+            NumberOutOfRange(e) => defmt::write!(fmt, "{}", e),
             OutputLengthNotEnough => defmt::write!(fmt, "Output buffer length is not long enough"),
         }
     }
@@ -68,28 +89,51 @@ impl defmt::Format for ArgError {
             defmt::write!(fmt, " when SCALE is {},", scale);
         }
 
-        let arg_string = match self.arg_type {
-            ArgType::Arg1 => "ARG1",
-            ArgType::Arg2 => "ARG2",
+        defmt::write!(fmt, " {} should be", self.arg_type);
+
+        if self.inclusive_upper_bound {
+            defmt::write!(
+                fmt,
+                " {} <= {} <= {}",
+                self.arg_range[0],
+                self.arg_type,
+                self.arg_range[1]
+            )
+        } else {
+            defmt::write!(
+                fmt,
+                " {} <= {} < {}",
+                self.arg_range[0],
+                self.arg_type,
+                self.arg_range[1]
+            )
         };
-
-        defmt::write!(fmt, " {} should be", arg_string);
-
-        let inclusive_string = if self.inclusive_upper_bound { "=" } else { "" };
-
-        defmt::write!(
-            fmt,
-            " {} <= {} <{} {}",
-            self.arg_range[0],
-            arg_string,
-            inclusive_string,
-            self.arg_range[1]
-        )
     }
 }
 
 #[derive(Debug)]
+#[cfg_attr(feature = "defmt", derive(defmt::Format))]
 pub(super) enum ArgType {
     Arg1,
     Arg2,
 }
+
+/// Input value is out of range for Q1.x format
+#[allow(missing_docs)]
+#[derive(Debug)]
+pub enum NumberOutOfRange {
+    BelowLowerBound,
+    AboveUpperBound,
+}
+
+#[cfg(feature = "defmt")]
+impl defmt::Format for NumberOutOfRange {
+    fn format(&self, fmt: defmt::Formatter) {
+        use NumberOutOfRange::*;
+
+        match self {
+            BelowLowerBound => defmt::write!(fmt, "input value should be equal or greater than -1"),
+            AboveUpperBound => defmt::write!(fmt, "input value should be equal or less than 1"),
+        }
+    }
+}
diff --git a/embassy-stm32/src/cordic/mod.rs b/embassy-stm32/src/cordic/mod.rs
index 5ac9addd8..b0db3f060 100644
--- a/embassy-stm32/src/cordic/mod.rs
+++ b/embassy-stm32/src/cordic/mod.rs
@@ -56,7 +56,7 @@ impl Config {
         Ok(config)
     }
 
-    fn check_scale(&self) -> Result<(), CordicError> {
+    fn check_scale(&self) -> Result<(), ConfigError> {
         use Function::*;
 
         let scale_raw = self.scale as u8;
@@ -76,10 +76,10 @@ impl Config {
         };
 
         if let Some(range) = err_range {
-            Err(CordicError::ConfigError(ConfigError {
+            Err(ConfigError {
                 func: self.function,
                 scale_range: range,
-            }))
+            })
         } else {
             Ok(())
         }
@@ -226,20 +226,20 @@ impl<'d, T: Instance> Cordic<'d, T> {
             consumed_input_len = double_input.len() + 1;
 
             // preload first value from arg1 to cordic
-            self.blocking_write_f64(arg1s[0]);
+            self.blocking_write_f64(arg1s[0])?;
 
             for (&arg1, &arg2) in double_input {
                 // Since we manually preload a value before,
                 // we will write arg2 (from the actual last pair) first, (at this moment, cordic start to calculating,)
                 // and write arg1 (from the actual next pair), then read the result, to "keep preloading"
 
-                self.blocking_write_f64(arg2);
-                self.blocking_write_f64(arg1);
+                self.blocking_write_f64(arg2)?;
+                self.blocking_write_f64(arg1)?;
                 self.blocking_read_f64_to_buf(output, &mut output_count);
             }
 
             // write last input value from arg2s, then read out the result
-            self.blocking_write_f64(arg2s[arg2s.len() - 1]);
+            self.blocking_write_f64(arg2s[arg2s.len() - 1])?;
             self.blocking_read_f64_to_buf(output, &mut output_count);
         }
 
@@ -253,12 +253,12 @@ impl<'d, T: Instance> Cordic<'d, T> {
             self.peri.set_argument_count(AccessCount::One);
 
             // "preload" value to cordic (at this moment, cordic start to calculating)
-            self.blocking_write_f64(input_left[0]);
+            self.blocking_write_f64(input_left[0])?;
 
             for &arg in input_left.iter().skip(1) {
                 // this line write arg for next round caculation to cordic,
                 // and read result from last round
-                self.blocking_write_f64(arg);
+                self.blocking_write_f64(arg)?;
                 self.blocking_read_f64_to_buf(output, &mut output_count);
             }
 
@@ -281,8 +281,9 @@ impl<'d, T: Instance> Cordic<'d, T> {
         }
     }
 
-    fn blocking_write_f64(&mut self, arg: f64) {
-        self.peri.write_argument(utils::f64_to_q1_31(arg));
+    fn blocking_write_f64(&mut self, arg: f64) -> Result<(), NumberOutOfRange> {
+        self.peri.write_argument(utils::f64_to_q1_31(arg)?);
+        Ok(())
     }
 
     /// Run a async CORDIC calculation in q.1.31 format
@@ -339,7 +340,7 @@ impl<'d, T: Instance> Cordic<'d, T> {
 
             for (&arg1, &arg2) in double_input {
                 for &arg in [arg1, arg2].iter() {
-                    input_buf[input_buf_len] = utils::f64_to_q1_31(arg);
+                    input_buf[input_buf_len] = utils::f64_to_q1_31(arg)?;
                     input_buf_len += 1;
                 }
 
@@ -383,7 +384,7 @@ impl<'d, T: Instance> Cordic<'d, T> {
             self.peri.set_argument_count(AccessCount::One);
 
             for &arg in input_remain {
-                input_buf[input_buf_len] = utils::f64_to_q1_31(arg);
+                input_buf[input_buf_len] = utils::f64_to_q1_31(arg)?;
                 input_buf_len += 1;
 
                 if input_buf_len == INPUT_BUF_MAX_LEN {
@@ -509,10 +510,10 @@ impl<'d, T: Instance> Cordic<'d, T> {
         let (&arg1, &arg2) = args.next().unwrap();
 
         // preloading 1 pair of arguments
-        self.blocking_write_f32(arg1, arg2);
+        self.blocking_write_f32(arg1, arg2)?;
 
         for (&arg1, &arg2) in args {
-            self.blocking_write_f32(arg1, arg2);
+            self.blocking_write_f32(arg1, arg2)?;
             self.blocking_read_f32_to_buf(output, &mut output_count);
         }
 
@@ -522,15 +523,13 @@ impl<'d, T: Instance> Cordic<'d, T> {
         Ok(output_count)
     }
 
-    fn blocking_write_f32(&mut self, arg1: f32, arg2: f32) {
-        let reg_value: u32 = utils::f32_args_to_u32(arg1, arg2);
-        self.peri.write_argument(reg_value);
+    fn blocking_write_f32(&mut self, arg1: f32, arg2: f32) -> Result<(), NumberOutOfRange> {
+        self.peri.write_argument(utils::f32_args_to_u32(arg1, arg2)?);
+        Ok(())
     }
 
     fn blocking_read_f32_to_buf(&mut self, result_buf: &mut [f32], result_index: &mut usize) {
-        let reg_value = self.peri.read_result();
-
-        let (res1, res2) = utils::u32_to_f32_res(reg_value);
+        let (res1, res2) = utils::u32_to_f32_res(self.peri.read_result());
 
         result_buf[*result_index] = res1;
         *result_index += 1;
@@ -597,7 +596,7 @@ impl<'d, T: Instance> Cordic<'d, T> {
         );
 
         for (&arg1, &arg2) in args {
-            input_buf[input_buf_len] = utils::f32_args_to_u32(arg1, arg2);
+            input_buf[input_buf_len] = utils::f32_args_to_u32(arg1, arg2)?;
             input_buf_len += 1;
 
             if input_buf_len == INPUT_BUF_MAX_LEN {
@@ -655,7 +654,7 @@ impl<'d, T: Instance> Cordic<'d, T> {
 macro_rules! check_input_value {
     ($func_name:ident, $float_type:ty) => {
         impl<'d, T: Instance> Cordic<'d, T> {
-            fn $func_name(&self, arg1s: &[$float_type], arg2s: Option<&[$float_type]>) -> Result<(), CordicError> {
+            fn $func_name(&self, arg1s: &[$float_type], arg2s: Option<&[$float_type]>) -> Result<(), ArgError> {
                 let config = &self.config;
 
                 use Function::*;
@@ -741,13 +740,13 @@ macro_rules! check_input_value {
                 };
 
                 if let Some(err) = err_info {
-                    return Err(CordicError::ArgError(ArgError {
+                    return Err(ArgError {
                         func: config.function,
                         scale: err.scale,
                         arg_range: err.range,
                         inclusive_upper_bound: err.inclusive_upper_bound,
                         arg_type: ArgType::Arg1,
-                    }));
+                    });
                 }
 
                 // check ARG2 value
@@ -769,13 +768,13 @@ macro_rules! check_input_value {
                     };
 
                     if let Some(err) = err_info {
-                        return Err(CordicError::ArgError(ArgError {
+                        return Err(ArgError {
                             func: config.function,
                             scale: None,
                             arg_range: err.range,
                             inclusive_upper_bound: true,
                             arg_type: ArgType::Arg2,
-                        }));
+                        });
                     }
                 }
 
diff --git a/embassy-stm32/src/cordic/utils.rs b/embassy-stm32/src/cordic/utils.rs
index 79bef6b97..3c3ed224f 100644
--- a/embassy-stm32/src/cordic/utils.rs
+++ b/embassy-stm32/src/cordic/utils.rs
@@ -1,39 +1,42 @@
 //! Common match utils
+use super::errors::NumberOutOfRange;
 
 macro_rules! floating_fixed_convert {
     ($f_to_q:ident, $q_to_f:ident, $unsigned_bin_typ:ty, $signed_bin_typ:ty, $float_ty:ty, $offset:literal, $min_positive:literal) => {
         /// convert float point to fixed point format
-        pub(crate) fn $f_to_q(value: $float_ty) -> $unsigned_bin_typ {
+        pub fn $f_to_q(value: $float_ty) -> Result<$unsigned_bin_typ, NumberOutOfRange> {
             const MIN_POSITIVE: $float_ty = unsafe { core::mem::transmute($min_positive) };
 
-            assert!(
-                (-1.0 as $float_ty) <= value,
-                "input value {} should be equal or greater than -1",
-                value
-            );
+            if value < -1.0 {
+                return Err(NumberOutOfRange::BelowLowerBound)
+            }
+
+            if value > 1.0 {
+                return Err(NumberOutOfRange::AboveUpperBound)
+            }
 
 
-            let value = if value == 1.0 as $float_ty{
-                // make a exception for user specifing exact 1.0 float point,
-                // convert 1.0 to max representable value of q1.x format
+            let value = if 1.0 - MIN_POSITIVE < value && value <= 1.0 {
+                // make a exception for value between (1.0^{-x} , 1.0] float point,
+                // convert it to max representable value of q1.x format
                 (1.0 as $float_ty) - MIN_POSITIVE
             } else {
-                assert!(
-                    value <= (1.0 as $float_ty) - MIN_POSITIVE,
-                    "input value {} should be equal or less than 1-2^(-{})",
-                    value, $offset
-                );
                 value
             };
 
-            (value * ((1 as $unsigned_bin_typ << $offset) as $float_ty)) as $unsigned_bin_typ
+            // It's necessary to cast the float value to signed integer, before convert it to a unsigned value.
+            // Since value from register is actually a "signed value", a "as" cast will keep original binary format but mark it as unsgined value.
+            // see https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast
+            Ok((value * ((1 as $unsigned_bin_typ << $offset) as $float_ty)) as $signed_bin_typ as $unsigned_bin_typ)
         }
 
         #[inline(always)]
         /// convert fixed point to float point format
-        pub(crate) fn $q_to_f(value: $unsigned_bin_typ) -> $float_ty {
-            // It's needed to convert from unsigned to signed first, for correct result.
-            -(value as $signed_bin_typ as $float_ty) / ((1 as $unsigned_bin_typ << $offset) as $float_ty)
+        pub fn $q_to_f(value: $unsigned_bin_typ) -> $float_ty {
+            // It's necessary to cast the unsigned integer to signed integer, before convert it to a float value.
+            // Since value from register is actually a "signed value", a "as" cast will keep original binary format but mark it as signed value.
+            // see https://doc.rust-lang.org/reference/expressions/operator-expr.html#numeric-cast
+            (value as $signed_bin_typ as $float_ty) / ((1 as $unsigned_bin_typ << $offset) as $float_ty)
         }
     };
 }
@@ -59,8 +62,8 @@ floating_fixed_convert!(
 );
 
 #[inline(always)]
-pub(crate) fn f32_args_to_u32(arg1: f32, arg2: f32) -> u32 {
-    f32_to_q1_15(arg1) as u32 + ((f32_to_q1_15(arg2) as u32) << 16)
+pub(crate) fn f32_args_to_u32(arg1: f32, arg2: f32) -> Result<u32, NumberOutOfRange> {
+    Ok(f32_to_q1_15(arg1)? as u32 + ((f32_to_q1_15(arg2)? as u32) << 16))
 }
 
 #[inline(always)]