From a91f68654473e8ca6030b97ccda98e7fd868716e Mon Sep 17 00:00:00 2001
From: Karun <karun@nautilusdefense.com>
Date: Fri, 3 May 2024 17:29:21 -0400
Subject: [PATCH] Check group configuration validity

---
 embassy-stm32/src/tsc/mod.rs | 250 ++++++++++++++++++++++++++++++++---
 1 file changed, 233 insertions(+), 17 deletions(-)

diff --git a/embassy-stm32/src/tsc/mod.rs b/embassy-stm32/src/tsc/mod.rs
index 6b9d8fcf5..17a455558 100644
--- a/embassy-stm32/src/tsc/mod.rs
+++ b/embassy-stm32/src/tsc/mod.rs
@@ -215,10 +215,15 @@ impl Default for Config {
 #[allow(missing_docs)]
 pub struct TscPin<'d, T, C> {
     _pin: PeripheralRef<'d, AnyPin>,
-    _role: PinType,
+    role: PinType,
     phantom: PhantomData<(T, C)>,
 }
 
+enum GroupError {
+    NoSample,
+    ChannelShield,
+}
+
 /// Pin group definition
 /// Pins are organized into groups of four IOs, all groups with a
 /// sampling channel must also have a sampling capacitor channel.
@@ -241,6 +246,119 @@ impl<'d, T: Instance, C> PinGroup<'d, T, C> {
             d4: None,
         }
     }
+
+    fn contains_shield(&self) -> bool {
+        let mut shield_count = 0;
+
+        if let Some(pin) = &self.d1 {
+            if let PinType::Shield = pin.role {
+                shield_count += 1;
+            }
+        }
+
+        if let Some(pin) = &self.d2 {
+            if let PinType::Shield = pin.role {
+                shield_count += 1;
+            }
+        }
+
+        if let Some(pin) = &self.d3 {
+            if let PinType::Shield = pin.role {
+                shield_count += 1;
+            }
+        }
+
+        if let Some(pin) = &self.d4 {
+            if let PinType::Shield = pin.role {
+                shield_count += 1;
+            }
+        }
+
+        shield_count == 1
+    }
+
+    fn check_group(&self) -> Result<(), GroupError> {
+        let mut channel_count = 0;
+        let mut shield_count = 0;
+        let mut sample_count = 0;
+        if let Some(pin) = &self.d1 {
+            match pin.role {
+                PinType::Channel => {
+                    channel_count += 1;
+                }
+                PinType::Shield => {
+                    shield_count += 1;
+                }
+                PinType::Sample => {
+                    sample_count += 1;
+                }
+            }
+        }
+
+        if let Some(pin) = &self.d2 {
+            match pin.role {
+                PinType::Channel => {
+                    channel_count += 1;
+                }
+                PinType::Shield => {
+                    shield_count += 1;
+                }
+                PinType::Sample => {
+                    sample_count += 1;
+                }
+            }
+        }
+
+        if let Some(pin) = &self.d3 {
+            match pin.role {
+                PinType::Channel => {
+                    channel_count += 1;
+                }
+                PinType::Shield => {
+                    shield_count += 1;
+                }
+                PinType::Sample => {
+                    sample_count += 1;
+                }
+            }
+        }
+
+        if let Some(pin) = &self.d4 {
+            match pin.role {
+                PinType::Channel => {
+                    channel_count += 1;
+                }
+                PinType::Shield => {
+                    shield_count += 1;
+                }
+                PinType::Sample => {
+                    sample_count += 1;
+                }
+            }
+        }
+
+        // Every group requires one sampling capacitor
+        if sample_count != 1 {
+            return Err(GroupError::NoSample);
+        }
+
+        // Each group must have at least one shield or channel IO
+        if shield_count == 0 && channel_count == 0 {
+            return Err(GroupError::ChannelShield);
+        }
+
+        // Any group can either contain channel ios or a shield IO
+        if shield_count != 0 && channel_count != 0 {
+            return Err(GroupError::ChannelShield);
+        }
+
+        // No more than one shield IO is allow per group and amongst all groups
+        if shield_count > 1 {
+            return Err(GroupError::ChannelShield);
+        }
+
+        Ok(())
+    }
 }
 
 macro_rules! group_impl {
@@ -261,7 +379,7 @@ macro_rules! group_impl {
                     );
                     self.d1 = Some(TscPin {
                         _pin: pin.map_into(),
-                        _role: role,
+                        role: role,
                         phantom: PhantomData,
                     })
                 })
