diff --git a/Cargo.lock b/Cargo.lock index e077a95..55f8bbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -828,6 +828,12 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "libredox" version = "0.0.1" @@ -885,6 +891,7 @@ dependencies = [ "fixed", "fixed-macro", "format_no_std", + "libm", "packed_struct", "panic-probe", "portable-atomic", diff --git a/Cargo.toml b/Cargo.toml index 9755d54..2b51105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ fixed = "1.23.1" fixed-macro = "1.2" static_cell = "2" portable-atomic = { version = "1.5", features = ["critical-section"] } +libm = { version = "0.2.8" } #cortex-m = { version = "0.7.6", features = ["critical-section-single-core"] } cortex-m = { version = "0.7.6", features = ["inline-asm"] } diff --git a/src/input.rs b/src/input.rs index de0acb4..fc37f58 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,21 +1,53 @@ -use defmt::{debug, info}; +use core::task::Poll; + +use defmt::{debug, Format}; +use embassy_futures::{join::join, yield_now}; use embassy_rp::{ - clocks::RoscRng, - flash::{Async, Flash, ERASE_SIZE}, + flash::{Async, Flash}, gpio::{Input, Output, Pin}, peripherals::{ - DMA_CH0, FLASH, PIN_10, PIN_11, PIN_15, PIN_16, PIN_17, PIN_18, PIN_19, PIN_20, PIN_21, - PIN_22, PIN_23, PIN_24, PIN_25, PIN_29, PIN_5, PIN_8, PIN_9, SPI0, + FLASH, PIN_10, PIN_11, PIN_16, PIN_17, PIN_18, PIN_19, PIN_20, PIN_21, PIN_22, PIN_23, + PIN_24, PIN_5, PIN_8, PIN_9, PWM_CH4, PWM_CH6, SPI0, }, + pwm::Pwm, spi::Spi, }; use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal}; -use rand::RngCore; +use embassy_time::{Instant, Timer}; +use packed_struct::derive::PackedStruct; -use crate::{gcc_hid::GcReport, ADDR_OFFSET, FLASH_SIZE}; +use crate::{ + gcc_hid::GcReport, + stick::{linearize, run_kalman, FilterGains, StickParams}, + PackedFloat, ADDR_OFFSET, FLASH_SIZE, +}; pub static GCC_SIGNAL: Signal = Signal::new(); +static STICK_SIGNAL: Signal = Signal::new(); + +#[derive(Debug, Clone, Default, Format, PackedStruct)] +#[packed_struct(endian = "msb")] +pub struct ControllerConfig { + #[packed_field(size_bits = "8")] + pub config_version: u8, + #[packed_field(size_bits = "32")] + pub ax_waveshaping: PackedFloat, + #[packed_field(size_bits = "32")] + pub ay_waveshaping: PackedFloat, + #[packed_field(size_bits = "32")] + pub cx_waveshaping: PackedFloat, + #[packed_field(size_bits = "32")] + pub cy_waveshaping: PackedFloat, +} + +struct StickState { + ax: u8, + ay: u8, + cx: u8, + cy: u8, +} + #[derive(PartialEq, Eq)] enum Stick { ControlStick, @@ -28,10 +60,10 @@ enum StickAxis { YAxis, } -fn read_ext_adc<'a, Acs: Pin, Ccs: Pin>( +fn read_ext_adc<'a, Acs: Pin, Ccs: Pin, I: embassy_rp::spi::Instance, M: embassy_rp::spi::Mode>( which_stick: Stick, which_axis: StickAxis, - spi: &mut Spi<'a, SPI0, embassy_rp::spi::Blocking>, + spi: &mut Spi<'a, I, M>, spi_acs: &mut Output<'a, Acs>, spi_ccs: &mut Output<'a, Ccs>, ) -> u16 { @@ -61,6 +93,134 @@ fn read_ext_adc<'a, Acs: Pin, Ccs: Pin>( return temp_value; } +/// Gets the average stick state over a 1ms interval in a non-blocking fashion. +async fn update_stick_states< + 'a, + Acs: Pin, + Ccs: Pin, + I: embassy_rp::spi::Instance, + M: embassy_rp::spi::Mode, +>( + mut spi: &mut Spi<'a, I, M>, + mut spi_acs: &mut Output<'a, Acs>, + mut spi_ccs: &mut Output<'a, Ccs>, + adc_scale: f32, + controlstick_params: &StickParams, + cstick_params: &StickParams, + controller_config: &ControllerConfig, + filter_gains: &FilterGains, +) { + let mut adc_count = 0u32; + let mut ax_sum = 0u32; + let mut ay_sum = 0u32; + let mut cx_sum = 0u32; + let mut cy_sum = 0u32; + + // TODO: lower interval possible? + let mut timer = Timer::at(Instant::now() + embassy_time::Duration::from_millis(1)); + + while embassy_futures::poll_once(&mut timer) != Poll::Ready(()) { + adc_count += 1; + ax_sum += read_ext_adc( + Stick::ControlStick, + StickAxis::XAxis, + &mut spi, + &mut spi_acs, + &mut spi_ccs, + ) as u32; + ay_sum += read_ext_adc( + Stick::ControlStick, + StickAxis::YAxis, + &mut spi, + &mut spi_acs, + &mut spi_ccs, + ) as u32; + cx_sum += read_ext_adc( + Stick::CStick, + StickAxis::XAxis, + &mut spi, + &mut spi_acs, + &mut spi_ccs, + ) as u32; + cy_sum += read_ext_adc( + Stick::CStick, + StickAxis::YAxis, + &mut spi, + &mut spi_acs, + &mut spi_ccs, + ) as u32; + + // with this, we can poll the sticks at 1000Hz (ish), while updating + // the rest of the controller (the buttons) much faster, to ensure + // better input integrity for button inputs. + yield_now().await; + } + + timer.await; + + let raw_controlstick_x = (ax_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale; + let raw_controlstick_y = (ay_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale; + let raw_cstick_x = (cx_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale; + let raw_cstick_y = (cy_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale; + + let x_z = linearize(raw_controlstick_x, &controlstick_params.fit_coeffs_x); + let y_z = linearize(raw_controlstick_y, &controlstick_params.fit_coeffs_y); + + let pos_cx = linearize(raw_cstick_x, &cstick_params.fit_coeffs_x); + let pos_cy = linearize(raw_cstick_y, &cstick_params.fit_coeffs_y); + + let (x_pos_filt, y_pos_filt) = run_kalman(x_z, y_z, controller_config, filter_gains); + + STICK_SIGNAL.signal(StickState { + ax: 127, + ay: 127, + cx: 127, + cy: 127, + }) +} + +fn update_button_states< + A: Pin, + B: Pin, + X: Pin, + Y: Pin, + Start: Pin, + L: Pin, + R: Pin, + Z: Pin, + DLeft: Pin, + DRight: Pin, + DUp: Pin, + DDown: Pin, +>( + gcc_state: &mut GcReport, + btn_a: &Input<'_, A>, + btn_b: &Input<'_, B>, + btn_x: &Input<'_, X>, + btn_y: &Input<'_, Y>, + btn_start: &Input<'_, Start>, + btn_l: &Input<'_, L>, + btn_r: &Input<'_, R>, + btn_z: &Input<'_, Z>, + btn_dleft: &Input<'_, DLeft>, + btn_dright: &Input<'_, DRight>, + btn_dup: &Input<'_, DUp>, + btn_ddown: &Input<'_, DDown>, +) { + gcc_state.buttons_1.button_a = btn_a.is_low(); + gcc_state.buttons_1.button_b = btn_b.is_low(); + gcc_state.buttons_1.button_x = btn_x.is_low(); + gcc_state.buttons_1.button_y = btn_y.is_low(); + gcc_state.buttons_2.button_z = btn_z.is_low(); + gcc_state.buttons_2.button_start = btn_start.is_low(); + gcc_state.buttons_2.button_l = btn_l.is_low(); + gcc_state.buttons_2.button_r = btn_r.is_low(); + gcc_state.buttons_1.dpad_left = btn_dleft.is_low(); + gcc_state.buttons_1.dpad_right = btn_dright.is_low(); + gcc_state.buttons_1.dpad_up = btn_dup.is_low(); + gcc_state.buttons_1.dpad_down = btn_ddown.is_low(); +} + #[embassy_executor::task] pub async fn input_loop( mut flash: Flash<'static, FLASH, Async, FLASH_SIZE>, @@ -76,26 +236,67 @@ pub async fn input_loop( btn_x: Input<'static, PIN_18>, btn_y: Input<'static, PIN_19>, btn_start: Input<'static, PIN_5>, - btn_rumble: Input<'static, PIN_25>, - btn_brake: Input<'static, PIN_29>, + pwm_rumble: Pwm<'static, PWM_CH4>, + pwm_brake: Pwm<'static, PWM_CH6>, mut spi: Spi<'static, SPI0, embassy_rp::spi::Blocking>, mut spi_acs: Output<'static, PIN_24>, mut spi_ccs: Output<'static, PIN_23>, -) -> ! { +) { let mut gcc_state = GcReport::default(); - let mut rng = RoscRng; + // Set the stick states to the center + gcc_state.stick_x = 127; + gcc_state.stick_y = 127; + gcc_state.cstick_x = 127; + gcc_state.cstick_y = 127; let mut uid = [0u8; 1]; flash.blocking_read(ADDR_OFFSET, &mut uid).unwrap(); debug!("Read from flash: {:02X}", uid); - loop { - gcc_state.buttons_1.button_a = btn_z.is_low(); - gcc_state.stick_x = rng.next_u32() as u8; - gcc_state.stick_y = rng.next_u32() as u8; + // TODO: load controller config here - GCC_SIGNAL.signal(gcc_state); - } + let stick_state_fut = async { + loop { + // update_stick_states(&mut spi, &mut spi_acs, &mut spi_ccs, 1.0).await; + } + }; + + let input_fut = async { + loop { + update_button_states( + &mut gcc_state, + &btn_a, + &btn_b, + &btn_x, + &btn_y, + &btn_start, + &btn_l, + &btn_r, + &btn_z, + &btn_dleft, + &btn_dright, + &btn_dup, + &btn_ddown, + ); + + yield_now().await; + + // not every loop pass is going to update the stick state + match STICK_SIGNAL.try_take() { + Some(stick_state) => { + gcc_state.stick_x = stick_state.ax; + gcc_state.stick_y = stick_state.ay; + gcc_state.cstick_x = stick_state.cx; + gcc_state.cstick_y = stick_state.cy; + } + None => (), + } + + GCC_SIGNAL.signal(gcc_state); + } + }; + + join(input_fut, stick_state_fut).await; } diff --git a/src/main.rs b/src/main.rs index 6de22be..0ea5ebb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,8 +6,11 @@ #![no_main] mod gcc_hid; mod input; +mod stick; -use defmt::{debug, info}; +use core::ops::Deref; + +use defmt::{debug, info, Format}; use embassy_executor::Executor; use embassy_rp::{ bind_interrupts, @@ -15,12 +18,14 @@ use embassy_rp::{ gpio::{self, Input}, multicore::{spawn_core1, Stack}, peripherals::USB, + pwm::Pwm, spi::{self, Spi}, usb::{Driver, InterruptHandler}, }; use gcc_hid::usb_transfer_loop; use gpio::{Level, Output}; use input::input_loop; +use packed_struct::PackedStruct; use static_cell::StaticCell; use {defmt_rtt as _, panic_probe as _}; @@ -31,6 +36,31 @@ static EXECUTOR1: StaticCell = StaticCell::new(); const FLASH_SIZE: usize = 2 * 1024 * 1024; const ADDR_OFFSET: u32 = 0x100000; +/// wrapper type because packed_struct doesn't implement float +/// packing by default for some reason +#[derive(Debug, Format, Clone, Default)] +pub struct PackedFloat(f32); + +impl PackedStruct for PackedFloat { + type ByteArray = [u8; 4]; + + fn pack(&self) -> packed_struct::PackingResult { + Ok(self.to_be_bytes()) + } + + fn unpack(src: &Self::ByteArray) -> packed_struct::PackingResult { + Ok(Self(f32::from_be_bytes(*src))) + } +} + +impl Deref for PackedFloat { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + bind_interrupts!(struct Irqs { USBCTRL_IRQ => InterruptHandler; }); @@ -79,6 +109,17 @@ fn main() -> ! { let spi_acs = Output::new(p_acs, Level::High); // active low let spi_ccs = Output::new(p_ccs, Level::High); // active low + let mut pwm_config: embassy_rp::pwm::Config = Default::default(); + pwm_config.top = 255; + pwm_config.enable = true; + pwm_config.compare_b = 255; + + let pwm_rumble = Pwm::new_output_b(p.PWM_CH4, p.PIN_25, pwm_config.clone()); + let pwm_brake = Pwm::new_output_b(p.PWM_CH6, p.PIN_29, pwm_config.clone()); + + pwm_rumble.set_counter(255); + pwm_brake.set_counter(0); + executor0.run(|spawner| { spawner .spawn(input_loop( @@ -95,8 +136,8 @@ fn main() -> ! { Input::new(p.PIN_18, gpio::Pull::Up), Input::new(p.PIN_19, gpio::Pull::Up), Input::new(p.PIN_5, gpio::Pull::Up), - Input::new(p.PIN_25, gpio::Pull::Up), - Input::new(p.PIN_29, gpio::Pull::Up), + pwm_rumble, + pwm_brake, spi, spi_acs, spi_ccs, diff --git a/src/stick.rs b/src/stick.rs new file mode 100644 index 0000000..0789af1 --- /dev/null +++ b/src/stick.rs @@ -0,0 +1,263 @@ +// vast majority of this is taken from Phob firmware + +use defmt::Format; +use libm::fabs; + +use crate::input::ControllerConfig; + +/// fit order for the linearization +const FIT_ORDER: usize = 3; +const N_COEFFS: usize = FIT_ORDER + 1; +const NO_OF_NOTCHES: usize = 16; +const MAX_ORDER: usize = 20; + +#[derive(Clone, Debug, Default, Format)] +pub struct StickParams { + // these are the linearization coefficients + pub fit_coeffs_x: [f32; N_COEFFS], + pub fit_coeffs_y: [f32; N_COEFFS], + + // these are the notch remap parameters + pub affine_coeffs_x: [[f32; 16]; 4], // affine transformation coefficients for all regions of the stick + pub boundary_angles_x: [f32; 4], // angles at the boundaries between regions of the stick (in the plane) +} + +#[derive(Clone, Debug, Default, Format)] +pub struct FilterGains { + /// What's the max stick distance from the center + pub max_stick: f32, + /// filtered velocity terms + /// how fast the filtered velocity falls off in the absence of stick movement. + /// Probably don't touch this. + pub x_vel_decay: f32, //0.1 default for 1.2ms timesteps, larger for bigger timesteps + pub y_vel_decay: f32, + /// how much the current position disagreement impacts the filtered velocity. + /// Probably don't touch this. + pub x_vel_pos_factor: f32, //0.01 default for 1.2ms timesteps, larger for bigger timesteps + pub y_vel_pos_factor: f32, + /// how much to ignore filtered velocity when computing the new stick position. + /// DO CHANGE THIS + /// Higher gives shorter rise times and slower fall times (more pode, less snapback) + pub x_vel_damp: f32, //0.125 default for 1.2ms timesteps, smaller for bigger timesteps + pub y_vel_damp: f32, + /// speed and accel thresholds below which we try to follow the stick better + /// These may need tweaking according to how noisy the signal is + /// If it's noisier, we may need to add additional filtering + /// If the timesteps are *really small* then it may need to be increased to get + /// above the noise floor. Or some combination of filtering and playing with + /// the thresholds. + pub vel_thresh: f32, //1 default for 1.2ms timesteps, larger for bigger timesteps + pub accel_thresh: f32, //5 default for 1.2ms timesteps, larger for bigger timesteps + /// This just applies a low-pass filter. + /// The purpose is to provide delay for single-axis ledgedashes. + /// Must be between 0 and 1. Larger = more smoothing and delay. + pub x_smoothing: f32, + pub y_smoothing: f32, + /// Same thing but for C-stick + pub c_xsmoothing: f32, + pub c_ysmoothing: f32, +} + +#[derive(Clone, Debug, Default)] +struct LinearizeCalibrationOutput { + pub fit_coeffs_x: [f64; N_COEFFS], + pub fit_coeffs_y: [f64; N_COEFFS], + + pub out_x: [f32; NO_OF_NOTCHES], + pub out_y: [f32; NO_OF_NOTCHES], +} + +pub fn run_kalman( + x_z: f32, + y_z: f32, + controller_config: &ControllerConfig, + filter_gains: &FilterGains, +) -> (f32, f32) { + todo!() +} + +fn curve_fit_power(base: f64, exponent: u32) -> f64 { + if exponent == 0 { + return 1.0; + } + + let mut val = base; + for _ in 1..exponent { + val *= base; + } + + val +} + +fn sub_col( + matrix: &[[f64; N]; N], + t: &[f64; MAX_ORDER], + col: usize, + n: usize, +) -> [[f64; N]; N] { + let mut m = *matrix; + for i in 0..n { + m[i][col] = t[i]; + } + m +} + +fn det(matrix: &[[f64; N]; N]) -> f64 { + let sign = trianglize(matrix); + if sign == 0 { + return 0.; + } + let mut p = 1f64; + for i in 0..N { + p *= matrix[i][i]; + } + p * (sign as f64) +} + +fn trianglize(matrix: &[[f64; N]; N]) -> i32 { + let mut sign = 1; + let mut matrix = *matrix; + + for i in 0..N { + let mut max = 0; + for row in i..N { + if fabs(matrix[row][i]) > fabs(matrix[max][i]) { + max = row; + } + } + if max > 0 { + sign = -sign; + let tmp = matrix[i]; + matrix[i] = matrix[max]; + matrix[max] = tmp; + } + if matrix[i][i] == 0. { + return 0; + } + for row in i + 1..N { + let factor = matrix[row][i] / matrix[i][i]; + if factor == 0. { + continue; + } + for col in i..N { + matrix[row][col] -= factor * matrix[i][col]; + } + } + } + + sign +} + +fn fit_curve( + order: i32, + px: &[f64; N], + py: &[f64; N], +) -> [f64; NCoeffs] { + let mut coeffs = [0f64; NCoeffs]; + + if NCoeffs != (order + 1) as usize { + panic!( + "Invalid coefficients length, expected {}, but got {}", + order + 1, + NCoeffs + ); + } + + if NCoeffs > MAX_ORDER || NCoeffs < 2 { + panic!("Matrix size out of bounds"); + } + + if N < 1 { + panic!("Not enough points to fit"); + } + + let mut t = [0f64; MAX_ORDER]; + let mut s = [0f64; MAX_ORDER * 2 + 1]; + + for i in 0..N { + let x = px[i]; + let y = py[i]; + for j in 0..NCoeffs * 2 - 1 { + s[j] += curve_fit_power(x, j as u32); + } + for j in 0..NCoeffs { + t[j] += y * curve_fit_power(x, j as u32); + } + } + + //Master matrix LHS of linear equation + let mut matrix = [[0f64; NCoeffs]; NCoeffs]; + + for i in 0..NCoeffs { + for j in 0..NCoeffs { + matrix[i][j] = s[i + j]; + } + } + + let denom = det(&matrix); + + for i in 0..NCoeffs { + coeffs[NCoeffs - i - 1] = det(&sub_col(&matrix, &t, i, NCoeffs)) / denom; + } + + coeffs +} + +pub fn linearize(point: f32, coefficients: &[f32; 4]) -> f32 { + coefficients[0] * (point * point * point) + + coefficients[1] * (point * point) + + coefficients[2] * point + + coefficients[3] +} + +/// +/// Generate a fit to linearize the stick response. +/// +/// Inputs: +/// cleaned points X and Y, (must be 17 points for each of these, the first being the center, the others starting at 3 oclock and going around counterclockwise) +/// +/// Outputs: +/// linearization fit coefficients for X and Y +pub fn linearize_calibration(in_x: [f32; 17], in_y: [f32; 17]) -> LinearizeCalibrationOutput { + let mut fit_points_x = [0f64; 5]; + let mut fit_points_y = [0f64; 5]; + + fit_points_x[0] = in_x[8 + 1] as f64; + fit_points_x[1] = (in_x[6 + 1] as f64 + in_x[10 + 1] as f64) / 2.0f64; + fit_points_x[2] = in_x[0] as f64; + fit_points_x[3] = (in_x[2 + 1] as f64 + in_x[14 + 1] as f64) / 2.0f64; + fit_points_x[4] = in_x[0 + 1] as f64; + + fit_points_y[0] = in_y[12 + 1] as f64; + fit_points_y[1] = (in_y[10 + 1] as f64 + in_y[14 + 1] as f64) / 2.0f64; + fit_points_y[2] = in_y[0] as f64; + fit_points_y[3] = (in_y[6 + 1] as f64 + in_y[2 + 1] as f64) / 2.0f64; + fit_points_y[4] = in_y[4 + 1] as f64; + + let x_output: [f64; 5] = [27.5, 53.2537879754, 127.5, 201.7462120246, 227.5]; + let y_output: [f64; 5] = [27.5, 53.2537879754, 127.5, 201.7462120246, 227.5]; + + let mut fit_coeffs_x = fit_curve::<5, N_COEFFS>(FIT_ORDER as i32, &fit_points_x, &x_output); + let mut fit_coeffs_y = fit_curve::<5, N_COEFFS>(FIT_ORDER as i32, &fit_points_y, &y_output); + + let x_zero_error = linearize(fit_points_x[2] as f32, &fit_coeffs_x.map(|e| e as f32)); + let y_zero_error = linearize(fit_points_y[2] as f32, &fit_coeffs_y.map(|e| e as f32)); + + fit_coeffs_x[3] = fit_coeffs_x[3] - x_zero_error as f64; + fit_coeffs_y[3] = fit_coeffs_y[3] - y_zero_error as f64; + + let mut out_x = [0f32; NO_OF_NOTCHES]; + let mut out_y = [0f32; NO_OF_NOTCHES]; + + for i in 0..=NO_OF_NOTCHES { + out_x[i] = linearize(in_x[i], &fit_coeffs_x.map(|e| e as f32)); + out_y[i] = linearize(in_y[i], &fit_coeffs_y.map(|e| e as f32)); + } + + LinearizeCalibrationOutput { + fit_coeffs_x, + fit_coeffs_y, + out_x, + out_y, + } +}