diff --git a/satrs-core/src/pus/verification.rs b/satrs-core/src/pus/verification.rs index c3c8522..7c187f0 100644 --- a/satrs-core/src/pus/verification.rs +++ b/satrs-core/src/pus/verification.rs @@ -944,6 +944,8 @@ mod alloc_mod { #[derive(Clone)] pub struct VerificationReporter { source_data_buf: Vec, + //pub seq_count_provider: Option> + //pub msg_count_provider: Option>>, pub reporter: VerificationReporterCore, } diff --git a/satrs-core/src/seq_count.rs b/satrs-core/src/seq_count.rs index 2258d94..af03f68 100644 --- a/satrs-core/src/seq_count.rs +++ b/satrs-core/src/seq_count.rs @@ -2,6 +2,8 @@ use core::cell::Cell; use core::sync::atomic::{AtomicU16, Ordering}; #[cfg(feature = "alloc")] use dyn_clone::DynClone; +use paste::paste; +use spacepackets::MAX_SEQ_COUNT; #[cfg(feature = "std")] pub use stdmod::*; @@ -15,21 +17,11 @@ pub trait SequenceCountProviderCore { fn increment(&self); - // TODO: Maybe remove this? - fn increment_mut(&mut self) { - self.increment(); - } - fn get_and_increment(&self) -> Raw { let val = self.get(); self.increment(); val } - - // TODO: Maybe remove this? - fn get_and_increment_mut(&mut self) -> Raw { - self.get_and_increment() - } } /// Extension trait which allows cloning a sequence count provider after it was turned into @@ -42,36 +34,84 @@ dyn_clone::clone_trait_object!(SequenceCountProvider); impl SequenceCountProvider for T where T: SequenceCountProviderCore + Clone {} #[derive(Default, Clone)] -pub struct SeqCountProviderSimple { - seq_count: Cell, +pub struct SeqCountProviderSimple { + seq_count: Cell, + max_val: T, } -impl SeqCountProviderSimple { +macro_rules! impl_for_primitives { + ($($ty: ident,)+) => { + $( + paste! { + impl SeqCountProviderSimple<$ty> { + pub fn [](max_val: $ty) -> Self { + Self { + seq_count: Cell::new(0), + max_val, + } + } + + pub fn []() -> Self { + Self { + seq_count: Cell::new(0), + max_val: $ty::MAX + } + } + } + + impl SequenceCountProviderCore<$ty> for SeqCountProviderSimple<$ty> { + fn get(&self) -> $ty { + self.seq_count.get() + } + + fn increment(&self) { + self.get_and_increment(); + } + + fn get_and_increment(&self) -> $ty { + let curr_count = self.seq_count.get(); + + if curr_count == self.max_val { + self.seq_count.set(0); + } else { + self.seq_count.set(curr_count + 1); + } + curr_count + } + } + } + )+ + } +} + +impl_for_primitives!(u8, u16, u32, u64,); + +/// This is a sequence count provider which wraps around at [MAX_SEQ_COUNT]. +pub struct CcsdsSimpleSeqCountProvider { + provider: SeqCountProviderSimple, +} + +impl CcsdsSimpleSeqCountProvider { pub fn new() -> Self { Self { - seq_count: Cell::new(0), + provider: SeqCountProviderSimple::new_u16_max_val(MAX_SEQ_COUNT), } } } -impl SequenceCountProviderCore for SeqCountProviderSimple { - fn get(&self) -> u16 { - self.seq_count.get() +impl Default for CcsdsSimpleSeqCountProvider { + fn default() -> Self { + Self::new() } +} - fn increment(&self) { - self.get_and_increment(); - } - - fn get_and_increment(&self) -> u16 { - let curr_count = self.seq_count.get(); - - if curr_count == u16::MAX { - self.seq_count.set(0); - } else { - self.seq_count.set(curr_count + 1); +impl SequenceCountProviderCore for CcsdsSimpleSeqCountProvider { + delegate::delegate! { + to self.provider { + fn get(&self) -> u16; + fn increment(&self); + fn get_and_increment(&self) -> u16; } - curr_count } } @@ -127,3 +167,47 @@ pub mod stdmod { } } } + +#[cfg(test)] +mod tests { + use crate::seq_count::{ + CcsdsSimpleSeqCountProvider, SeqCountProviderSimple, SequenceCountProviderCore, + }; + use spacepackets::MAX_SEQ_COUNT; + + #[test] + fn test_u8_counter() { + let u8_counter = SeqCountProviderSimple::new_u8(); + assert_eq!(u8_counter.get(), 0); + assert_eq!(u8_counter.get_and_increment(), 0); + assert_eq!(u8_counter.get_and_increment(), 1); + assert_eq!(u8_counter.get(), 2); + } + + #[test] + fn test_u8_counter_overflow() { + let u8_counter = SeqCountProviderSimple::new_u8(); + for _ in 0..256 { + u8_counter.increment(); + } + assert_eq!(u8_counter.get(), 0); + } + + #[test] + fn test_ccsds_counter() { + let ccsds_counter = CcsdsSimpleSeqCountProvider::default(); + assert_eq!(ccsds_counter.get(), 0); + assert_eq!(ccsds_counter.get_and_increment(), 0); + assert_eq!(ccsds_counter.get_and_increment(), 1); + assert_eq!(ccsds_counter.get(), 2); + } + + #[test] + fn test_ccsds_counter_overflow() { + let ccsds_counter = CcsdsSimpleSeqCountProvider::default(); + for _ in 0..MAX_SEQ_COUNT + 1 { + ccsds_counter.increment(); + } + assert_eq!(ccsds_counter.get(), 0); + } +} diff --git a/satrs-example/src/main.rs b/satrs-example/src/main.rs index 99a475c..743a148 100644 --- a/satrs-example/src/main.rs +++ b/satrs-example/src/main.rs @@ -36,7 +36,7 @@ use satrs_core::pus::verification::{ MpscVerifSender, VerificationReporterCfg, VerificationReporterWithSender, }; use satrs_core::pus::MpscTmtcInStoreSender; -use satrs_core::seq_count::{SeqCountProviderSimple, SequenceCountProviderCore}; +use satrs_core::seq_count::{CcsdsSimpleSeqCountProvider, SequenceCountProviderCore}; use satrs_core::spacepackets::tm::PusTmZeroCopyWriter; use satrs_core::spacepackets::{ time::cds::TimeProvider, @@ -79,7 +79,7 @@ fn main() { pool: Arc::new(RwLock::new(Box::new(tc_pool))), }; - let seq_count_provider = SeqCountProviderSimple::new(); + let seq_count_provider = CcsdsSimpleSeqCountProvider::new(); let mut msg_counter_map: HashMap = HashMap::new(); let sock_addr = SocketAddr::new(IpAddr::V4(OBSW_SERVER_ADDR), SERVER_PORT); let (tc_source_tx, tc_source_rx) = channel();