begin implementing stick shenanigans

This commit is contained in:
Naxdy 2024-03-19 21:06:18 +01:00
parent cbaa4f4ca9
commit 5244f1a75e
No known key found for this signature in database
GPG key ID: C0437AAE9755550F
5 changed files with 535 additions and 22 deletions

7
Cargo.lock generated
View file

@ -828,6 +828,12 @@ version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libredox"
version = "0.0.1"
@ -885,6 +891,7 @@ dependencies = [
"fixed",
"fixed-macro",
"format_no_std",
"libm",
"packed_struct",
"panic-probe",
"portable-atomic",

View file

@ -42,6 +42,7 @@ fixed = "1.23.1"
fixed-macro = "1.2"
static_cell = "2"
portable-atomic = { version = "1.5", features = ["critical-section"] }
libm = { version = "0.2.8" }
#cortex-m = { version = "0.7.6", features = ["critical-section-single-core"] }
cortex-m = { version = "0.7.6", features = ["inline-asm"] }

View file

@ -1,21 +1,53 @@
use defmt::{debug, info};
use core::task::Poll;
use defmt::{debug, Format};
use embassy_futures::{join::join, yield_now};
use embassy_rp::{
clocks::RoscRng,
flash::{Async, Flash, ERASE_SIZE},
flash::{Async, Flash},
gpio::{Input, Output, Pin},
peripherals::{
DMA_CH0, FLASH, PIN_10, PIN_11, PIN_15, PIN_16, PIN_17, PIN_18, PIN_19, PIN_20, PIN_21,
PIN_22, PIN_23, PIN_24, PIN_25, PIN_29, PIN_5, PIN_8, PIN_9, SPI0,
FLASH, PIN_10, PIN_11, PIN_16, PIN_17, PIN_18, PIN_19, PIN_20, PIN_21, PIN_22, PIN_23,
PIN_24, PIN_5, PIN_8, PIN_9, PWM_CH4, PWM_CH6, SPI0,
},
pwm::Pwm,
spi::Spi,
};
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, signal::Signal};
use rand::RngCore;
use embassy_time::{Instant, Timer};
use packed_struct::derive::PackedStruct;
use crate::{gcc_hid::GcReport, ADDR_OFFSET, FLASH_SIZE};
use crate::{
gcc_hid::GcReport,
stick::{linearize, run_kalman, FilterGains, StickParams},
PackedFloat, ADDR_OFFSET, FLASH_SIZE,
};
pub static GCC_SIGNAL: Signal<CriticalSectionRawMutex, GcReport> = Signal::new();
static STICK_SIGNAL: Signal<CriticalSectionRawMutex, StickState> = Signal::new();
#[derive(Debug, Clone, Default, Format, PackedStruct)]
#[packed_struct(endian = "msb")]
pub struct ControllerConfig {
#[packed_field(size_bits = "8")]
pub config_version: u8,
#[packed_field(size_bits = "32")]
pub ax_waveshaping: PackedFloat,
#[packed_field(size_bits = "32")]
pub ay_waveshaping: PackedFloat,
#[packed_field(size_bits = "32")]
pub cx_waveshaping: PackedFloat,
#[packed_field(size_bits = "32")]
pub cy_waveshaping: PackedFloat,
}
struct StickState {
ax: u8,
ay: u8,
cx: u8,
cy: u8,
}
#[derive(PartialEq, Eq)]
enum Stick {
ControlStick,
@ -28,10 +60,10 @@ enum StickAxis {
YAxis,
}
fn read_ext_adc<'a, Acs: Pin, Ccs: Pin>(
fn read_ext_adc<'a, Acs: Pin, Ccs: Pin, I: embassy_rp::spi::Instance, M: embassy_rp::spi::Mode>(
which_stick: Stick,
which_axis: StickAxis,
spi: &mut Spi<'a, SPI0, embassy_rp::spi::Blocking>,
spi: &mut Spi<'a, I, M>,
spi_acs: &mut Output<'a, Acs>,
spi_ccs: &mut Output<'a, Ccs>,
) -> u16 {
@ -61,6 +93,134 @@ fn read_ext_adc<'a, Acs: Pin, Ccs: Pin>(
return temp_value;
}
/// Gets the average stick state over a 1ms interval in a non-blocking fashion.
async fn update_stick_states<
'a,
Acs: Pin,
Ccs: Pin,
I: embassy_rp::spi::Instance,
M: embassy_rp::spi::Mode,
>(
mut spi: &mut Spi<'a, I, M>,
mut spi_acs: &mut Output<'a, Acs>,
mut spi_ccs: &mut Output<'a, Ccs>,
adc_scale: f32,
controlstick_params: &StickParams,
cstick_params: &StickParams,
controller_config: &ControllerConfig,
filter_gains: &FilterGains,
) {
let mut adc_count = 0u32;
let mut ax_sum = 0u32;
let mut ay_sum = 0u32;
let mut cx_sum = 0u32;
let mut cy_sum = 0u32;
// TODO: lower interval possible?
let mut timer = Timer::at(Instant::now() + embassy_time::Duration::from_millis(1));
while embassy_futures::poll_once(&mut timer) != Poll::Ready(()) {
adc_count += 1;
ax_sum += read_ext_adc(
Stick::ControlStick,
StickAxis::XAxis,
&mut spi,
&mut spi_acs,
&mut spi_ccs,
) as u32;
ay_sum += read_ext_adc(
Stick::ControlStick,
StickAxis::YAxis,
&mut spi,
&mut spi_acs,
&mut spi_ccs,
) as u32;
cx_sum += read_ext_adc(
Stick::CStick,
StickAxis::XAxis,
&mut spi,
&mut spi_acs,
&mut spi_ccs,
) as u32;
cy_sum += read_ext_adc(
Stick::CStick,
StickAxis::YAxis,
&mut spi,
&mut spi_acs,
&mut spi_ccs,
) as u32;
// with this, we can poll the sticks at 1000Hz (ish), while updating
// the rest of the controller (the buttons) much faster, to ensure
// better input integrity for button inputs.
yield_now().await;
}
timer.await;
let raw_controlstick_x = (ax_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale;
let raw_controlstick_y = (ay_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale;
let raw_cstick_x = (cx_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale;
let raw_cstick_y = (cy_sum as f32) / (adc_count as f32) / 4096.0f32 * adc_scale;
let x_z = linearize(raw_controlstick_x, &controlstick_params.fit_coeffs_x);
let y_z = linearize(raw_controlstick_y, &controlstick_params.fit_coeffs_y);
let pos_cx = linearize(raw_cstick_x, &cstick_params.fit_coeffs_x);
let pos_cy = linearize(raw_cstick_y, &cstick_params.fit_coeffs_y);
let (x_pos_filt, y_pos_filt) = run_kalman(x_z, y_z, controller_config, filter_gains);
STICK_SIGNAL.signal(StickState {
ax: 127,
ay: 127,
cx: 127,
cy: 127,
})
}
fn update_button_states<
A: Pin,
B: Pin,
X: Pin,
Y: Pin,
Start: Pin,
L: Pin,
R: Pin,
Z: Pin,
DLeft: Pin,
DRight: Pin,
DUp: Pin,
DDown: Pin,
>(
gcc_state: &mut GcReport,
btn_a: &Input<'_, A>,
btn_b: &Input<'_, B>,
btn_x: &Input<'_, X>,
btn_y: &Input<'_, Y>,
btn_start: &Input<'_, Start>,
btn_l: &Input<'_, L>,
btn_r: &Input<'_, R>,
btn_z: &Input<'_, Z>,
btn_dleft: &Input<'_, DLeft>,
btn_dright: &Input<'_, DRight>,
btn_dup: &Input<'_, DUp>,
btn_ddown: &Input<'_, DDown>,
) {
gcc_state.buttons_1.button_a = btn_a.is_low();
gcc_state.buttons_1.button_b = btn_b.is_low();
gcc_state.buttons_1.button_x = btn_x.is_low();
gcc_state.buttons_1.button_y = btn_y.is_low();
gcc_state.buttons_2.button_z = btn_z.is_low();
gcc_state.buttons_2.button_start = btn_start.is_low();
gcc_state.buttons_2.button_l = btn_l.is_low();
gcc_state.buttons_2.button_r = btn_r.is_low();
gcc_state.buttons_1.dpad_left = btn_dleft.is_low();
gcc_state.buttons_1.dpad_right = btn_dright.is_low();
gcc_state.buttons_1.dpad_up = btn_dup.is_low();
gcc_state.buttons_1.dpad_down = btn_ddown.is_low();
}
#[embassy_executor::task]
pub async fn input_loop(
mut flash: Flash<'static, FLASH, Async, FLASH_SIZE>,
@ -76,26 +236,67 @@ pub async fn input_loop(
btn_x: Input<'static, PIN_18>,
btn_y: Input<'static, PIN_19>,
btn_start: Input<'static, PIN_5>,
btn_rumble: Input<'static, PIN_25>,
btn_brake: Input<'static, PIN_29>,
pwm_rumble: Pwm<'static, PWM_CH4>,
pwm_brake: Pwm<'static, PWM_CH6>,
mut spi: Spi<'static, SPI0, embassy_rp::spi::Blocking>,
mut spi_acs: Output<'static, PIN_24>,
mut spi_ccs: Output<'static, PIN_23>,
) -> ! {
) {
let mut gcc_state = GcReport::default();
let mut rng = RoscRng;
// Set the stick states to the center
gcc_state.stick_x = 127;
gcc_state.stick_y = 127;
gcc_state.cstick_x = 127;
gcc_state.cstick_y = 127;
let mut uid = [0u8; 1];
flash.blocking_read(ADDR_OFFSET, &mut uid).unwrap();
debug!("Read from flash: {:02X}", uid);
loop {
gcc_state.buttons_1.button_a = btn_z.is_low();
gcc_state.stick_x = rng.next_u32() as u8;
gcc_state.stick_y = rng.next_u32() as u8;
// TODO: load controller config here
GCC_SIGNAL.signal(gcc_state);
}
let stick_state_fut = async {
loop {
// update_stick_states(&mut spi, &mut spi_acs, &mut spi_ccs, 1.0).await;
}
};
let input_fut = async {
loop {
update_button_states(
&mut gcc_state,
&btn_a,
&btn_b,
&btn_x,
&btn_y,
&btn_start,
&btn_l,
&btn_r,
&btn_z,
&btn_dleft,
&btn_dright,
&btn_dup,
&btn_ddown,
);
yield_now().await;
// not every loop pass is going to update the stick state
match STICK_SIGNAL.try_take() {
Some(stick_state) => {
gcc_state.stick_x = stick_state.ax;
gcc_state.stick_y = stick_state.ay;
gcc_state.cstick_x = stick_state.cx;
gcc_state.cstick_y = stick_state.cy;
}
None => (),
}
GCC_SIGNAL.signal(gcc_state);
}
};
join(input_fut, stick_state_fut).await;
}

View file

@ -6,8 +6,11 @@
#![no_main]
mod gcc_hid;
mod input;
mod stick;
use defmt::{debug, info};
use core::ops::Deref;
use defmt::{debug, info, Format};
use embassy_executor::Executor;
use embassy_rp::{
bind_interrupts,
@ -15,12 +18,14 @@ use embassy_rp::{
gpio::{self, Input},
multicore::{spawn_core1, Stack},
peripherals::USB,
pwm::Pwm,
spi::{self, Spi},
usb::{Driver, InterruptHandler},
};
use gcc_hid::usb_transfer_loop;
use gpio::{Level, Output};
use input::input_loop;
use packed_struct::PackedStruct;
use static_cell::StaticCell;
use {defmt_rtt as _, panic_probe as _};
@ -31,6 +36,31 @@ static EXECUTOR1: StaticCell<Executor> = StaticCell::new();
const FLASH_SIZE: usize = 2 * 1024 * 1024;
const ADDR_OFFSET: u32 = 0x100000;
/// wrapper type because packed_struct doesn't implement float
/// packing by default for some reason
#[derive(Debug, Format, Clone, Default)]
pub struct PackedFloat(f32);
impl PackedStruct for PackedFloat {
type ByteArray = [u8; 4];
fn pack(&self) -> packed_struct::PackingResult<Self::ByteArray> {
Ok(self.to_be_bytes())
}
fn unpack(src: &Self::ByteArray) -> packed_struct::PackingResult<Self> {
Ok(Self(f32::from_be_bytes(*src)))
}
}
impl Deref for PackedFloat {
type Target = f32;
fn deref(&self) -> &Self::Target {
&self.0
}
}
bind_interrupts!(struct Irqs {
USBCTRL_IRQ => InterruptHandler<USB>;
});
@ -79,6 +109,17 @@ fn main() -> ! {
let spi_acs = Output::new(p_acs, Level::High); // active low
let spi_ccs = Output::new(p_ccs, Level::High); // active low
let mut pwm_config: embassy_rp::pwm::Config = Default::default();
pwm_config.top = 255;
pwm_config.enable = true;
pwm_config.compare_b = 255;
let pwm_rumble = Pwm::new_output_b(p.PWM_CH4, p.PIN_25, pwm_config.clone());
let pwm_brake = Pwm::new_output_b(p.PWM_CH6, p.PIN_29, pwm_config.clone());
pwm_rumble.set_counter(255);
pwm_brake.set_counter(0);
executor0.run(|spawner| {
spawner
.spawn(input_loop(
@ -95,8 +136,8 @@ fn main() -> ! {
Input::new(p.PIN_18, gpio::Pull::Up),
Input::new(p.PIN_19, gpio::Pull::Up),
Input::new(p.PIN_5, gpio::Pull::Up),
Input::new(p.PIN_25, gpio::Pull::Up),
Input::new(p.PIN_29, gpio::Pull::Up),
pwm_rumble,
pwm_brake,
spi,
spi_acs,
spi_ccs,

263
src/stick.rs Normal file
View file

@ -0,0 +1,263 @@
// vast majority of this is taken from Phob firmware
use defmt::Format;
use libm::fabs;
use crate::input::ControllerConfig;
/// fit order for the linearization
const FIT_ORDER: usize = 3;
const N_COEFFS: usize = FIT_ORDER + 1;
const NO_OF_NOTCHES: usize = 16;
const MAX_ORDER: usize = 20;
#[derive(Clone, Debug, Default, Format)]
pub struct StickParams {
// these are the linearization coefficients
pub fit_coeffs_x: [f32; N_COEFFS],
pub fit_coeffs_y: [f32; N_COEFFS],
// these are the notch remap parameters
pub affine_coeffs_x: [[f32; 16]; 4], // affine transformation coefficients for all regions of the stick
pub boundary_angles_x: [f32; 4], // angles at the boundaries between regions of the stick (in the plane)
}
#[derive(Clone, Debug, Default, Format)]
pub struct FilterGains {
/// What's the max stick distance from the center
pub max_stick: f32,
/// filtered velocity terms
/// how fast the filtered velocity falls off in the absence of stick movement.
/// Probably don't touch this.
pub x_vel_decay: f32, //0.1 default for 1.2ms timesteps, larger for bigger timesteps
pub y_vel_decay: f32,
/// how much the current position disagreement impacts the filtered velocity.
/// Probably don't touch this.
pub x_vel_pos_factor: f32, //0.01 default for 1.2ms timesteps, larger for bigger timesteps
pub y_vel_pos_factor: f32,
/// how much to ignore filtered velocity when computing the new stick position.
/// DO CHANGE THIS
/// Higher gives shorter rise times and slower fall times (more pode, less snapback)
pub x_vel_damp: f32, //0.125 default for 1.2ms timesteps, smaller for bigger timesteps
pub y_vel_damp: f32,
/// speed and accel thresholds below which we try to follow the stick better
/// These may need tweaking according to how noisy the signal is
/// If it's noisier, we may need to add additional filtering
/// If the timesteps are *really small* then it may need to be increased to get
/// above the noise floor. Or some combination of filtering and playing with
/// the thresholds.
pub vel_thresh: f32, //1 default for 1.2ms timesteps, larger for bigger timesteps
pub accel_thresh: f32, //5 default for 1.2ms timesteps, larger for bigger timesteps
/// This just applies a low-pass filter.
/// The purpose is to provide delay for single-axis ledgedashes.
/// Must be between 0 and 1. Larger = more smoothing and delay.
pub x_smoothing: f32,
pub y_smoothing: f32,
/// Same thing but for C-stick
pub c_xsmoothing: f32,
pub c_ysmoothing: f32,
}
#[derive(Clone, Debug, Default)]
struct LinearizeCalibrationOutput {
pub fit_coeffs_x: [f64; N_COEFFS],
pub fit_coeffs_y: [f64; N_COEFFS],
pub out_x: [f32; NO_OF_NOTCHES],
pub out_y: [f32; NO_OF_NOTCHES],
}
pub fn run_kalman(
x_z: f32,
y_z: f32,
controller_config: &ControllerConfig,
filter_gains: &FilterGains,
) -> (f32, f32) {
todo!()
}
fn curve_fit_power(base: f64, exponent: u32) -> f64 {
if exponent == 0 {
return 1.0;
}
let mut val = base;
for _ in 1..exponent {
val *= base;
}
val
}
fn sub_col<const N: usize>(
matrix: &[[f64; N]; N],
t: &[f64; MAX_ORDER],
col: usize,
n: usize,
) -> [[f64; N]; N] {
let mut m = *matrix;
for i in 0..n {
m[i][col] = t[i];
}
m
}
fn det<const N: usize>(matrix: &[[f64; N]; N]) -> f64 {
let sign = trianglize(matrix);
if sign == 0 {
return 0.;
}
let mut p = 1f64;
for i in 0..N {
p *= matrix[i][i];
}
p * (sign as f64)
}
fn trianglize<const N: usize>(matrix: &[[f64; N]; N]) -> i32 {
let mut sign = 1;
let mut matrix = *matrix;
for i in 0..N {
let mut max = 0;
for row in i..N {
if fabs(matrix[row][i]) > fabs(matrix[max][i]) {
max = row;
}
}
if max > 0 {
sign = -sign;
let tmp = matrix[i];
matrix[i] = matrix[max];
matrix[max] = tmp;
}
if matrix[i][i] == 0. {
return 0;
}
for row in i + 1..N {
let factor = matrix[row][i] / matrix[i][i];
if factor == 0. {
continue;
}
for col in i..N {
matrix[row][col] -= factor * matrix[i][col];
}
}
}
sign
}
fn fit_curve<const N: usize, const NCoeffs: usize>(
order: i32,
px: &[f64; N],
py: &[f64; N],
) -> [f64; NCoeffs] {
let mut coeffs = [0f64; NCoeffs];
if NCoeffs != (order + 1) as usize {
panic!(
"Invalid coefficients length, expected {}, but got {}",
order + 1,
NCoeffs
);
}
if NCoeffs > MAX_ORDER || NCoeffs < 2 {
panic!("Matrix size out of bounds");
}
if N < 1 {
panic!("Not enough points to fit");
}
let mut t = [0f64; MAX_ORDER];
let mut s = [0f64; MAX_ORDER * 2 + 1];
for i in 0..N {
let x = px[i];
let y = py[i];
for j in 0..NCoeffs * 2 - 1 {
s[j] += curve_fit_power(x, j as u32);
}
for j in 0..NCoeffs {
t[j] += y * curve_fit_power(x, j as u32);
}
}
//Master matrix LHS of linear equation
let mut matrix = [[0f64; NCoeffs]; NCoeffs];
for i in 0..NCoeffs {
for j in 0..NCoeffs {
matrix[i][j] = s[i + j];
}
}
let denom = det(&matrix);
for i in 0..NCoeffs {
coeffs[NCoeffs - i - 1] = det(&sub_col(&matrix, &t, i, NCoeffs)) / denom;
}
coeffs
}
pub fn linearize(point: f32, coefficients: &[f32; 4]) -> f32 {
coefficients[0] * (point * point * point)
+ coefficients[1] * (point * point)
+ coefficients[2] * point
+ coefficients[3]
}
///
/// Generate a fit to linearize the stick response.
///
/// Inputs:
/// cleaned points X and Y, (must be 17 points for each of these, the first being the center, the others starting at 3 oclock and going around counterclockwise)
///
/// Outputs:
/// linearization fit coefficients for X and Y
pub fn linearize_calibration(in_x: [f32; 17], in_y: [f32; 17]) -> LinearizeCalibrationOutput {
let mut fit_points_x = [0f64; 5];
let mut fit_points_y = [0f64; 5];
fit_points_x[0] = in_x[8 + 1] as f64;
fit_points_x[1] = (in_x[6 + 1] as f64 + in_x[10 + 1] as f64) / 2.0f64;
fit_points_x[2] = in_x[0] as f64;
fit_points_x[3] = (in_x[2 + 1] as f64 + in_x[14 + 1] as f64) / 2.0f64;
fit_points_x[4] = in_x[0 + 1] as f64;
fit_points_y[0] = in_y[12 + 1] as f64;
fit_points_y[1] = (in_y[10 + 1] as f64 + in_y[14 + 1] as f64) / 2.0f64;
fit_points_y[2] = in_y[0] as f64;
fit_points_y[3] = (in_y[6 + 1] as f64 + in_y[2 + 1] as f64) / 2.0f64;
fit_points_y[4] = in_y[4 + 1] as f64;
let x_output: [f64; 5] = [27.5, 53.2537879754, 127.5, 201.7462120246, 227.5];
let y_output: [f64; 5] = [27.5, 53.2537879754, 127.5, 201.7462120246, 227.5];
let mut fit_coeffs_x = fit_curve::<5, N_COEFFS>(FIT_ORDER as i32, &fit_points_x, &x_output);
let mut fit_coeffs_y = fit_curve::<5, N_COEFFS>(FIT_ORDER as i32, &fit_points_y, &y_output);
let x_zero_error = linearize(fit_points_x[2] as f32, &fit_coeffs_x.map(|e| e as f32));
let y_zero_error = linearize(fit_points_y[2] as f32, &fit_coeffs_y.map(|e| e as f32));
fit_coeffs_x[3] = fit_coeffs_x[3] - x_zero_error as f64;
fit_coeffs_y[3] = fit_coeffs_y[3] - y_zero_error as f64;
let mut out_x = [0f32; NO_OF_NOTCHES];
let mut out_y = [0f32; NO_OF_NOTCHES];
for i in 0..=NO_OF_NOTCHES {
out_x[i] = linearize(in_x[i], &fit_coeffs_x.map(|e| e as f32));
out_y[i] = linearize(in_y[i], &fit_coeffs_y.map(|e| e as f32));
}
LinearizeCalibrationOutput {
fit_coeffs_x,
fit_coeffs_y,
out_x,
out_y,
}
}