diff --git a/embassy-stm32/src/cordic/errors.rs b/embassy-stm32/src/cordic/errors.rs index 653014290..3c70fc9e7 100644 --- a/embassy-stm32/src/cordic/errors.rs +++ b/embassy-stm32/src/cordic/errors.rs @@ -5,12 +5,14 @@ use super::{Function, Scale}; pub enum CordicError { /// Config error ConfigError(ConfigError), - /// Argument error - ArgError(ArgError), - /// Output buffer length error - OutputLengthNotEnough, + /// Argument length is incorrect + ArgumentLengthIncorrect, + /// Result buffer length error + ResultLengthNotEnough, /// Input value is out of range for Q1.x format NumberOutOfRange(NumberOutOfRange), + /// Argument error + ArgError(ArgError), } impl From<ConfigError> for CordicError { @@ -19,18 +21,18 @@ impl From<ConfigError> for CordicError { } } -impl From<ArgError> for CordicError { - fn from(value: ArgError) -> Self { - Self::ArgError(value) - } -} - impl From<NumberOutOfRange> for CordicError { fn from(value: NumberOutOfRange) -> Self { Self::NumberOutOfRange(value) } } +impl From<ArgError> for CordicError { + fn from(value: ArgError) -> Self { + Self::ArgError(value) + } +} + #[cfg(feature = "defmt")] impl defmt::Format for CordicError { fn format(&self, fmt: defmt::Formatter) { @@ -38,9 +40,10 @@ impl defmt::Format for CordicError { match self { ConfigError(e) => defmt::write!(fmt, "{}", e), - ArgError(e) => defmt::write!(fmt, "{}", e), + ResultLengthNotEnough => defmt::write!(fmt, "Output buffer length is not long enough"), + ArgumentLengthIncorrect => defmt::write!(fmt, "Argument length incorrect"), NumberOutOfRange(e) => defmt::write!(fmt, "{}", e), - OutputLengthNotEnough => defmt::write!(fmt, "Output buffer length is not long enough"), + ArgError(e) => defmt::write!(fmt, "{}", e), } } } @@ -71,6 +74,26 @@ impl defmt::Format for ConfigError { } } +/// Input value is out of range for Q1.x format +#[allow(missing_docs)] +#[derive(Debug)] +pub enum NumberOutOfRange { + BelowLowerBound, + AboveUpperBound, +} + +#[cfg(feature = "defmt")] +impl defmt::Format for NumberOutOfRange { + fn format(&self, fmt: defmt::Formatter) { + use NumberOutOfRange::*; + + match self { + BelowLowerBound => defmt::write!(fmt, "input value should be equal or greater than -1"), + AboveUpperBound => defmt::write!(fmt, "input value should be equal or less than 1"), + } + } +} + /// Error on checking input arguments #[allow(dead_code)] #[derive(Debug)] @@ -119,23 +142,3 @@ pub(super) enum ArgType { Arg1, Arg2, } - -/// Input value is out of range for Q1.x format -#[allow(missing_docs)] -#[derive(Debug)] -pub enum NumberOutOfRange { - BelowLowerBound, - AboveUpperBound, -} - -#[cfg(feature = "defmt")] -impl defmt::Format for NumberOutOfRange { - fn format(&self, fmt: defmt::Formatter) { - use NumberOutOfRange::*; - - match self { - BelowLowerBound => defmt::write!(fmt, "input value should be equal or greater than -1"), - AboveUpperBound => defmt::write!(fmt, "input value should be equal or less than 1"), - } - } -} diff --git a/embassy-stm32/src/cordic/mod.rs b/embassy-stm32/src/cordic/mod.rs index f12efe2eb..2479e1b27 100644 --- a/embassy-stm32/src/cordic/mod.rs +++ b/embassy-stm32/src/cordic/mod.rs @@ -21,8 +21,6 @@ pub mod low_level { pub use super::sealed::*; } -const INPUT_BUF_MAX_LEN: usize = 16; - /// CORDIC driver pub struct Cordic<'d, T: Instance> { peri: PeripheralRef<'d, T>, @@ -38,17 +36,15 @@ pub struct Config { function: Function, precision: Precision, scale: Scale, - res1_only: bool, } impl Config { /// Create a config for Cordic driver - pub fn new(function: Function, precision: Precision, scale: Scale, res1_only: bool) -> Result<Self, CordicError> { + pub fn new(function: Function, precision: Precision, scale: Scale) -> Result<Self, CordicError> { let config = Self { function, precision, scale, - res1_only, }; config.check_scale()?; @@ -117,7 +113,32 @@ impl<'d, T: Instance> Cordic<'d, T> { self.peri.set_data_width(arg_width, res_width); } - fn reconfigure(&mut self) { + fn clean_rrdy_flag(&mut self) { + while self.peri.ready_to_read() { + self.peri.read_result(); + } + } + + /// Disable IRQ and DMA, clean RRDY, and set ARG2 to +1 (0x7FFFFFFF) + pub fn reconfigure(&mut self) { + // reset ARG2 to +1 + { + self.peri.disable_irq(); + self.peri.disable_read_dma(); + self.peri.disable_write_dma(); + self.clean_rrdy_flag(); + + self.peri.set_func(Function::Cos); + self.peri.set_precision(Precision::Iters4); + self.peri.set_scale(Scale::Arg1Res1); + self.peri.set_argument_count(AccessCount::Two); + self.peri.set_data_width(Width::Bits32, Width::Bits32); + self.peri.write_argument(0x0u32); + self.peri.write_argument(0x7FFFFFFFu32); + + self.clean_rrdy_flag(); + } + self.peri.set_func(self.config.function); self.peri.set_precision(self.config.precision); self.peri.set_scale(self.config.scale); @@ -125,16 +146,154 @@ impl<'d, T: Instance> Cordic<'d, T> { // we don't set NRES in here, but to make sure NRES is set each time user call "calc"-ish functions, // since each "calc"-ish functions can have different ARGSIZE and RESSIZE, thus NRES should be change accordingly. } +} - async fn launch_a_dma_transfer( +impl<'d, T: Instance> Drop for Cordic<'d, T> { + fn drop(&mut self) { + T::disable(); + } +} + +// q1.31 related +impl<'d, T: Instance> Cordic<'d, T> { + /// Run a blocking CORDIC calculation in q1.31 format + /// + /// Notice: + /// If you set `arg1_only` to `true`, please be sure ARG2 value has been set to desired value before. + /// This function won't set ARG2 to +1 before or after each round of calculation. + /// If you want to make sure ARG2 is set to +1, consider run [.reconfigure()](Self::reconfigure). + pub fn blocking_calc_32bit( + &mut self, + arg: &[u32], + res: &mut [u32], + arg1_only: bool, + res1_only: bool, + ) -> Result<usize, CordicError> { + if arg.is_empty() { + return Ok(0); + } + + let res_cnt = Self::check_arg_res_length_32bit(arg.len(), res.len(), arg1_only, res1_only)?; + + self.peri + .set_argument_count(if arg1_only { AccessCount::One } else { AccessCount::Two }); + + self.peri + .set_result_count(if res1_only { AccessCount::One } else { AccessCount::Two }); + + self.peri.set_data_width(Width::Bits32, Width::Bits32); + + let mut cnt = 0; + + match arg1_only { + true => { + // To use cordic preload function, the first value is special. + // It is loaded to CORDIC WDATA register out side of loop + let first_value = arg[0]; + + // preload 1st value to CORDIC, to start the CORDIC calc + self.peri.write_argument(first_value); + + for &arg1 in &arg[1..] { + // preload arg1 (for next calc) + self.peri.write_argument(arg1); + + // then read current result out + res[cnt] = self.peri.read_result(); + cnt += 1; + if !res1_only { + res[cnt] = self.peri.read_result(); + cnt += 1; + } + } + + // read the last result + res[cnt] = self.peri.read_result(); + cnt += 1; + if !res1_only { + res[cnt] = self.peri.read_result(); + // cnt += 1; + } + } + false => { + // To use cordic preload function, the first and last value is special. + // They are load to CORDIC WDATA register out side of loop + let first_value = arg[0]; + let last_value = arg[arg.len() - 1]; + + let paired_args = &arg[1..arg.len() - 1]; + + // preload 1st value to CORDIC + self.peri.write_argument(first_value); + + for args in paired_args.chunks(2) { + let arg2 = args[0]; + let arg1 = args[1]; + + // load arg2 (for current calc) first, to start the CORDIC calc + self.peri.write_argument(arg2); + + // preload arg1 (for next calc) + self.peri.write_argument(arg1); + + // then read current result out + res[cnt] = self.peri.read_result(); + cnt += 1; + if !res1_only { + res[cnt] = self.peri.read_result(); + cnt += 1; + } + } + + // load last value to CORDIC, and finish the calculation + self.peri.write_argument(last_value); + res[cnt] = self.peri.read_result(); + cnt += 1; + if !res1_only { + res[cnt] = self.peri.read_result(); + // cnt += 1; + } + } + } + + // at this point cnt should be equal to res_cnt + + Ok(res_cnt) + } + + /// Run a async CORDIC calculation in q.1.31 format + /// + /// Notice: + /// If you set `arg1_only` to `true`, please be sure ARG2 value has been set to desired value before. + /// This function won't set ARG2 to +1 before or after each round of calculation. + /// If you want to make sure ARG2 is set to +1, consider run [.reconfigure()](Self::reconfigure). + pub async fn async_calc_32bit( &mut self, write_dma: impl Peripheral<P = impl WriteDma<T>>, read_dma: impl Peripheral<P = impl ReadDma<T>>, - input: &[u32], - output: &mut [u32], - ) { + arg: &[u32], + res: &mut [u32], + arg1_only: bool, + res1_only: bool, + ) -> Result<usize, CordicError> { + if arg.is_empty() { + return Ok(0); + } + + let res_cnt = Self::check_arg_res_length_32bit(arg.len(), res.len(), arg1_only, res1_only)?; + + let active_res_buf = &mut res[..res_cnt]; + into_ref!(write_dma, read_dma); + self.peri + .set_argument_count(if arg1_only { AccessCount::One } else { AccessCount::Two }); + + self.peri + .set_result_count(if res1_only { AccessCount::One } else { AccessCount::Two }); + + self.peri.set_data_width(Width::Bits32, Width::Bits32); + let write_req = write_dma.request(); let read_req = read_dma.request(); @@ -150,7 +309,7 @@ impl<'d, T: Instance> Cordic<'d, T> { let write_transfer = dma::Transfer::new_write( &mut write_dma, write_req, - input, + arg, T::regs().wdata().as_ptr() as *mut _, Default::default(), ); @@ -159,328 +318,60 @@ impl<'d, T: Instance> Cordic<'d, T> { &mut read_dma, read_req, T::regs().rdata().as_ptr() as *mut _, - output, + active_res_buf, Default::default(), ); embassy_futures::join::join(write_transfer, read_transfer).await; } - } -} -impl<'d, T: Instance> Drop for Cordic<'d, T> { - fn drop(&mut self) { - T::disable(); + Ok(res_cnt) } -} -// q1.31 related -impl<'d, T: Instance> Cordic<'d, T> { - /// Run a blocking CORDIC calculation in q1.31 format - pub fn blocking_calc_32bit( - &mut self, - arg1s: &[f64], - arg2s: Option<&[f64]>, - output: &mut [f64], + fn check_arg_res_length_32bit( + arg_len: usize, + res_len: usize, + arg1_only: bool, + res1_only: bool, ) -> Result<usize, CordicError> { - if arg1s.is_empty() { - return Ok(0); + if !arg1_only && arg_len % 2 != 0 { + return Err(CordicError::ArgumentLengthIncorrect); } - let output_length_enough = match self.config.res1_only { - true => output.len() >= arg1s.len(), - false => output.len() >= 2 * arg1s.len(), - }; + let mut minimal_res_length = arg_len; - if !output_length_enough { - return Err(CordicError::OutputLengthNotEnough); + if !res1_only { + minimal_res_length *= 2; } - self.check_input_f64(arg1s, arg2s)?; - - self.peri.set_result_count(if self.config.res1_only { - AccessCount::One - } else { - AccessCount::Two - }); - - self.peri.set_data_width(Width::Bits32, Width::Bits32); - - let mut output_count = 0; - - let mut consumed_input_len = 0; - - // - // handle 2 input args calculation - // - - if arg2s.is_some() && !arg2s.unwrap().is_empty() { - let arg2s = arg2s.unwrap(); - - self.peri.set_argument_count(AccessCount::Two); - - // Skip 1st value from arg1s, this value will be manually "preload" to cordic, to make use of cordic preload function. - // And we preserve last value from arg2s, since it need to manually write to cordic, and read the result out. - let double_input = arg1s.iter().skip(1).zip(&arg2s[..arg2s.len() - 1]); - // Since we preload 1st value from arg1s, the consumed input length is double_input length + 1. - consumed_input_len = double_input.len() + 1; - - // preload first value from arg1 to cordic - self.blocking_write_f64(arg1s[0])?; - - for (&arg1, &arg2) in double_input { - // Since we manually preload a value before, - // we will write arg2 (from the actual last pair) first, (at this moment, cordic start to calculating,) - // and write arg1 (from the actual next pair), then read the result, to "keep preloading" - - self.blocking_write_f64(arg2)?; - self.blocking_write_f64(arg1)?; - self.blocking_read_f64_to_buf(output, &mut output_count); - } - - // write last input value from arg2s, then read out the result - self.blocking_write_f64(arg2s[arg2s.len() - 1])?; - self.blocking_read_f64_to_buf(output, &mut output_count); + if !arg1_only { + minimal_res_length /= 2 } - // - // handle 1 input arg calculation - // - - let input_left = &arg1s[consumed_input_len..]; - - if !input_left.is_empty() { - self.peri.set_argument_count(AccessCount::One); - - // "preload" value to cordic (at this moment, cordic start to calculating) - self.blocking_write_f64(input_left[0])?; - - for &arg in input_left.iter().skip(1) { - // this line write arg for next round calculation to cordic, - // and read result from last round - self.blocking_write_f64(arg)?; - self.blocking_read_f64_to_buf(output, &mut output_count); - } - - // read the last output - self.blocking_read_f64_to_buf(output, &mut output_count); + if minimal_res_length > res_len { + return Err(CordicError::ResultLengthNotEnough); } - Ok(output_count) - } - - fn blocking_read_f64_to_buf(&mut self, result_buf: &mut [f64], result_index: &mut usize) { - result_buf[*result_index] = utils::q1_31_to_f64(self.peri.read_result()); - *result_index += 1; - - // We don't care about whether the function return 1 or 2 results, - // the only thing matter is whether user want 1 or 2 results. - if !self.config.res1_only { - result_buf[*result_index] = utils::q1_31_to_f64(self.peri.read_result()); - *result_index += 1; - } - } - - fn blocking_write_f64(&mut self, arg: f64) -> Result<(), NumberOutOfRange> { - self.peri.write_argument(utils::f64_to_q1_31(arg)?); - Ok(()) - } - - /// Run a async CORDIC calculation in q.1.31 format - pub async fn async_calc_32bit( - &mut self, - write_dma: impl Peripheral<P = impl WriteDma<T>>, - read_dma: impl Peripheral<P = impl ReadDma<T>>, - arg1s: &[f64], - arg2s: Option<&[f64]>, - output: &mut [f64], - ) -> Result<usize, CordicError> { - if arg1s.is_empty() { - return Ok(0); - } - - let output_length_enough = match self.config.res1_only { - true => output.len() >= arg1s.len(), - false => output.len() >= 2 * arg1s.len(), - }; - - if !output_length_enough { - return Err(CordicError::OutputLengthNotEnough); - } - - self.check_input_f64(arg1s, arg2s)?; - - into_ref!(write_dma, read_dma); - - self.peri.set_result_count(if self.config.res1_only { - AccessCount::One - } else { - AccessCount::Two - }); - - self.peri.set_data_width(Width::Bits32, Width::Bits32); - - let mut output_count = 0; - let mut consumed_input_len = 0; - let mut input_buf = [0u32; INPUT_BUF_MAX_LEN]; - let mut input_buf_len = 0; - - // - // handle 2 input args calculation - // - - if !arg2s.unwrap_or_default().is_empty() { - let arg2s = arg2s.unwrap(); - - self.peri.set_argument_count(AccessCount::Two); - - let double_input = arg1s.iter().zip(arg2s); - - consumed_input_len = double_input.len(); - - for (&arg1, &arg2) in double_input { - for &arg in [arg1, arg2].iter() { - input_buf[input_buf_len] = utils::f64_to_q1_31(arg)?; - input_buf_len += 1; - } - - if input_buf_len == INPUT_BUF_MAX_LEN { - self.inner_dma_calc_32bit( - &mut write_dma, - &mut read_dma, - true, - &input_buf[..input_buf_len], - output, - &mut output_count, - ) - .await; - - input_buf_len = 0; - } - } - - if input_buf_len > 0 { - self.inner_dma_calc_32bit( - &mut write_dma, - &mut read_dma, - true, - &input_buf[..input_buf_len], - output, - &mut output_count, - ) - .await; - - input_buf_len = 0; - } - } - - // - // handle 1 input arg calculation - // - - if arg1s.len() > consumed_input_len { - let input_remain = &arg1s[consumed_input_len..]; - - self.peri.set_argument_count(AccessCount::One); - - for &arg in input_remain { - input_buf[input_buf_len] = utils::f64_to_q1_31(arg)?; - input_buf_len += 1; - - if input_buf_len == INPUT_BUF_MAX_LEN { - self.inner_dma_calc_32bit( - &mut write_dma, - &mut read_dma, - false, - &input_buf[..input_buf_len], - output, - &mut output_count, - ) - .await; - - input_buf_len = 0; - } - } - - if input_buf_len > 0 { - self.inner_dma_calc_32bit( - &mut write_dma, - &mut read_dma, - false, - &input_buf[..input_buf_len], - output, - &mut output_count, - ) - .await; - - // input_buf_len = 0; - } - } - - Ok(output_count) - } - - // this function is highly coupled with async_calc_32bit, and is not intended to use in other place - async fn inner_dma_calc_32bit( - &mut self, - write_dma: impl Peripheral<P = impl WriteDma<T>>, - read_dma: impl Peripheral<P = impl ReadDma<T>>, - double_input: bool, // gether extra info to calc output_buf size - input_buf: &[u32], // input_buf, its content should be exact length for calculation - output: &mut [f64], // caller should uses this buf as a final output array - output_start_index: &mut usize, // the index of start point of the output for this round of calculation - ) { - // output_buf is the place to store raw value from CORDIC (via DMA). - // For buf size, we assume in this round of calculation: - // all input is 1 arg, and all calculation need 2 output, - // thus output_buf will always be long enough. - let mut output_buf = [0u32; INPUT_BUF_MAX_LEN * 2]; - - let mut output_buf_size = input_buf.len(); - if !self.config.res1_only { - // if we need 2 result for 1 input, then output_buf length should be 2x long. - output_buf_size *= 2; - }; - if double_input { - // if input itself is 2 args for 1 calculation, then output_buf length should be /2. - output_buf_size /= 2; - } - - let active_output_buf = &mut output_buf[..output_buf_size]; - - self.launch_a_dma_transfer(write_dma, read_dma, input_buf, active_output_buf) - .await; - - for &mut output_u32 in active_output_buf { - output[*output_start_index] = utils::q1_31_to_f64(output_u32); - *output_start_index += 1; - } + Ok(minimal_res_length) } } // q1.15 related impl<'d, T: Instance> Cordic<'d, T> { - /// Run a blocking CORDIC calculation in q1.15 format - pub fn blocking_calc_16bit( - &mut self, - arg1s: &[f32], - arg2s: Option<&[f32]>, - output: &mut [f32], - ) -> Result<usize, CordicError> { - if arg1s.is_empty() { + /// Run a blocking CORDIC calculation in q1.15 format + /// + /// Notice:: + /// User will take respond to merge two u16 arguments into one u32 data, and/or split one u32 data into two u16 results. + pub fn blocking_calc_16bit(&mut self, arg: &[u32], res: &mut [u32]) -> Result<usize, CordicError> { + if arg.is_empty() { return Ok(0); } - let output_length_enough = match self.config.res1_only { - true => output.len() >= arg1s.len(), - false => output.len() >= 2 * arg1s.len(), - }; - - if !output_length_enough { - return Err(CordicError::OutputLengthNotEnough); + if arg.len() > res.len() { + return Err(CordicError::ResultLengthNotEnough); } - self.check_input_f32(arg1s, arg2s)?; + let res_cnt = arg.len(); // In q1.15 mode, 1 write/read to access 2 arguments/results self.peri.set_argument_count(AccessCount::One); @@ -488,83 +379,53 @@ impl<'d, T: Instance> Cordic<'d, T> { self.peri.set_data_width(Width::Bits16, Width::Bits16); - let mut output_count = 0; + // To use cordic preload function, the first value is special. + // It is loaded to CORDIC WDATA register out side of loop + let first_value = arg[0]; - // In q1.15 mode, we always fill 1 pair of 16bit value into WDATA register. - // If arg2s is None or empty array, we assume arg2 value always 1.0 (as reset value for ARG2). - // If arg2s has some value, and but not as long as arg1s, - // we fill the reset of arg2 values with last value from arg2s (as q1.31 version does) + // preload 1st value to CORDIC, to start the CORDIC calc + self.peri.write_argument(first_value); - let arg2_default_value = match arg2s { - Some(arg2s) if !arg2s.is_empty() => arg2s[arg2s.len() - 1], - _ => 1.0, - }; + let mut cnt = 0; - let mut args = arg1s.iter().zip( - arg2s - .unwrap_or(&[]) - .iter() - .chain(core::iter::repeat(&arg2_default_value)), - ); + for &arg_val in &arg[1..] { + // preload arg_val (for next calc) + self.peri.write_argument(arg_val); - let (&arg1, &arg2) = args.next().unwrap(); - - // preloading 1 pair of arguments - self.blocking_write_f32(arg1, arg2)?; - - for (&arg1, &arg2) in args { - self.blocking_write_f32(arg1, arg2)?; - self.blocking_read_f32_to_buf(output, &mut output_count); + // then read current result out + res[cnt] = self.peri.read_result(); + cnt += 1; } - // read last pair of value from cordic - self.blocking_read_f32_to_buf(output, &mut output_count); + // read last result out + res[cnt] = self.peri.read_result(); + // cnt += 1; - Ok(output_count) + Ok(res_cnt) } - fn blocking_write_f32(&mut self, arg1: f32, arg2: f32) -> Result<(), NumberOutOfRange> { - self.peri.write_argument(utils::f32_args_to_u32(arg1, arg2)?); - Ok(()) - } - - fn blocking_read_f32_to_buf(&mut self, result_buf: &mut [f32], result_index: &mut usize) { - let (res1, res2) = utils::u32_to_f32_res(self.peri.read_result()); - - result_buf[*result_index] = res1; - *result_index += 1; - - // We don't care about whether the function return 1 or 2 results, - // the only thing matter is whether user want 1 or 2 results. - if !self.config.res1_only { - result_buf[*result_index] = res2; - *result_index += 1; - } - } - - /// Run a async CORDIC calculation in q1.15 format + /// Run a async CORDIC calculation in q1.15 format + /// + /// Notice:: + /// User will take respond to merge two u16 arguments into one u32 data, and/or split one u32 data into two u16 results. pub async fn async_calc_16bit( &mut self, write_dma: impl Peripheral<P = impl WriteDma<T>>, read_dma: impl Peripheral<P = impl ReadDma<T>>, - arg1s: &[f32], - arg2s: Option<&[f32]>, - output: &mut [f32], + arg: &[u32], + res: &mut [u32], ) -> Result<usize, CordicError> { - if arg1s.is_empty() { + if arg.is_empty() { return Ok(0); } - let output_length_enough = match self.config.res1_only { - true => output.len() >= arg1s.len(), - false => output.len() >= 2 * arg1s.len(), - }; - - if !output_length_enough { - return Err(CordicError::OutputLengthNotEnough); + if arg.len() > res.len() { + return Err(CordicError::ResultLengthNotEnough); } - self.check_input_f32(arg1s, arg2s)?; + let res_cnt = arg.len(); + + let active_res_buf = &mut res[..res_cnt]; into_ref!(write_dma, read_dma); @@ -574,142 +435,96 @@ impl<'d, T: Instance> Cordic<'d, T> { self.peri.set_data_width(Width::Bits16, Width::Bits16); - let mut output_count = 0; - let mut input_buf = [0u32; INPUT_BUF_MAX_LEN]; - let mut input_buf_len = 0; + let write_req = write_dma.request(); + let read_req = read_dma.request(); - // In q1.15 mode, we always fill 1 pair of 16bit value into WDATA register. - // If arg2s is None or empty array, we assume arg2 value always 1.0 (as reset value for ARG2). - // If arg2s has some value, and but not as long as arg1s, - // we fill the reset of arg2 values with last value from arg2s (as CORDIC behavior on q1.31 format) + self.peri.enable_write_dma(); + self.peri.enable_read_dma(); - let arg2_default_value = match arg2s { - Some(arg2s) if !arg2s.is_empty() => arg2s[arg2s.len() - 1], - _ => 1.0, - }; + let _on_drop = OnDrop::new(|| { + self.peri.disable_write_dma(); + self.peri.disable_read_dma(); + }); - let args = arg1s.iter().zip( - arg2s - .unwrap_or(&[]) - .iter() - .chain(core::iter::repeat(&arg2_default_value)), - ); - - for (&arg1, &arg2) in args { - input_buf[input_buf_len] = utils::f32_args_to_u32(arg1, arg2)?; - input_buf_len += 1; - - if input_buf_len == INPUT_BUF_MAX_LEN { - self.inner_dma_calc_16bit(&mut write_dma, &mut read_dma, &input_buf, output, &mut output_count) - .await; - } - } - - if input_buf_len > 0 { - self.inner_dma_calc_16bit( + unsafe { + let write_transfer = dma::Transfer::new_write( &mut write_dma, + write_req, + arg, + T::regs().wdata().as_ptr() as *mut _, + Default::default(), + ); + + let read_transfer = dma::Transfer::new_read( &mut read_dma, - &input_buf[..input_buf_len], - output, - &mut output_count, - ) - .await; + read_req, + T::regs().rdata().as_ptr() as *mut _, + active_res_buf, + Default::default(), + ); + + embassy_futures::join::join(write_transfer, read_transfer).await; } - Ok(output_count) - } - - // this function is highly coupled with async_calc_16bit, and is not intended to use in other place - async fn inner_dma_calc_16bit( - &mut self, - write_dma: impl Peripheral<P = impl WriteDma<T>>, - read_dma: impl Peripheral<P = impl ReadDma<T>>, - input_buf: &[u32], // input_buf, its content should be exact length for calculation - output: &mut [f32], // caller should uses this buf as a final output array - output_start_index: &mut usize, // the index of start point of the output for this round of calculation - ) { - // output_buf is the place to store raw value from CORDIC (via DMA). - let mut output_buf = [0u32; INPUT_BUF_MAX_LEN]; - - let active_output_buf = &mut output_buf[..input_buf.len()]; - - self.launch_a_dma_transfer(write_dma, read_dma, input_buf, active_output_buf) - .await; - - for &mut output_u32 in active_output_buf { - let (res1, res2) = utils::u32_to_f32_res(output_u32); - - output[*output_start_index] = res1; - *output_start_index += 1; - - if !self.config.res1_only { - output[*output_start_index] = res2; - *output_start_index += 1; - } - } + Ok(res_cnt) } } -// check input value ARG1, ARG2, SCALE and FUNCTION are compatible with each other -macro_rules! check_input_value { - ($func_name:ident, $float_type:ty) => { +macro_rules! check_arg_value { + ($func_arg1_name:ident, $func_arg2_name:ident, $float_type:ty) => { impl<'d, T: Instance> Cordic<'d, T> { - fn $func_name(&self, arg1s: &[$float_type], arg2s: Option<&[$float_type]>) -> Result<(), ArgError> { + /// check input value ARG1, SCALE and FUNCTION are compatible with each other + pub fn $func_arg1_name(&self, arg: $float_type) -> Result<(), ArgError> { let config = &self.config; use Function::*; struct Arg1ErrInfo { scale: Option<Scale>, - range: [f32; 2], + range: [f32; 2], // f32 is ok, it only used in error display inclusive_upper_bound: bool, } - // check ARG1 value let err_info = match config.function { - Cos | Sin | Phase | Modulus | Arctan if arg1s.iter().any(|v| !(-1.0..=1.0).contains(v)) => { - Some(Arg1ErrInfo { - scale: None, - range: [-1.0, 1.0], - inclusive_upper_bound: true, - }) - } + Cos | Sin | Phase | Modulus | Arctan if !(-1.0..=1.0).contains(arg) => Some(Arg1ErrInfo { + scale: None, + range: [-1.0, 1.0], + inclusive_upper_bound: true, + }), - Cosh | Sinh if arg1s.iter().any(|v| !(-0.559..=0.559).contains(v)) => Some(Arg1ErrInfo { + Cosh | Sinh if !(-0.559..=0.559).contains(arg) => Some(Arg1ErrInfo { scale: None, range: [-0.559, 0.559], inclusive_upper_bound: true, }), - Arctanh if arg1s.iter().any(|v| !(-0.403..=0.403).contains(v)) => Some(Arg1ErrInfo { + Arctanh if !(-0.403..=0.403).contains(arg) => Some(Arg1ErrInfo { scale: None, range: [-0.403, 0.403], inclusive_upper_bound: true, }), Ln => match config.scale { - Scale::Arg1o2Res2 if arg1s.iter().any(|v| !(0.0535..0.5).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1o2Res2 if !(0.0535..0.5).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1o2Res2), range: [0.0535, 0.5], inclusive_upper_bound: false, }), - Scale::Arg1o4Res4 if arg1s.iter().any(|v| !(0.25..0.75).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1o4Res4 if !(0.25..0.75).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1o4Res4), range: [0.25, 0.75], inclusive_upper_bound: false, }), - Scale::Arg1o8Res8 if arg1s.iter().any(|v| !(0.375..0.875).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1o8Res8 if !(0.375..0.875).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1o8Res8), range: [0.375, 0.875], inclusive_upper_bound: false, }), - Scale::Arg1o16Res16 if arg1s.iter().any(|v| !(0.4375..0.584).contains(v)) => { - Some(Arg1ErrInfo { - scale: Some(Scale::Arg1o16Res16), - range: [0.4375, 0.584], - inclusive_upper_bound: false, - }) - } + Scale::Arg1o16Res16 if !(0.4375..0.584).contains(arg) => Some(Arg1ErrInfo { + scale: Some(Scale::Arg1o16Res16), + range: [0.4375, 0.584], + inclusive_upper_bound: false, + }), Scale::Arg1o2Res2 | Scale::Arg1o4Res4 | Scale::Arg1o8Res8 | Scale::Arg1o16Res16 => None, @@ -717,17 +532,17 @@ macro_rules! check_input_value { }, Sqrt => match config.scale { - Scale::Arg1Res1 if arg1s.iter().any(|v| !(0.027..0.75).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1Res1 if !(0.027..0.75).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1Res1), range: [0.027, 0.75], inclusive_upper_bound: false, }), - Scale::Arg1o2Res2 if arg1s.iter().any(|v| !(0.375..0.875).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1o2Res2 if !(0.375..0.875).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1o2Res2), range: [0.375, 0.875], inclusive_upper_bound: false, }), - Scale::Arg1o4Res4 if arg1s.iter().any(|v| !(0.4375..0.584).contains(v)) => Some(Arg1ErrInfo { + Scale::Arg1o4Res4 if !(0.4375..0.584).contains(arg) => Some(Arg1ErrInfo { scale: Some(Scale::Arg1o4Res4), range: [0.4375, 0.584], inclusive_upper_bound: false, @@ -749,33 +564,35 @@ macro_rules! check_input_value { }); } - // check ARG2 value - if let Some(arg2s) = arg2s { - struct Arg2ErrInfo { - range: [f32; 2], - } + Ok(()) + } - let err_info = match config.function { - Cos | Sin if arg2s.iter().any(|v| !(0.0..=1.0).contains(v)) => { - Some(Arg2ErrInfo { range: [0.0, 1.0] }) - } + /// check input value ARG2 and FUNCTION are compatible with each other + pub fn $func_arg2_name(&self, arg: $float_type) -> Result<(), ArgError> { + let config = &self.config; - Phase | Modulus if arg2s.iter().any(|v| !(-1.0..=1.0).contains(v)) => { - Some(Arg2ErrInfo { range: [-1.0, 1.0] }) - } + use Function::*; - Cos | Sin | Phase | Modulus | Arctan | Cosh | Sinh | Arctanh | Ln | Sqrt => None, - }; + struct Arg2ErrInfo { + range: [f32; 2], // f32 is ok, it only used in error display + } - if let Some(err) = err_info { - return Err(ArgError { - func: config.function, - scale: None, - arg_range: err.range, - inclusive_upper_bound: true, - arg_type: ArgType::Arg2, - }); - } + let err_info = match config.function { + Cos | Sin if !(0.0..=1.0).contains(arg) => Some(Arg2ErrInfo { range: [0.0, 1.0] }), + + Phase | Modulus if !(-1.0..=1.0).contains(arg) => Some(Arg2ErrInfo { range: [-1.0, 1.0] }), + + Cos | Sin | Phase | Modulus | Arctan | Cosh | Sinh | Arctanh | Ln | Sqrt => None, + }; + + if let Some(err) = err_info { + return Err(ArgError { + func: config.function, + scale: None, + arg_range: err.range, + inclusive_upper_bound: true, + arg_type: ArgType::Arg2, + }); } Ok(()) @@ -784,8 +601,8 @@ macro_rules! check_input_value { }; } -check_input_value!(check_input_f64, f64); -check_input_value!(check_input_f32, f32); +check_arg_value!(check_f64_arg1, check_f64_arg2, &f64); +check_arg_value!(check_f32_arg1, check_f32_arg2, &f32); foreach_interrupt!( ($inst:ident, cordic, $block:ident, GLOBAL, $irq:ident) => { diff --git a/embassy-stm32/src/cordic/utils.rs b/embassy-stm32/src/cordic/utils.rs index 41821d6e2..008f50270 100644 --- a/embassy-stm32/src/cordic/utils.rs +++ b/embassy-stm32/src/cordic/utils.rs @@ -1,4 +1,4 @@ -//! Common match utils +//! Common math utils use super::errors::NumberOutOfRange; macro_rules! floating_fixed_convert { @@ -60,16 +60,3 @@ floating_fixed_convert!( 15, 0x3800_0000u32 // binary form of 1f32^(-15) ); - -#[inline(always)] -pub(crate) fn f32_args_to_u32(arg1: f32, arg2: f32) -> Result<u32, NumberOutOfRange> { - Ok(f32_to_q1_15(arg1)? as u32 + ((f32_to_q1_15(arg2)? as u32) << 16)) -} - -#[inline(always)] -pub(crate) fn u32_to_f32_res(reg_value: u32) -> (f32, f32) { - let res1 = q1_15_to_f32((reg_value & ((1u32 << 16) - 1)) as u16); - let res2 = q1_15_to_f32((reg_value >> 16) as u16); - - (res1, res2) -} diff --git a/examples/stm32h5/src/bin/cordic.rs b/examples/stm32h5/src/bin/cordic.rs index d49f75b8f..73e873574 100644 --- a/examples/stm32h5/src/bin/cordic.rs +++ b/examples/stm32h5/src/bin/cordic.rs @@ -3,7 +3,7 @@ use defmt::*; use embassy_executor::Spawner; -use embassy_stm32::cordic; +use embassy_stm32::cordic::{self, utils}; use {defmt_rtt as _, panic_probe as _}; #[embassy_executor::main] @@ -16,20 +16,63 @@ async fn main(_spawner: Spawner) { cordic::Function::Sin, Default::default(), Default::default(), - false, )), ); - let mut output = [0f64; 16]; + // for output buf, the length is not that strict, larger than minimal required is ok. + let mut output_f64 = [0f64; 19]; + let mut output_u32 = [0u32; 21]; - let arg1 = [1.0, 0.0, -1.0]; // for trigonometric function, the ARG1 value [-pi, pi] should be map to [-1, 1] - let arg2 = [0.5, 1.0]; + // tips: + // CORDIC peripheral has some strict on input value, you can also use ".check_argX_fXX()" methods + // to make sure your input values are compatible with current CORDIC setup. + let arg1 = [-1.0, -0.5, 0.0, 0.5, 1.0]; // for trigonometric function, the ARG1 value [-pi, pi] should be map to [-1, 1] + let arg2 = [0.5]; // and for Sin function, ARG2 should be in [0, 1] - let cnt = unwrap!( + let mut input_buf = [0u32; 9]; + + // convert input from floating point to fixed point + input_buf[0] = unwrap!(utils::f64_to_q1_31(arg1[0])); + input_buf[1] = unwrap!(utils::f64_to_q1_31(arg2[0])); + + // If input length is small, blocking mode can be used to minimize overhead. + let cnt0 = unwrap!(cordic.blocking_calc_32bit( + &input_buf[..2], // input length is strict, since driver use its length to detect calculation count + &mut output_u32, + false, + false + )); + + // convert result from fixed point into floating point + for (&u32_val, f64_val) in output_u32[..cnt0].iter().zip(output_f64.iter_mut()) { + *f64_val = utils::q1_31_to_f64(u32_val); + } + + // convert input from floating point to fixed point + // + // first value from arg1 is used, so truncate to arg1[1..] + for (&f64_val, u32_val) in arg1[1..].iter().zip(input_buf.iter_mut()) { + *u32_val = unwrap!(utils::f64_to_q1_31(f64_val)); + } + + // If calculation is a little longer, async mode can make use of DMA, and let core do some other stuff. + let cnt1 = unwrap!( cordic - .async_calc_32bit(&mut dp.GPDMA1_CH0, &mut dp.GPDMA1_CH1, &arg1, Some(&arg2), &mut output,) + .async_calc_32bit( + &mut dp.GPDMA1_CH0, + &mut dp.GPDMA1_CH1, + &input_buf[..arg1.len() - 1], // limit input buf to its actual length + &mut output_u32, + true, + false + ) .await ); - println!("async calc 32bit: {}", output[..cnt]); + // convert result from fixed point into floating point + for (&u32_val, f64_val) in output_u32[..cnt1].iter().zip(output_f64[cnt0..cnt0 + cnt1].iter_mut()) { + *f64_val = utils::q1_31_to_f64(u32_val); + } + + println!("result: {}", output_f64[..cnt0 + cnt1]); } diff --git a/tests/stm32/src/bin/cordic.rs b/tests/stm32/src/bin/cordic.rs index cd2e9d6f7..669fd96ab 100644 --- a/tests/stm32/src/bin/cordic.rs +++ b/tests/stm32/src/bin/cordic.rs @@ -14,6 +14,7 @@ mod common; use common::*; use embassy_executor::Spawner; +use embassy_stm32::cordic::utils; use embassy_stm32::{bind_interrupts, cordic, peripherals, rng}; use num_traits::Float; use {defmt_rtt as _, panic_probe as _}; @@ -24,11 +25,12 @@ bind_interrupts!(struct Irqs { /* input value control, can be changed */ -const ARG1_LENGTH: usize = 9; -const ARG2_LENGTH: usize = 4; // this might not be the exact length of ARG2, since ARG2 need to be inside [0, 1] +const INPUT_U32_COUNT: usize = 9; +const INPUT_U8_COUNT: usize = 4 * INPUT_U32_COUNT; -const INPUT_Q1_31_LENGTH: usize = ARG1_LENGTH + ARG2_LENGTH; -const INPUT_U8_LENGTH: usize = 4 * INPUT_Q1_31_LENGTH; +// Assume first calculation needs 2 arguments, the reset needs 1 argument. +// And all calculation generate 2 results. +const OUTPUT_LENGTH: usize = (INPUT_U32_COUNT - 1) * 2; #[embassy_executor::main] async fn main(_spawner: Spawner) { @@ -42,43 +44,28 @@ async fn main(_spawner: Spawner) { let mut rng = rng::Rng::new(dp.RNG, Irqs); - let mut input_buf_u8 = [0u8; INPUT_U8_LENGTH]; + let mut input_buf_u8 = [0u8; INPUT_U8_COUNT]; defmt::unwrap!(rng.async_fill_bytes(&mut input_buf_u8).await); // convert every [u8; 4] to a u32, for a Q1.31 value - let input_q1_31 = unsafe { core::mem::transmute::<[u8; INPUT_U8_LENGTH], [u32; INPUT_Q1_31_LENGTH]>(input_buf_u8) }; + let mut input_q1_31 = unsafe { core::mem::transmute::<[u8; INPUT_U8_COUNT], [u32; INPUT_U32_COUNT]>(input_buf_u8) }; - let mut input_f64_buf = [0f64; INPUT_Q1_31_LENGTH]; + // ARG2 for Sin function should be inside [0, 1], set MSB to 0 of a Q1.31 value, will make sure it's no less than 0. + input_q1_31[1] &= !(1u32 << 31); - let mut cordic_output_f64_buf = [0f64; ARG1_LENGTH * 2]; + // + // CORDIC calculation + // - // convert Q1.31 value back to f64, for software calculation verify - for (val_u32, val_f64) in input_q1_31.iter().zip(input_f64_buf.iter_mut()) { - *val_f64 = cordic::utils::q1_31_to_f64(*val_u32); - } - - let mut arg2_f64_buf = [0f64; ARG2_LENGTH]; - let mut arg2_f64_len = 0; - - // check if ARG2 is in range [0, 1] (limited by CORDIC peripheral with Sin mode) - for &arg2 in &input_f64_buf[ARG1_LENGTH..] { - if arg2 >= 0.0 { - arg2_f64_buf[arg2_f64_len] = arg2; - arg2_f64_len += 1; - } - } - - // the actual value feed to CORDIC - let arg1_f64_ls = &input_f64_buf[..ARG1_LENGTH]; - let arg2_f64_ls = &arg2_f64_buf[..arg2_f64_len]; + let mut output_q1_31 = [0u32; OUTPUT_LENGTH]; + // setup Cordic driver let mut cordic = cordic::Cordic::new( dp.CORDIC, defmt::unwrap!(cordic::Config::new( cordic::Function::Sin, Default::default(), Default::default(), - false, )), ); @@ -88,67 +75,66 @@ async fn main(_spawner: Spawner) { #[cfg(any(feature = "stm32h563zi", feature = "stm32u585ai", feature = "stm32u5a5zj"))] let (mut write_dma, mut read_dma) = (dp.GPDMA1_CH4, dp.GPDMA1_CH5); - let cordic_start_point = embassy_time::Instant::now(); + // calculate first result using blocking mode + let cnt0 = defmt::unwrap!(cordic.blocking_calc_32bit(&input_q1_31[..2], &mut output_q1_31, false, false)); - let cnt = unwrap!( + // calculate rest results using async mode + let cnt1 = defmt::unwrap!( cordic .async_calc_32bit( &mut write_dma, &mut read_dma, - arg1_f64_ls, - Some(arg2_f64_ls), - &mut cordic_output_f64_buf, + &input_q1_31[2..], + &mut output_q1_31[cnt0..], + true, + false, ) .await ); - let cordic_end_point = embassy_time::Instant::now(); + // all output value length should be the same as our output buffer size + defmt::assert_eq!(cnt0 + cnt1, output_q1_31.len()); - // since we get 2 output for 1 calculation, the output length should be ARG1_LENGTH * 2 - defmt::assert!(cnt == ARG1_LENGTH * 2); + let mut cordic_result_f64 = [0.0f64; OUTPUT_LENGTH]; - let mut software_output_f64_buf = [0f64; ARG1_LENGTH * 2]; + for (f64_val, u32_val) in cordic_result_f64.iter_mut().zip(output_q1_31) { + *f64_val = utils::q1_31_to_f64(u32_val); + } - // for software calc, if there is no ARG2 value, insert a 1.0 as value (the reset value for ARG2 in CORDIC) - let arg2_f64_ls = if arg2_f64_len == 0 { &[1.0] } else { arg2_f64_ls }; + // + // software calculation + // - let software_inputs = arg1_f64_ls + let mut software_result_f64 = [0.0f64; OUTPUT_LENGTH]; + + let arg2 = utils::q1_31_to_f64(input_q1_31[1]); + + for (&arg1, res) in input_q1_31 .iter() - .zip( - arg2_f64_ls - .iter() - .chain(core::iter::repeat(&arg2_f64_ls[arg2_f64_ls.len() - 1])), - ) - .zip(software_output_f64_buf.chunks_mut(2)); + .enumerate() + .filter_map(|(idx, val)| if idx != 1 { Some(val) } else { None }) + .zip(software_result_f64.chunks_mut(2)) + { + let arg1 = utils::q1_31_to_f64(arg1); - let software_start_point = embassy_time::Instant::now(); - - for ((arg1, arg2), res) in software_inputs { let (raw_res1, raw_res2) = (arg1 * core::f64::consts::PI).sin_cos(); - (res[0], res[1]) = (raw_res1 * arg2, raw_res2 * arg2); } - let software_end_point = embassy_time::Instant::now(); + // + // check result are the same + // - for (cordic_res, software_res) in cordic_output_f64_buf[..cnt] + for (cordic_res, software_res) in cordic_result_f64[..cnt0 + cnt1] .chunks(2) - .zip(software_output_f64_buf.chunks(2)) + .zip(software_result_f64.chunks(2)) { for (cord_res, soft_res) in cordic_res.iter().zip(software_res.iter()) { + // 2.0.powi(-19) is the max residual error for Sin function, in q1.31 format, with 24 iterations (aka PRECISION = 6) defmt::assert!((cord_res - soft_res).abs() <= 2.0.powi(-19)); } } - // This comparison is just for fun. Since it not a equal compare: - // software use 64-bit floating point, but CORDIC use 32-bit fixed point. - defmt::trace!( - "calculate count: {}, Cordic time: {} us, software time: {} us", - ARG1_LENGTH, - (cordic_end_point - cordic_start_point).as_micros(), - (software_end_point - software_start_point).as_micros() - ); - info!("Test OK"); cortex_m::asm::bkpt(); }