diff --git a/satrs-core/src/seq_count.rs b/satrs-core/src/seq_count.rs index af03f68..5658281 100644 --- a/satrs-core/src/seq_count.rs +++ b/satrs-core/src/seq_count.rs @@ -1,5 +1,4 @@ use core::cell::Cell; -use core::sync::atomic::{AtomicU16, Ordering}; #[cfg(feature = "alloc")] use dyn_clone::DynClone; use paste::paste; @@ -115,63 +114,63 @@ impl SequenceCountProviderCore for CcsdsSimpleSeqCountProvider { } } -pub struct SeqCountProviderAtomicRef { - atomic: AtomicU16, - ordering: Ordering, -} - -impl SeqCountProviderAtomicRef { - pub const fn new(ordering: Ordering) -> Self { - Self { - atomic: AtomicU16::new(0), - ordering, - } - } -} - -impl SequenceCountProviderCore for SeqCountProviderAtomicRef { - fn get(&self) -> u16 { - self.atomic.load(self.ordering) - } - - fn increment(&self) { - self.atomic.fetch_add(1, self.ordering); - } - - fn get_and_increment(&self) -> u16 { - self.atomic.fetch_add(1, self.ordering) - } -} - #[cfg(feature = "std")] pub mod stdmod { use super::*; - use std::sync::Arc; + use std::sync::{Arc, Mutex}; - #[derive(Clone, Default)] - pub struct SeqCountProviderSyncClonable { - seq_count: Arc, - } - - impl SequenceCountProviderCore for SeqCountProviderSyncClonable { - fn get(&self) -> u16 { - self.seq_count.load(Ordering::SeqCst) - } - - fn increment(&self) { - self.seq_count.fetch_add(1, Ordering::SeqCst); - } - - fn get_and_increment(&self) -> u16 { - self.seq_count.fetch_add(1, Ordering::SeqCst) - } + macro_rules! sync_clonable_seq_counter_impl { + ($($ty: ident,)+) => { + $(paste! { + #[derive(Clone, Default)] + pub struct [] { + seq_count: Arc>, + max_val: $ty + } + + impl [] { + pub fn new() -> Self { + Self::new_with_max_val($ty::MAX) + } + + pub fn new_with_max_val(max_val: $ty) -> Self { + Self { + seq_count: Arc::default(), + max_val + } + } + } + impl SequenceCountProviderCore<$ty> for [] { + fn get(&self) -> $ty { + *self.seq_count.lock().unwrap() + } + + fn increment(&self) { + self.get_and_increment(); + } + + fn get_and_increment(&self) -> $ty { + let mut counter = self.seq_count.lock().unwrap(); + let current_val = *counter; + if *counter == self.max_val { + *counter = 0; + } else { + *counter += 1; + } + current_val + } + } + })+ + } } + sync_clonable_seq_counter_impl!(u8, u16, u32, u64,); } #[cfg(test)] mod tests { use crate::seq_count::{ - CcsdsSimpleSeqCountProvider, SeqCountProviderSimple, SequenceCountProviderCore, + CcsdsSimpleSeqCountProvider, SeqCountProviderSimple, SeqCountProviderSyncU8, + SequenceCountProviderCore, }; use spacepackets::MAX_SEQ_COUNT; @@ -210,4 +209,31 @@ mod tests { } assert_eq!(ccsds_counter.get(), 0); } + + #[test] + fn test_atomic_ref_counters() { + let sync_u8_counter = SeqCountProviderSyncU8::new(); + assert_eq!(sync_u8_counter.get(), 0); + assert_eq!(sync_u8_counter.get_and_increment(), 0); + assert_eq!(sync_u8_counter.get_and_increment(), 1); + assert_eq!(sync_u8_counter.get(), 2); + } + + #[test] + fn test_atomic_ref_counters_overflow() { + let sync_u8_counter = SeqCountProviderSyncU8::new(); + for _ in 0..u8::MAX as u16 + 1 { + sync_u8_counter.increment(); + } + assert_eq!(sync_u8_counter.get(), 0); + } + + #[test] + fn test_atomic_ref_counters_overflow_custom_max_val() { + let sync_u8_counter = SeqCountProviderSyncU8::new_with_max_val(128); + for _ in 0..129 { + sync_u8_counter.increment(); + } + assert_eq!(sync_u8_counter.get(), 0); + } }