diff --git a/satrs/src/action.rs b/satrs/src/action.rs index c51a99a..8803b7a 100644 --- a/satrs/src/action.rs +++ b/satrs/src/action.rs @@ -41,7 +41,7 @@ impl TargetedActionRequest { } } -/// A reply to an action request. +/// A reply to an action request specific to PUS. #[non_exhaustive] #[derive(Clone, Debug)] pub enum ActionReply { diff --git a/satrs/src/lib.rs b/satrs/src/lib.rs index 5040d58..b3374f9 100644 --- a/satrs/src/lib.rs +++ b/satrs/src/lib.rs @@ -32,6 +32,9 @@ pub mod events; #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] pub mod executable; pub mod hal; +#[cfg(feature = "std")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] +pub mod mode_tree; pub mod objects; pub mod pool; pub mod power; @@ -49,8 +52,7 @@ pub mod params; pub use spacepackets; -/// Generic channel ID type. -pub type ChannelId = u32; +pub use queue::ChannelId; /// Generic target ID type. pub type TargetId = u64; diff --git a/satrs/src/mode.rs b/satrs/src/mode.rs index c5968b4..1346226 100644 --- a/satrs/src/mode.rs +++ b/satrs/src/mode.rs @@ -5,19 +5,22 @@ use spacepackets::ByteConversionError; use crate::TargetId; +pub type Mode = u32; +pub type Submode = u16; + #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ModeAndSubmode { - mode: u32, - submode: u16, + mode: Mode, + submode: Submode, } impl ModeAndSubmode { - pub const fn new_mode_only(mode: u32) -> Self { + pub const fn new_mode_only(mode: Mode) -> Self { Self { mode, submode: 0 } } - pub const fn new(mode: u32, submode: u16) -> Self { + pub const fn new(mode: Mode, submode: Submode) -> Self { Self { mode, submode } } @@ -33,16 +36,20 @@ impl ModeAndSubmode { }); } Ok(Self { - mode: u32::from_be_bytes(buf[0..4].try_into().unwrap()), - submode: u16::from_be_bytes(buf[4..6].try_into().unwrap()), + mode: Mode::from_be_bytes(buf[0..size_of::()].try_into().unwrap()), + submode: Submode::from_be_bytes( + buf[size_of::()..size_of::() + size_of::()] + .try_into() + .unwrap(), + ), }) } - pub fn mode(&self) -> u32 { + pub fn mode(&self) -> Mode { self.mode } - pub fn submode(&self) -> u16 { + pub fn submode(&self) -> Submode { self.submode } } @@ -87,6 +94,20 @@ pub enum ModeRequest { AnnounceModeRecursive, } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ModeReply { + /// Unrequest mode information. Can be used to notify other components of changed modes. + ModeInfo(ModeAndSubmode), + /// Reply to a mode request to confirm the commanded mode was reached. + ModeReply(ModeAndSubmode), + CantReachMode(ModeAndSubmode), + WrongMode { + expected: ModeAndSubmode, + reached: ModeAndSubmode, + }, +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct TargetedModeRequest { diff --git a/satrs/src/mode_tree.rs b/satrs/src/mode_tree.rs new file mode 100644 index 0000000..62e622d --- /dev/null +++ b/satrs/src/mode_tree.rs @@ -0,0 +1,367 @@ +use alloc::vec::Vec; +use hashbrown::HashMap; +use std::sync::mpsc; + +use crate::{ + mode::{Mode, ModeAndSubmode, ModeReply, ModeRequest, Submode}, + queue::GenericTargetedMessagingError, + request::{ + MessageReceiver, MessageReceiverWithId, MessageSender, MessageSenderAndReceiver, + MessageSenderMap, MessageSenderMapWithId, MessageWithSenderId, + RequestAndReplySenderAndReceiver, + }, + ChannelId, +}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum TableEntryType { + /// Target table containing information of the expected children modes for given mode. + Target, + /// Sequence table which contains information about how to reach a target table, including + /// the order of the sequences. + Sequence, +} + +pub struct ModeTableEntry { + /// Name of respective table entry. + pub name: &'static str, + /// Target channel ID. + pub channel_id: ChannelId, + pub mode_submode: ModeAndSubmode, + pub allowed_submode_mask: Option, + pub check_success: bool, +} + +pub struct ModeTableMapValue { + /// Name for a given mode table entry. + pub name: &'static str, + pub entries: Vec, +} + +pub type ModeTable = HashMap; + +pub trait ModeRequestSender { + fn local_channel_id(&self) -> ChannelId; + fn send_mode_request( + &self, + target_id: ChannelId, + request: ModeRequest, + ) -> Result<(), GenericTargetedMessagingError>; +} + +pub trait ModeReplySender { + fn local_channel_id(&self) -> ChannelId; + + fn send_mode_reply( + &self, + target_id: ChannelId, + reply: ModeReply, + ) -> Result<(), GenericTargetedMessagingError>; +} + +pub trait ModeRequestReceiver { + fn try_recv_mode_request( + &self, + ) -> Result>, GenericTargetedMessagingError>; +} + +pub trait ModeReplyReceiver { + fn try_recv_mode_reply( + &self, + ) -> Result>, GenericTargetedMessagingError>; +} + +impl> MessageSenderMap { + pub fn send_mode_request( + &self, + local_id: ChannelId, + target_id: ChannelId, + request: ModeRequest, + ) -> Result<(), GenericTargetedMessagingError> { + self.send_message(local_id, target_id, request) + } + + pub fn add_request_target(&mut self, target_id: ChannelId, request_sender: S) { + self.add_message_target(target_id, request_sender) + } +} + +impl> MessageSenderMap { + pub fn send_mode_reply( + &self, + local_id: ChannelId, + target_id: ChannelId, + request: ModeReply, + ) -> Result<(), GenericTargetedMessagingError> { + self.send_message(local_id, target_id, request) + } + + pub fn add_reply_target(&mut self, target_id: ChannelId, request_sender: S) { + self.add_message_target(target_id, request_sender) + } +} + +impl> ModeReplySender for MessageSenderMapWithId { + fn send_mode_reply( + &self, + target_channel_id: ChannelId, + reply: ModeReply, + ) -> Result<(), GenericTargetedMessagingError> { + self.send_message(target_channel_id, reply) + } + + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id + } +} + +impl> ModeRequestSender for MessageSenderMapWithId { + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id + } + + fn send_mode_request( + &self, + target_id: ChannelId, + request: ModeRequest, + ) -> Result<(), GenericTargetedMessagingError> { + self.send_message(target_id, request) + } +} + +impl> ModeReplyReceiver for MessageReceiverWithId { + fn try_recv_mode_reply( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.try_recv_message() + } +} + +impl> ModeRequestReceiver + for MessageReceiverWithId +{ + fn try_recv_mode_request( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.try_recv_message() + } +} + +impl, R: MessageReceiver> ModeRequestSender + for MessageSenderAndReceiver +{ + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id_generic() + } + + fn send_mode_request( + &self, + target_id: ChannelId, + request: ModeRequest, + ) -> Result<(), GenericTargetedMessagingError> { + self.message_sender_map + .send_mode_request(self.local_channel_id(), target_id, request) + } +} + +impl, R: MessageReceiver> ModeReplySender + for MessageSenderAndReceiver +{ + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id_generic() + } + + fn send_mode_reply( + &self, + target_id: ChannelId, + request: ModeReply, + ) -> Result<(), GenericTargetedMessagingError> { + self.message_sender_map + .send_mode_reply(self.local_channel_id(), target_id, request) + } +} + +impl, R: MessageReceiver> ModeReplyReceiver + for MessageSenderAndReceiver +{ + fn try_recv_mode_reply( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.message_receiver + .try_recv_message(self.local_channel_id_generic()) + } +} +impl, R: MessageReceiver> ModeRequestReceiver + for MessageSenderAndReceiver +{ + fn try_recv_mode_request( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.message_receiver + .try_recv_message(self.local_channel_id_generic()) + } +} + +pub type ModeRequestHandlerConnector = MessageSenderAndReceiver; +pub type MpscModeRequestHandlerConnector = ModeRequestHandlerConnector< + mpsc::Sender>, + mpsc::Receiver>, +>; +pub type MpscBoundedModeRequestHandlerConnector = ModeRequestHandlerConnector< + mpsc::SyncSender>, + mpsc::Receiver>, +>; + +pub type ModeRequestorConnector = MessageSenderAndReceiver; +pub type MpscModeRequestorConnector = ModeRequestorConnector< + mpsc::Sender>, + mpsc::Receiver>, +>; +pub type MpscBoundedModeRequestorConnector = ModeRequestorConnector< + mpsc::SyncSender>, + mpsc::Receiver>, +>; + +pub type ModeConnector = + RequestAndReplySenderAndReceiver; +pub type MpscModeConnector = ModeConnector< + mpsc::Sender>, + mpsc::Receiver>, + mpsc::Sender>, + mpsc::Receiver>, +>; +pub type MpscBoundedModeConnector = ModeConnector< + mpsc::SyncSender>, + mpsc::Receiver>, + mpsc::SyncSender>, + mpsc::Receiver>, +>; + +impl< + REPLY, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > RequestAndReplySenderAndReceiver +{ + pub fn add_request_target(&mut self, target_id: ChannelId, request_sender: S0) { + self.request_sender_map + .add_message_target(target_id, request_sender) + } +} + +impl< + REQUEST, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > RequestAndReplySenderAndReceiver +{ + pub fn add_reply_target(&mut self, target_id: ChannelId, reply_sender: S1) { + self.reply_sender_map + .add_message_target(target_id, reply_sender) + } +} + +impl< + REPLY, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > ModeRequestSender for RequestAndReplySenderAndReceiver +{ + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id_generic() + } + + fn send_mode_request( + &self, + target_id: ChannelId, + request: ModeRequest, + ) -> Result<(), GenericTargetedMessagingError> { + self.request_sender_map + .send_mode_request(self.local_channel_id(), target_id, request) + } +} + +impl< + REQUEST, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > ModeReplySender for RequestAndReplySenderAndReceiver +{ + fn local_channel_id(&self) -> ChannelId { + self.local_channel_id_generic() + } + + fn send_mode_reply( + &self, + target_id: ChannelId, + request: ModeReply, + ) -> Result<(), GenericTargetedMessagingError> { + self.reply_sender_map + .send_mode_reply(self.local_channel_id(), target_id, request) + } +} + +impl< + REQUEST, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > ModeReplyReceiver for RequestAndReplySenderAndReceiver +{ + fn try_recv_mode_reply( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.reply_receiver + .try_recv_message(self.local_channel_id_generic()) + } +} + +impl< + REPLY, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > ModeRequestReceiver for RequestAndReplySenderAndReceiver +{ + fn try_recv_mode_request( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.request_receiver + .try_recv_message(self.local_channel_id_generic()) + } +} + +pub trait ModeProvider { + fn mode_and_submode(&self) -> ModeAndSubmode; +} + +#[derive(Debug, Clone)] +pub enum ModeError { + Messaging(GenericTargetedMessagingError), +} + +impl From for ModeError { + fn from(value: GenericTargetedMessagingError) -> Self { + Self::Messaging(value) + } +} + +pub trait ModeRequestHandler: ModeProvider { + fn start_transition(&mut self, mode_and_submode: ModeAndSubmode) -> Result<(), ModeError>; + + fn announce_mode(&self, recursive: bool); + fn handle_mode_reached(&mut self) -> Result<(), GenericTargetedMessagingError>; +} + +#[cfg(test)] +mod tests {} diff --git a/satrs/src/pus/mod.rs b/satrs/src/pus/mod.rs index 1f62c36..43f8462 100644 --- a/satrs/src/pus/mod.rs +++ b/satrs/src/pus/mod.rs @@ -566,9 +566,9 @@ pub mod std_mod { fn recv_tc(&self) -> Result { self.receiver.try_recv().map_err(|e| match e { TryRecvError::Empty => TryRecvTmtcError::Empty, - TryRecvError::Disconnected => { - TryRecvTmtcError::Tmtc(EcssTmtcError::from(GenericReceiveError::TxDisconnected)) - } + TryRecvError::Disconnected => TryRecvTmtcError::Tmtc(EcssTmtcError::from( + GenericReceiveError::TxDisconnected(Some(self.channel_id())), + )), }) } } @@ -662,7 +662,7 @@ pub mod std_mod { self.receiver.try_recv().map_err(|e| match e { cb::TryRecvError::Empty => TryRecvTmtcError::Empty, cb::TryRecvError::Disconnected => TryRecvTmtcError::Tmtc(EcssTmtcError::from( - GenericReceiveError::TxDisconnected, + GenericReceiveError::TxDisconnected(Some(self.id())), )), }) } diff --git a/satrs/src/queue.rs b/satrs/src/queue.rs index a80f44f..7be4551 100644 --- a/satrs/src/queue.rs +++ b/satrs/src/queue.rs @@ -4,11 +4,15 @@ use std::error::Error; #[cfg(feature = "std")] use std::sync::mpsc; +/// Generic channel ID type. +pub type ChannelId = u32; + /// Generic error type for sending something via a message queue. #[derive(Debug, Copy, Clone)] pub enum GenericSendError { RxDisconnected, QueueFull(Option), + TargetDoesNotExist(ChannelId), } impl Display for GenericSendError { @@ -20,6 +24,9 @@ impl Display for GenericSendError { GenericSendError::QueueFull(max_cap) => { write!(f, "queue with max capacity of {max_cap:?} is full") } + GenericSendError::TargetDoesNotExist(target) => { + write!(f, "target queue with ID {target} does not exist") + } } } } @@ -31,14 +38,14 @@ impl Error for GenericSendError {} #[derive(Debug, Copy, Clone)] pub enum GenericReceiveError { Empty, - TxDisconnected, + TxDisconnected(Option), } impl Display for GenericReceiveError { fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { match self { - Self::TxDisconnected => { - write!(f, "tx side has disconnected") + Self::TxDisconnected(channel_id) => { + write!(f, "tx side with id {channel_id:?} has disconnected") } Self::Empty => { write!(f, "nothing to receive") @@ -50,6 +57,23 @@ impl Display for GenericReceiveError { #[cfg(feature = "std")] impl Error for GenericReceiveError {} +#[derive(Debug, Clone)] +pub enum GenericTargetedMessagingError { + Send(GenericSendError), + Receive(GenericReceiveError), +} +impl From for GenericTargetedMessagingError { + fn from(value: GenericSendError) -> Self { + Self::Send(value) + } +} + +impl From for GenericTargetedMessagingError { + fn from(value: GenericReceiveError) -> Self { + Self::Receive(value) + } +} + #[cfg(feature = "std")] impl From> for GenericSendError { fn from(_: mpsc::SendError) -> Self { diff --git a/satrs/src/request.rs b/satrs/src/request.rs index 24ca497..0ca9d03 100644 --- a/satrs/src/request.rs +++ b/satrs/src/request.rs @@ -2,12 +2,16 @@ use core::fmt; #[cfg(feature = "std")] use std::error::Error; +#[cfg(feature = "std")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] +pub use std_mod::*; + use spacepackets::{ ecss::{tc::IsPusTelecommand, PusPacket}, ByteConversionError, CcsdsPacket, }; -use crate::TargetId; +use crate::{queue::GenericTargetedMessagingError, ChannelId, TargetId}; pub type Apid = u16; @@ -108,3 +112,264 @@ impl fmt::Display for TargetAndApidId { write!(f, "{}, {}", self.apid, self.target) } } + +pub struct MessageWithSenderId { + pub sender_id: ChannelId, + pub message: MSG, +} + +impl MessageWithSenderId { + pub fn new(sender_id: ChannelId, message: MSG) -> Self { + Self { sender_id, message } + } +} + +/// Generic trait for objects which can send targeted messages. +pub trait MessageSender: Send { + fn send(&self, message: MessageWithSenderId) -> Result<(), GenericTargetedMessagingError>; +} + +// Generic trait for objects which can receive targeted messages. +pub trait MessageReceiver { + fn try_recv(&self) -> Result>, GenericTargetedMessagingError>; +} + +#[cfg(feature = "std")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] +mod std_mod { + use core::marker::PhantomData; + use std::sync::mpsc; + + use hashbrown::HashMap; + + use crate::{ + queue::{GenericReceiveError, GenericSendError, GenericTargetedMessagingError}, + ChannelId, + }; + + use super::{MessageReceiver, MessageSender, MessageWithSenderId}; + + impl MessageSender for mpsc::Sender> { + fn send( + &self, + message: MessageWithSenderId, + ) -> Result<(), GenericTargetedMessagingError> { + self.send(message) + .map_err(|_| GenericSendError::RxDisconnected)?; + Ok(()) + } + } + impl MessageSender for mpsc::SyncSender> { + fn send( + &self, + message: MessageWithSenderId, + ) -> Result<(), GenericTargetedMessagingError> { + if let Err(e) = self.try_send(message) { + match e { + mpsc::TrySendError::Full(_) => { + return Err(GenericSendError::QueueFull(None).into()); + } + mpsc::TrySendError::Disconnected(_) => todo!(), + } + } + Ok(()) + } + } + + pub struct MessageSenderMap>( + pub HashMap, + PhantomData, + ); + + pub type MpscSenderMap = MessageReceiverWithId>; + pub type MpscBoundedSenderMap = MessageReceiverWithId>; + + impl> Default for MessageSenderMap { + fn default() -> Self { + Self(Default::default(), PhantomData) + } + } + + impl> MessageSenderMap { + pub fn add_message_target(&mut self, target_id: ChannelId, message_sender: S) { + self.0.insert(target_id, message_sender); + } + + pub fn send_message( + &self, + local_channel_id: ChannelId, + target_channel_id: ChannelId, + message: MSG, + ) -> Result<(), GenericTargetedMessagingError> { + if self.0.contains_key(&target_channel_id) { + self.0 + .get(&target_channel_id) + .unwrap() + .send(MessageWithSenderId::new(local_channel_id, message)) + .map_err(|_| GenericSendError::RxDisconnected)?; + return Ok(()); + } + Err(GenericSendError::TargetDoesNotExist(target_channel_id).into()) + } + } + + pub struct MessageSenderMapWithId> { + pub local_channel_id: ChannelId, + pub message_sender_map: MessageSenderMap, + } + + impl> MessageSenderMapWithId { + pub fn new(local_channel_id: ChannelId) -> Self { + Self { + local_channel_id, + message_sender_map: Default::default(), + } + } + + pub fn send_message( + &self, + target_channel_id: ChannelId, + message: MSG, + ) -> Result<(), GenericTargetedMessagingError> { + self.message_sender_map + .send_message(self.local_channel_id, target_channel_id, message) + } + + pub fn add_message_target(&mut self, target_id: ChannelId, message_sender: S) { + self.message_sender_map + .add_message_target(target_id, message_sender) + } + } + + impl MessageReceiver for mpsc::Receiver> { + fn try_recv( + &self, + ) -> Result>, GenericTargetedMessagingError> { + match self.try_recv() { + Ok(msg) => Ok(Some(msg)), + Err(e) => match e { + mpsc::TryRecvError::Empty => Ok(None), + mpsc::TryRecvError::Disconnected => { + Err(GenericReceiveError::TxDisconnected(None).into()) + } + }, + } + } + } + + pub struct MessageWithSenderIdReceiver>(pub R, PhantomData); + + impl> From for MessageWithSenderIdReceiver { + fn from(receiver: R) -> Self { + MessageWithSenderIdReceiver(receiver, PhantomData) + } + } + + impl> MessageWithSenderIdReceiver { + pub fn try_recv_message( + &self, + _local_id: ChannelId, + ) -> Result>, GenericTargetedMessagingError> { + self.0.try_recv() + } + } + + pub struct MessageReceiverWithId> { + local_channel_id: ChannelId, + reply_receiver: MessageWithSenderIdReceiver, + } + + pub type MpscMessageReceiverWithId = MessageReceiverWithId>; + + impl> MessageReceiverWithId { + pub fn new( + local_channel_id: ChannelId, + reply_receiver: MessageWithSenderIdReceiver, + ) -> Self { + Self { + local_channel_id, + reply_receiver, + } + } + + pub fn local_channel_id(&self) -> ChannelId { + self.local_channel_id + } + } + + impl> MessageReceiverWithId { + pub fn try_recv_message( + &self, + ) -> Result>, GenericTargetedMessagingError> { + self.reply_receiver.0.try_recv() + } + } + + pub struct MessageSenderAndReceiver, R: MessageReceiver> { + pub local_channel_id: ChannelId, + pub message_sender_map: MessageSenderMap, + pub message_receiver: MessageWithSenderIdReceiver, + } + + impl, R: MessageReceiver> + MessageSenderAndReceiver + { + pub fn new(local_channel_id: ChannelId, message_receiver: R) -> Self { + Self { + local_channel_id, + message_sender_map: Default::default(), + message_receiver: MessageWithSenderIdReceiver::from(message_receiver), + } + } + + pub fn add_message_target(&mut self, target_id: ChannelId, message_sender: S) { + self.message_sender_map + .add_message_target(target_id, message_sender) + } + + pub fn local_channel_id_generic(&self) -> ChannelId { + self.local_channel_id + } + } + + pub struct RequestAndReplySenderAndReceiver< + REQUEST, + REPLY, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > { + pub local_channel_id: ChannelId, + // These 2 are a functional group. + pub request_sender_map: MessageSenderMap, + pub reply_receiver: MessageWithSenderIdReceiver, + // These 2 are a functional group. + pub request_receiver: MessageWithSenderIdReceiver, + pub reply_sender_map: MessageSenderMap, + } + + impl< + REQUEST, + REPLY, + S0: MessageSender, + R0: MessageReceiver, + S1: MessageSender, + R1: MessageReceiver, + > RequestAndReplySenderAndReceiver + { + pub fn new(local_channel_id: ChannelId, request_receiver: R1, reply_receiver: R0) -> Self { + Self { + local_channel_id, + request_receiver: request_receiver.into(), + reply_receiver: reply_receiver.into(), + request_sender_map: Default::default(), + reply_sender_map: Default::default(), + } + } + + pub fn local_channel_id_generic(&self) -> ChannelId { + self.local_channel_id + } + } +} diff --git a/satrs/tests/mode_tree.rs b/satrs/tests/mode_tree.rs new file mode 100644 index 0000000..3c8680f --- /dev/null +++ b/satrs/tests/mode_tree.rs @@ -0,0 +1,287 @@ +use std::{println, sync::mpsc}; + +use satrs::mode_tree::ModeRequestSender; +use satrs::{ + mode::{ModeAndSubmode, ModeReply, ModeRequest}, + mode_tree::{ + ModeError, ModeProvider, ModeReplyReceiver, ModeReplySender, ModeRequestHandler, + ModeRequestReceiver, MpscBoundedModeConnector, MpscBoundedModeRequestHandlerConnector, + MpscBoundedModeRequestorConnector, + }, + queue::GenericTargetedMessagingError, + request::MessageWithSenderId, + ChannelId, +}; +use std::string::{String, ToString}; + +pub enum TestChannelId { + Device1 = 1, + Device2 = 2, + Assembly = 3, + PusModeService = 4, +} + +struct PusModeService { + pub mode_node: MpscBoundedModeRequestorConnector, +} + +impl PusModeService { + pub fn send_announce_mode_cmd_to_assy(&self) { + self.mode_node + .send_mode_request( + TestChannelId::Assembly as u32, + ModeRequest::AnnounceModeRecursive, + ) + .unwrap(); + } +} + +struct TestDevice { + pub name: String, + pub mode_node: MpscBoundedModeRequestHandlerConnector, + pub mode_and_submode: ModeAndSubmode, + pub mode_req_commander: Option, +} + +impl TestDevice { + pub fn run(&mut self) { + self.check_mode_requests().expect("mode messaging error"); + } + + pub fn check_mode_requests(&mut self) -> Result<(), GenericTargetedMessagingError> { + if let Some(request_and_id) = self.mode_node.try_recv_mode_request()? { + match request_and_id.message { + ModeRequest::SetMode(mode_and_submode) => { + self.start_transition(mode_and_submode).unwrap(); + self.mode_req_commander = Some(request_and_id.sender_id); + } + ModeRequest::ReadMode => self + .mode_node + .send_mode_reply( + request_and_id.sender_id, + ModeReply::ModeReply(self.mode_and_submode), + ) + .unwrap(), + ModeRequest::AnnounceMode => self.announce_mode(false), + ModeRequest::AnnounceModeRecursive => self.announce_mode(true), + } + } + Ok(()) + } +} + +impl ModeProvider for TestDevice { + fn mode_and_submode(&self) -> ModeAndSubmode { + self.mode_and_submode + } +} +impl ModeRequestHandler for TestDevice { + fn start_transition(&mut self, mode_and_submode: ModeAndSubmode) -> Result<(), ModeError> { + self.mode_and_submode = mode_and_submode; + self.handle_mode_reached()?; + Ok(()) + } + + fn announce_mode(&self, _recursive: bool) { + println!( + "{}: announcing mode: {:?}", + self.name, self.mode_and_submode + ); + } + + fn handle_mode_reached(&mut self) -> Result<(), GenericTargetedMessagingError> { + self.mode_node.send_mode_reply( + self.mode_req_commander.unwrap(), + ModeReply::ModeReply(self.mode_and_submode), + )?; + Ok(()) + } +} + +struct TestAssembly { + pub mode_node: MpscBoundedModeConnector, + pub mode_req_commander: Option, + pub mode_and_submode: ModeAndSubmode, + pub target_mode_and_submode: Option, +} + +impl ModeProvider for TestAssembly { + fn mode_and_submode(&self) -> ModeAndSubmode { + self.mode_and_submode + } +} + +impl TestAssembly { + pub fn run(&mut self) { + self.check_mode_requests().expect("mode messaging error"); + self.check_mode_replies().expect("mode messaging error"); + } + + pub fn check_mode_requests(&mut self) -> Result<(), GenericTargetedMessagingError> { + if let Some(request_and_id) = self.mode_node.try_recv_mode_request()? { + match request_and_id.message { + ModeRequest::SetMode(mode_and_submode) => { + self.start_transition(mode_and_submode).unwrap(); + self.mode_req_commander = Some(request_and_id.sender_id); + } + ModeRequest::ReadMode => self + .mode_node + .send_mode_reply( + request_and_id.sender_id, + ModeReply::ModeReply(self.mode_and_submode), + ) + .unwrap(), + ModeRequest::AnnounceMode => self.announce_mode(false), + ModeRequest::AnnounceModeRecursive => self.announce_mode(true), + } + } + Ok(()) + } + + pub fn check_mode_replies(&mut self) -> Result<(), GenericTargetedMessagingError> { + if let Some(reply_and_id) = self.mode_node.try_recv_mode_reply()? { + match reply_and_id.message { + ModeReply::ModeInfo(_) => todo!(), + ModeReply::ModeReply(reply) => { + println!( + "TestAssembly: Received mode reply from {:?}, reached: {:?}", + reply_and_id.sender_id, reply + ); + } + ModeReply::CantReachMode(_) => todo!(), + ModeReply::WrongMode { expected, reached } => { + println!( + "TestAssembly: Wrong mode reply from {:?}, reached {:?}, expected {:?}", + reply_and_id.sender_id, reached, expected + ); + } + } + } + Ok(()) + } +} + +impl ModeRequestHandler for TestAssembly { + fn start_transition(&mut self, mode_and_submode: ModeAndSubmode) -> Result<(), ModeError> { + self.target_mode_and_submode = Some(mode_and_submode); + Ok(()) + } + + fn announce_mode(&self, recursive: bool) { + println!( + "TestAssembly: Announcing mode (recursively: {}): {:?}", + recursive, self.mode_and_submode + ); + let mut mode_request = ModeRequest::AnnounceMode; + if recursive { + mode_request = ModeRequest::AnnounceModeRecursive; + } + self.mode_node + .request_sender_map + .0 + .iter() + .for_each(|(_, sender)| { + sender + .send(MessageWithSenderId::new( + self.mode_node.local_channel_id_generic(), + mode_request, + )) + .expect("sending mode request failed"); + }); + } + + fn handle_mode_reached(&mut self) -> Result<(), GenericTargetedMessagingError> { + self.mode_node.send_mode_reply( + self.mode_req_commander.unwrap(), + ModeReply::ModeReply(self.mode_and_submode), + )?; + Ok(()) + } +} + +fn main() { + // All request channel handles. + let (request_sender_to_dev1, request_receiver_dev1) = mpsc::sync_channel(10); + let (request_sender_to_dev2, request_receiver_dev2) = mpsc::sync_channel(10); + let (request_sender_to_assy, request_receiver_assy) = mpsc::sync_channel(10); + + // All reply channel handles. + let (reply_sender_to_assy, reply_receiver_assy) = mpsc::sync_channel(10); + let (reply_sender_to_pus, reply_receiver_pus) = mpsc::sync_channel(10); + + // Mode requestors and handlers. + let mut mode_node_assy = MpscBoundedModeConnector::new( + TestChannelId::Assembly as u32, + request_receiver_assy, + reply_receiver_assy, + ); + // Mode requestors only. + let mut mode_node_pus = MpscBoundedModeRequestorConnector::new( + TestChannelId::PusModeService as u32, + reply_receiver_pus, + ); + + // Request handlers only. + let mut mode_node_dev1 = MpscBoundedModeRequestHandlerConnector::new( + TestChannelId::Device1 as u32, + request_receiver_dev1, + ); + let mut mode_node_dev2 = MpscBoundedModeRequestHandlerConnector::new( + TestChannelId::Device2 as u32, + request_receiver_dev2, + ); + + // Set up mode request senders first. + mode_node_pus.add_message_target(TestChannelId::Assembly as u32, request_sender_to_assy); + mode_node_pus.add_message_target( + TestChannelId::Device1 as u32, + request_sender_to_dev1.clone(), + ); + mode_node_pus.add_message_target( + TestChannelId::Device2 as u32, + request_sender_to_dev2.clone(), + ); + mode_node_assy.add_request_target(TestChannelId::Device1 as u32, request_sender_to_dev1); + mode_node_assy.add_request_target(TestChannelId::Device2 as u32, request_sender_to_dev2); + + // Set up mode reply senders. + mode_node_dev1.add_message_target(TestChannelId::Assembly as u32, reply_sender_to_assy.clone()); + mode_node_dev1.add_message_target( + TestChannelId::PusModeService as u32, + reply_sender_to_pus.clone(), + ); + mode_node_dev2.add_message_target(TestChannelId::Assembly as u32, reply_sender_to_assy); + mode_node_dev2.add_message_target( + TestChannelId::PusModeService as u32, + reply_sender_to_pus.clone(), + ); + mode_node_assy.add_reply_target(TestChannelId::PusModeService as u32, reply_sender_to_pus); + + let mut device1 = TestDevice { + name: "Test Device 1".to_string(), + mode_node: mode_node_dev1, + mode_req_commander: None, + mode_and_submode: ModeAndSubmode::new(0, 0), + }; + let mut device2 = TestDevice { + name: "Test Device 2".to_string(), + mode_node: mode_node_dev2, + mode_req_commander: None, + mode_and_submode: ModeAndSubmode::new(0, 0), + }; + let mut assy = TestAssembly { + mode_node: mode_node_assy, + mode_req_commander: None, + mode_and_submode: ModeAndSubmode::new(0, 0), + target_mode_and_submode: None, + }; + let pus_service = PusModeService { + mode_node: mode_node_pus, + }; + + pus_service.send_announce_mode_cmd_to_assy(); + assy.run(); + device1.run(); + device2.run(); + assy.run(); +}