@@ -282,7 +400,7 @@ macro_rules! group_impl {
                     );
                     self.d2 = Some(TscPin {
                         _pin: pin.map_into(),
-                        _role: role,
+                        role: role,
                         phantom: PhantomData,
                     })
                 })
@@ -303,7 +421,7 @@ macro_rules! group_impl {
                     );
                     self.d3 = Some(TscPin {
                         _pin: pin.map_into(),
-                        _role: role,
+                        role: role,
                         phantom: PhantomData,
                     })
                 })
@@ -324,7 +442,7 @@ macro_rules! group_impl {
                     );
                     self.d4 = Some(TscPin {
                         _pin: pin.map_into(),
-                        _role: role,
+                        role: role,
                         phantom: PhantomData,
                     })
                 })
@@ -391,20 +509,118 @@ impl<'d, T: Instance> Tsc<'d, T> {
         config: Config,
     ) -> Self {
         // Need to check valid pin configuration input
-        Self::new_inner(
-            peri,
-            g1,
-            g2,
-            g3,
-            g4,
-            g5,
-            g6,
+        let g1 = g1.filter(|b| b.check_group().is_ok());
+        let g2 = g2.filter(|b| b.check_group().is_ok());
+        let g3 = g3.filter(|b| b.check_group().is_ok());
+        let g4 = g4.filter(|b| b.check_group().is_ok());
+        let g5 = g5.filter(|b| b.check_group().is_ok());
+        let g6 = g6.filter(|b| b.check_group().is_ok());
+        let g7 = g7.filter(|b| b.check_group().is_ok());
+        let g8 = g8.filter(|b| b.check_group().is_ok());
+
+        match Self::check_shields(
+            &g1,
+            &g2,
+            &g3,
+            &g4,
+            &g5,
+            &g6,
             #[cfg(any(tsc_v2, tsc_v3))]
-            g7,
+            &g7,
             #[cfg(tsc_v3)]
-            g8,
-            config,
-        )
+            &g8,
+        ) {
+            Ok(()) => Self::new_inner(
+                peri,
+                g1,
+                g2,
+                g3,
+                g4,
+                g5,
+                g6,
+                #[cfg(any(tsc_v2, tsc_v3))]
+                g7,
+                #[cfg(tsc_v3)]
+                g8,
+                config,
+            ),
+            Err(_) => Self::new_inner(
+                peri,
+                None,
+                None,
+                None,
+                None,
+                None,
+                None,
+                #[cfg(any(tsc_v2, tsc_v3))]
+                None,
+                #[cfg(tsc_v3)]
+                None,
+                config,
+            ),
+        }
+    }
+
+    fn check_shields(
+        g1: &Option<PinGroup<'d, T, G1>>,
+        g2: &Option<PinGroup<'d, T, G2>>,
+        g3: &Option<PinGroup<'d, T, G3>>,
+        g4: &Option<PinGroup<'d, T, G4>>,
+        g5: &Option<PinGroup<'d, T, G5>>,
+        g6: &Option<PinGroup<'d, T, G6>>,
+        #[cfg(any(tsc_v2, tsc_v3))] g7: &Option<PinGroup<'d, T, G7>>,
+        #[cfg(tsc_v3)] g8: &Option<PinGroup<'d, T, G8>>,
+    ) -> Result<(), GroupError> {
+        let mut shield_count = 0;
+
+        if let Some(pin_group) = g1 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        if let Some(pin_group) = g2 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        if let Some(pin_group) = g3 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        if let Some(pin_group) = g4 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        if let Some(pin_group) = g5 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        if let Some(pin_group) = g6 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        #[cfg(any(tsc_v2, tsc_v3))]
+        if let Some(pin_group) = g7 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+        #[cfg(tsc_v3)]
+        if let Some(pin_group) = g8 {
+            if pin_group.contains_shield() {
+                shield_count += 1;
+            }
+        };
+
+        if shield_count > 1 {
+            return Err(GroupError::ChannelShield);
+        }
+
+        Ok(())
     }
 
     fn extract_groups(io_mask: u32) -> u32 {