From 4afc360873ec1970722d4cb6866710f1ae98ba9e Mon Sep 17 00:00:00 2001 From: Robin Mueller Date: Mon, 11 Mar 2024 17:49:55 +0100 Subject: [PATCH] introduce active request table abstraction --- satrs/src/pus/action.rs | 126 ++++++++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 31 deletions(-) diff --git a/satrs/src/pus/action.rs b/satrs/src/pus/action.rs index 5db8275..15aa105 100644 --- a/satrs/src/pus/action.rs +++ b/satrs/src/pus/action.rs @@ -1,3 +1,5 @@ +use core::time::Duration; + use crate::{ action::{ActionId, ActionRequest}, params::Params, @@ -5,10 +7,10 @@ use crate::{ TargetId, }; -use super::verification::{TcStateAccepted, VerificationToken}; +use super::verification::{TcStateAccepted, TcStateStarted, VerificationToken}; use satrs_shared::res_code::ResultU16; -use spacepackets::ecss::EcssEnumU16; +use spacepackets::{ecss::EcssEnumU16, time::UnixTimestamp}; #[cfg(feature = "std")] #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] @@ -62,6 +64,27 @@ pub trait PusActionRequestRouter { ) -> Result<(), Self::Error>; } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveActionRequest { + action_id: ActionId, + token: VerificationToken, + start_time: UnixTimestamp, + timeout: Duration, +} + +pub trait ActiveRequestMapProvider: Default { + fn insert(&mut self, request_id: &RequestId, request: ActiveActionRequest); + fn get(&self, request_id: RequestId) -> Option; + fn get_mut(&mut self, request_id: RequestId) -> Option<&mut ActiveActionRequest>; + fn remove(&mut self, request_id: RequestId) -> bool; + + /// Call a user-supplied closure for each active request. + fn for_each(&self, f: F); + + /// Call a user-supplied closure for each active request. Mutable variant. + fn for_each_mut(&mut self, f: F); +} + #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod alloc_mod { @@ -228,12 +251,40 @@ pub mod std_mod { } } - #[derive(Debug, Clone, PartialEq, Eq)] - pub struct ActiveActionRequest { - action_id: ActionId, - token: VerificationToken, - start_time: UnixTimestamp, - timeout: Duration, + #[derive(Clone, Debug, Default)] + pub struct DefaultActiveRequestMap(HashMap); + + impl ActiveRequestMapProvider for DefaultActiveRequestMap { + // type Iter = hashbrown::hash_map::Iter<'a, RequestId, ActiveActionRequest>; + // type IterMut = hashbrown::hash_map::IterMut<'a, RequestId, ActiveActionRequest>; + + fn insert(&mut self, request_id: &RequestId, request: ActiveActionRequest) { + self.0.insert(*request_id, request); + } + + fn get(&self, request_id: RequestId) -> Option { + self.0.get(&request_id).cloned() + } + + fn get_mut(&mut self, request_id: RequestId) -> Option<&mut ActiveActionRequest> { + self.0.get_mut(&request_id) + } + + fn remove(&mut self, request_id: RequestId) -> bool { + self.0.remove(&request_id).is_some() + } + + fn for_each(&self, mut f: F) { + for (req_id, active_req) in &self.0 { + f(req_id, active_req); + } + } + + fn for_each_mut(&mut self, mut f: F) { + for (req_id, active_req) in &mut self.0 { + f(req_id, active_req); + } + } } pub trait ActionReplyHandlerHook { @@ -244,17 +295,21 @@ pub mod std_mod { pub struct PusService8ReplyHandler< VerificationReporter: VerificationReportingProvider, + ActiveRequestMap: ActiveRequestMapProvider, UserHook: ActionReplyHandlerHook, > { - active_requests: HashMap, + active_requests: ActiveRequestMap, verification_reporter: VerificationReporter, fail_data_buf: alloc::vec::Vec, current_time: UnixTimestamp, user_hook: UserHook, } - impl - PusService8ReplyHandler + impl< + VerificationReporter: VerificationReportingProvider, + ActiveRequestMap: ActiveRequestMapProvider, + UserHook: ActionReplyHandlerHook, + > PusService8ReplyHandler { #[cfg(feature = "std")] #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] @@ -263,15 +318,13 @@ pub mod std_mod { fail_data_buf_size: usize, user_hook: UserHook, ) -> Result { - let mut reply_handler = Self { - active_requests: HashMap::new(), + let current_time = UnixTimestamp::from_now()?; + Ok(Self::new( verification_reporter, - fail_data_buf: alloc::vec![0; fail_data_buf_size], - current_time: UnixTimestamp::default(), + fail_data_buf_size, user_hook, - }; - reply_handler.update_time_from_now()?; - Ok(reply_handler) + current_time, + )) } pub fn new( @@ -281,7 +334,7 @@ pub mod std_mod { init_time: UnixTimestamp, ) -> Self { Self { - active_requests: HashMap::new(), + active_requests: ActiveRequestMap::default(), verification_reporter, fail_data_buf: alloc::vec![0; fail_data_buf_size], current_time: init_time, @@ -296,7 +349,7 @@ pub mod std_mod { timeout: Duration, ) { self.active_requests.insert( - request_id.into(), + &request_id.into(), ActiveActionRequest { action_id, token, @@ -317,11 +370,16 @@ pub mod std_mod { /// /// It will call [Self::handle_timeout] for all active requests which have timed out. pub fn check_for_timeouts(&mut self, time_stamp: &[u8]) -> Result<(), EcssTmtcError> { - for active_req in self.active_requests.values() { + let mut timed_out_commands = alloc::vec::Vec::new(); + self.active_requests.for_each(|request_id, active_req| { let diff = self.current_time - active_req.start_time; if diff.duration_absolute > active_req.timeout { self.handle_timeout(active_req, time_stamp); } + timed_out_commands.push(*request_id); + }); + for timed_out_command in timed_out_commands { + self.active_requests.remove(timed_out_command); } Ok(()) } @@ -352,13 +410,13 @@ pub mod std_mod { action_reply_with_ids: ActionReplyPusWithIds, time_stamp: &[u8], ) -> Result<(), EcssTmtcError> { - let active_req = self.active_requests.get(&action_reply_with_ids.request_id); + let active_req = self.active_requests.get(action_reply_with_ids.request_id); if active_req.is_none() { self.user_hook .handle_unexpected_reply(&action_reply_with_ids); } - let active_req = active_req.unwrap(); - match action_reply_with_ids.reply { + let active_req = active_req.unwrap().clone(); + let remove_entry = match action_reply_with_ids.reply { ActionReplyPus::CompletionFailed { error_code, params } => { let fail_data_len = params.write_to_be_bytes(&mut self.fail_data_buf)?; self.verification_reporter @@ -371,8 +429,7 @@ pub mod std_mod { ), ) .map_err(|e| e.0)?; - self.active_requests - .remove(&action_reply_with_ids.request_id); + true } ActionReplyPus::StepFailed { error_code, @@ -391,15 +448,13 @@ pub mod std_mod { ), ) .map_err(|e| e.0)?; - self.active_requests - .remove(&action_reply_with_ids.request_id); + true } ActionReplyPus::Completed => { self.verification_reporter .completion_success(active_req.token, time_stamp) .map_err(|e| e.0)?; - self.active_requests - .remove(&action_reply_with_ids.request_id); + true } ActionReplyPus::StepSuccess { step } => { self.verification_reporter.step_success( @@ -407,7 +462,12 @@ pub mod std_mod { time_stamp, EcssEnumU16::new(step), )?; + false } + }; + if remove_entry { + self.active_requests + .remove(action_reply_with_ids.request_id); } Ok(()) } @@ -606,7 +666,11 @@ mod tests { pub struct Pus8ReplyTestbench { verif_reporter: TestVerificationReporter, - handler: PusService8ReplyHandler, + handler: PusService8ReplyHandler< + TestVerificationReporter, + DefaultActiveRequestMap, + TestReplyHandlerHook, + >, } impl Pus8ReplyTestbench {