diff --git a/satrs-example/src/main.rs b/satrs-example/src/main.rs index a6456d6..8e6ee8f 100644 --- a/satrs-example/src/main.rs +++ b/satrs-example/src/main.rs @@ -225,7 +225,7 @@ fn static_tmtc_pool_main() { info!("Starting TMTC and UDP task"); let jh_udp_tmtc = thread::Builder::new() - .name("TMTC and UDP".to_string()) + .name("SATRS tmtc-udp".to_string()) .spawn(move || { info!("Running UDP server on port {SERVER_PORT}"); loop { @@ -238,7 +238,7 @@ fn static_tmtc_pool_main() { info!("Starting TCP task"); let jh_tcp = thread::Builder::new() - .name("TCP".to_string()) + .name("sat-rs tcp".to_string()) .spawn(move || { info!("Running TCP server on port {SERVER_PORT}"); loop { @@ -257,7 +257,7 @@ fn static_tmtc_pool_main() { info!("Starting event handling task"); let jh_event_handling = thread::Builder::new() - .name("Event".to_string()) + .name("sat-rs events".to_string()) .spawn(move || loop { event_handler.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_EVENT_HANDLING)); @@ -266,7 +266,7 @@ fn static_tmtc_pool_main() { info!("Starting AOCS thread"); let jh_aocs = thread::Builder::new() - .name("AOCS".to_string()) + .name("sat-rs aocs".to_string()) .spawn(move || loop { mgm_handler.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_AOCS)); @@ -275,7 +275,7 @@ fn static_tmtc_pool_main() { info!("Starting PUS handler thread"); let jh_pus_handler = thread::Builder::new() - .name("PUS".to_string()) + .name("sat-rs pus".to_string()) .spawn(move || loop { pus_stack.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_PUS_STACK)); @@ -444,7 +444,7 @@ fn dyn_tmtc_pool_main() { info!("Starting TMTC and UDP task"); let jh_udp_tmtc = thread::Builder::new() - .name("TMTC and UDP".to_string()) + .name("sat-rs tmtc-udp".to_string()) .spawn(move || { info!("Running UDP server on port {SERVER_PORT}"); loop { @@ -457,7 +457,7 @@ fn dyn_tmtc_pool_main() { info!("Starting TCP task"); let jh_tcp = thread::Builder::new() - .name("TCP".to_string()) + .name("sat-rs tcp".to_string()) .spawn(move || { info!("Running TCP server on port {SERVER_PORT}"); loop { @@ -468,7 +468,7 @@ fn dyn_tmtc_pool_main() { info!("Starting TM funnel task"); let jh_tm_funnel = thread::Builder::new() - .name("TM Funnel".to_string()) + .name("sat-rs tm-funnel".to_string()) .spawn(move || loop { tm_funnel.operation(); }) @@ -476,7 +476,7 @@ fn dyn_tmtc_pool_main() { info!("Starting event handling task"); let jh_event_handling = thread::Builder::new() - .name("Event".to_string()) + .name("sat-rs events".to_string()) .spawn(move || loop { event_handler.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_EVENT_HANDLING)); @@ -485,7 +485,7 @@ fn dyn_tmtc_pool_main() { info!("Starting AOCS thread"); let jh_aocs = thread::Builder::new() - .name("AOCS".to_string()) + .name("sat-rs aocs".to_string()) .spawn(move || loop { mgm_handler.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_AOCS)); @@ -494,7 +494,7 @@ fn dyn_tmtc_pool_main() { info!("Starting PUS handler thread"); let jh_pus_handler = thread::Builder::new() - .name("PUS".to_string()) + .name("sat-rs pus".to_string()) .spawn(move || loop { pus_stack.periodic_operation(); thread::sleep(Duration::from_millis(FREQ_MS_PUS_STACK)); diff --git a/satrs-example/src/pus/action.rs b/satrs-example/src/pus/action.rs index 22b6b93..59859fc 100644 --- a/satrs-example/src/pus/action.rs +++ b/satrs-example/src/pus/action.rs @@ -267,7 +267,7 @@ impl Targete for ActionServiceWrapper { /// Returns [true] if the packet handling is finished. - fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool { + fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus { match self.service.poll_and_handle_next_tc(time_stamp) { Ok(result) => match result { PusPacketHandlerResult::RequestHandled => {} @@ -280,15 +280,13 @@ impl Targete PusPacketHandlerResult::SubserviceNotImplemented(subservice, _) => { warn!("PUS 8 subservice {subservice} not implemented"); } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, }, Err(error) => { error!("PUS packet handling error: {error:?}") } } - false + HandlingStatus::HandledOne } fn poll_and_handle_next_reply(&mut self, time_stamp: &[u8]) -> HandlingStatus { diff --git a/satrs-example/src/pus/event.rs b/satrs-example/src/pus/event.rs index 865b1f1..23cc2ca 100644 --- a/satrs-example/src/pus/event.rs +++ b/satrs-example/src/pus/event.rs @@ -13,6 +13,8 @@ use satrs::pus::{ }; use satrs_example::config::components::PUS_EVENT_MANAGEMENT; +use super::HandlingStatus; + pub fn create_event_service_static( tm_sender: TmInSharedPoolSender>, tc_pool: SharedStaticMemoryPool, @@ -62,7 +64,7 @@ pub struct EventServiceWrapper EventServiceWrapper { - pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool { + pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus { match self.handler.poll_and_handle_next_tc(time_stamp) { Ok(result) => match result { PusPacketHandlerResult::RequestHandled => {} @@ -75,14 +77,12 @@ impl PusPacketHandlerResult::SubserviceNotImplemented(subservice, _) => { warn!("PUS 5 subservice {subservice} not implemented"); } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, }, Err(error) => { error!("PUS packet handling error: {error:?}") } } - false + HandlingStatus::HandledOne } } diff --git a/satrs-example/src/pus/hk.rs b/satrs-example/src/pus/hk.rs index cb3ebb9..92f74ba 100644 --- a/satrs-example/src/pus/hk.rs +++ b/satrs-example/src/pus/hk.rs @@ -300,7 +300,7 @@ pub struct HkServiceWrapper HkServiceWrapper { - pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool { + pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus { match self.service.poll_and_handle_next_tc(time_stamp) { Ok(result) => match result { PusPacketHandlerResult::RequestHandled => {} @@ -313,15 +313,13 @@ impl PusPacketHandlerResult::SubserviceNotImplemented(subservice, _) => { warn!("PUS 3 subservice {subservice} not implemented"); } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, }, Err(error) => { error!("PUS packet handling error: {error:?}") } } - false + HandlingStatus::HandledOne } pub fn poll_and_handle_next_reply(&mut self, time_stamp: &[u8]) -> HandlingStatus { diff --git a/satrs-example/src/pus/mod.rs b/satrs-example/src/pus/mod.rs index 83bd34a..28c645a 100644 --- a/satrs-example/src/pus/mod.rs +++ b/satrs-example/src/pus/mod.rs @@ -157,7 +157,7 @@ impl PusReceiver { pub trait TargetedPusService { /// Returns [true] if the packet handling is finished. - fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool; + fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus; fn poll_and_handle_next_reply(&mut self, time_stamp: &[u8]) -> HandlingStatus; fn check_for_request_timeouts(&mut self); } diff --git a/satrs-example/src/pus/mode.rs b/satrs-example/src/pus/mode.rs index 4f2ff13..36a6ee6 100644 --- a/satrs-example/src/pus/mode.rs +++ b/satrs-example/src/pus/mode.rs @@ -272,7 +272,7 @@ impl Targete for ModeServiceWrapper { /// Returns [true] if the packet handling is finished. - fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool { + fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus { match self.service.poll_and_handle_next_tc(time_stamp) { Ok(result) => match result { PusPacketHandlerResult::RequestHandled => {} @@ -285,15 +285,13 @@ impl Targete PusPacketHandlerResult::SubserviceNotImplemented(subservice, _) => { warn!("PUS mode service: {subservice} not implemented"); } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, }, Err(error) => { error!("PUS mode service: packet handling error: {error:?}") } } - false + HandlingStatus::HandledOne } fn poll_and_handle_next_reply(&mut self, time_stamp: &[u8]) -> HandlingStatus { diff --git a/satrs-example/src/pus/scheduler.rs b/satrs-example/src/pus/scheduler.rs index d75c666..a774577 100644 --- a/satrs-example/src/pus/scheduler.rs +++ b/satrs-example/src/pus/scheduler.rs @@ -16,6 +16,8 @@ use satrs_example::config::components::PUS_SCHED_SERVICE; use crate::tmtc::PusTcSourceProviderSharedPool; +use super::HandlingStatus; + pub trait TcReleaser { fn release(&mut self, enabled: bool, info: &TcInfo, tc: &[u8]) -> bool; } @@ -92,7 +94,7 @@ impl } } - pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> bool { + pub fn poll_and_handle_next_tc(&mut self, time_stamp: &[u8]) -> HandlingStatus { match self .pus_11_handler .poll_and_handle_next_tc(time_stamp, &mut self.sched_tc_pool) @@ -108,15 +110,13 @@ impl PusPacketHandlerResult::SubserviceNotImplemented(subservice, _) => { warn!("PUS11: Subservice {subservice} not implemented"); } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, }, Err(error) => { error!("PUS packet handling error: {error:?}") } } - false + HandlingStatus::HandledOne } } diff --git a/satrs-example/src/pus/stack.rs b/satrs-example/src/pus/stack.rs index a11463c..524e35b 100644 --- a/satrs-example/src/pus/stack.rs +++ b/satrs-example/src/pus/stack.rs @@ -35,18 +35,29 @@ impl loop { let mut nothing_to_do = true; let mut is_srv_finished = - |tc_handling_done: bool, reply_handling_done: Option| { - if !tc_handling_done + |_srv_id: u8, + tc_handling_done: HandlingStatus, + reply_handling_done: Option| { + if tc_handling_done == HandlingStatus::HandledOne || (reply_handling_done.is_some() - && reply_handling_done.unwrap() == HandlingStatus::Empty) + && reply_handling_done.unwrap() == HandlingStatus::HandledOne) { nothing_to_do = false; } }; - is_srv_finished(self.test_srv.poll_and_handle_next_packet(&time_stamp), None); - is_srv_finished(self.schedule_srv.poll_and_handle_next_tc(&time_stamp), None); - is_srv_finished(self.event_srv.poll_and_handle_next_tc(&time_stamp), None); is_srv_finished( + 17, + self.test_srv.poll_and_handle_next_packet(&time_stamp), + None, + ); + is_srv_finished( + 11, + self.schedule_srv.poll_and_handle_next_tc(&time_stamp), + None, + ); + is_srv_finished(5, self.event_srv.poll_and_handle_next_tc(&time_stamp), None); + is_srv_finished( + 8, self.action_srv_wrapper.poll_and_handle_next_tc(&time_stamp), Some( self.action_srv_wrapper @@ -54,10 +65,12 @@ impl ), ); is_srv_finished( + 3, self.hk_srv_wrapper.poll_and_handle_next_tc(&time_stamp), Some(self.hk_srv_wrapper.poll_and_handle_next_reply(&time_stamp)), ); is_srv_finished( + 200, self.mode_srv.poll_and_handle_next_tc(&time_stamp), Some(self.mode_srv.poll_and_handle_next_reply(&time_stamp)), ); diff --git a/satrs-example/src/pus/test.rs b/satrs-example/src/pus/test.rs index 0111026..c471897 100644 --- a/satrs-example/src/pus/test.rs +++ b/satrs-example/src/pus/test.rs @@ -18,6 +18,8 @@ use satrs_example::config::components::PUS_TEST_SERVICE; use satrs_example::config::{tmtc_err, TEST_EVENT}; use std::sync::mpsc; +use super::HandlingStatus; + pub fn create_test_service_static( tm_sender: TmInSharedPoolSender>, tc_pool: SharedStaticMemoryPool, @@ -67,11 +69,11 @@ pub struct TestCustomServiceWrapper< impl TestCustomServiceWrapper { - pub fn poll_and_handle_next_packet(&mut self, time_stamp: &[u8]) -> bool { + pub fn poll_and_handle_next_packet(&mut self, time_stamp: &[u8]) -> HandlingStatus { let res = self.handler.poll_and_handle_next_tc(time_stamp); if res.is_err() { warn!("PUS17 handler failed with error {:?}", res.unwrap_err()); - return true; + return HandlingStatus::HandledOne; } match res.unwrap() { PusPacketHandlerResult::RequestHandled => { @@ -135,10 +137,8 @@ impl .expect("Sending start failure verification failed"); } } - PusPacketHandlerResult::Empty => { - return true; - } + PusPacketHandlerResult::Empty => return HandlingStatus::Empty, } - false + HandlingStatus::HandledOne } } diff --git a/satrs-example/src/tcp.rs b/satrs-example/src/tcp.rs index 7abcf5b..7db4c2c 100644 --- a/satrs-example/src/tcp.rs +++ b/satrs-example/src/tcp.rs @@ -5,7 +5,7 @@ use std::{ use log::{info, warn}; use satrs::{ - hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer}, + hal::std::tcp_server::{HandledConnectionHandler, ServerConfig, TcpSpacepacketsServer}, pus::ReceivesEcssPusTc, spacepackets::PacketId, tmtc::{CcsdsDistributor, CcsdsError, ReceivesCcsdsTc, TmPacketSourceCore}, @@ -13,6 +13,18 @@ use satrs::{ use crate::ccsds::CcsdsReceiver; +#[derive(Default)] +pub struct ConnectionFinishedHandler {} + +impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: satrs::hal::std::tcp_server::HandledConnectionInfo) { + info!( + "Served {} TMs and {} TCs for client {:?}", + info.num_sent_tms, info.num_received_tcs, info.addr + ); + } +} + #[derive(Default, Clone)] pub struct SyncTcpTmSource { tm_queue: Arc>>>, @@ -70,11 +82,12 @@ impl TmPacketSourceCore for SyncTcpTmSource { } pub type TcpServerType = TcpSpacepacketsServer< - (), - CcsdsError, SyncTcpTmSource, CcsdsDistributor, MpscErrorType>, HashSet, + ConnectionFinishedHandler, + (), + CcsdsError, >; pub struct TcpTask< @@ -109,6 +122,7 @@ impl< tm_source, tc_receiver, packet_id_lookup, + ConnectionFinishedHandler::default(), None, )?, }) @@ -116,14 +130,9 @@ impl< pub fn periodic_operation(&mut self) { loop { - let result = self.server.handle_next_connection(); + let result = self.server.handle_next_connection(None); match result { - Ok(conn_result) => { - info!( - "Served {} TMs and {} TCs for client {:?}", - conn_result.num_sent_tms, conn_result.num_received_tcs, conn_result.addr - ); - } + Ok(_conn_result) => (), Err(e) => { warn!("TCP server error: {e:?}"); } diff --git a/satrs/CHANGELOG.md b/satrs/CHANGELOG.md index e119322..508c69b 100644 --- a/satrs/CHANGELOG.md +++ b/satrs/CHANGELOG.md @@ -22,9 +22,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/). parameter type. This message also contains the sender ID which can be useful for debugging or application layer / FDIR logic. - Stop signal handling for the TCP servers. +- TCP server now uses `mio` crate to allow non-blocking operation. The server can now handle + multiple connections at once, and the context information about handled transfers is + passed via a callback which is inserted as a generic as well. ## Changed +- TCP server generics order. The error generics come last now. - `encoding::ccsds::PacketIdValidator` renamed to `ValidatorU16Id`, which lives in the crate root. It can be used for both CCSDS packet ID and CCSDS APID validation. - `EventManager::try_event_handling` not expects a mutable error handling closure instead of diff --git a/satrs/Cargo.toml b/satrs/Cargo.toml index 7f913e6..841861a 100644 --- a/satrs/Cargo.toml +++ b/satrs/Cargo.toml @@ -70,6 +70,11 @@ version = "0.5.4" features = ["all"] optional = true +[dependencies.mio] +version = "0.8" +features = ["os-poll", "net"] +optional = true + [dependencies.spacepackets] # git = "https://egit.irs.uni-stuttgart.de/rust/spacepackets.git" version = "0.11.0-rc.2" @@ -104,7 +109,8 @@ std = [ "spacepackets/std", "num_enum/std", "thiserror", - "socket2" + "socket2", + "mio" ] alloc = [ "serde/alloc", diff --git a/satrs/src/hal/std/tcp_cobs_server.rs b/satrs/src/hal/std/tcp_cobs_server.rs index 4158408..8b044f6 100644 --- a/satrs/src/hal/std/tcp_cobs_server.rs +++ b/satrs/src/hal/std/tcp_cobs_server.rs @@ -2,11 +2,11 @@ use alloc::sync::Arc; use alloc::vec; use cobs::encode; use core::sync::atomic::AtomicBool; +use core::time::Duration; use delegate::delegate; +use mio::net::{TcpListener, TcpStream}; use std::io::Write; use std::net::SocketAddr; -use std::net::TcpListener; -use std::net::TcpStream; use std::vec::Vec; use crate::encoding::parse_buffer_for_cobs_encoded_packets; @@ -17,6 +17,9 @@ use crate::hal::std::tcp_server::{ ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, }; +use super::tcp_server::HandledConnectionHandler; +use super::tcp_server::HandledConnectionInfo; + /// Concrete [TcpTcParser] implementation for the [TcpTmtcInCobsServer]. #[derive(Default)] pub struct CobsTcParser {} @@ -26,7 +29,7 @@ impl TcpTcParser for CobsTcParser { &mut self, tc_buffer: &mut [u8], tc_receiver: &mut (impl ReceivesTc + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, current_write_idx: usize, next_write_idx: &mut usize, ) -> Result<(), TcpTmtcError> { @@ -60,7 +63,7 @@ impl TcpTmSender for CobsTmSender { &mut self, tm_buffer: &mut [u8], tm_source: &mut (impl TmPacketSource + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, stream: &mut TcpStream, ) -> Result> { let mut tm_was_sent = false; @@ -112,21 +115,30 @@ impl TcpTmSender for CobsTmSender { /// The [TCP integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// test also serves as the example application for this module. pub struct TcpTmtcInCobsServer< - TmError, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, + HandledConnection: HandledConnectionHandler, + TmError, + TcError: 'static, > { - generic_server: - TcpTmtcGenericServer, + pub generic_server: TcpTmtcGenericServer< + TmSource, + TcReceiver, + CobsTmSender, + CobsTcParser, + HandledConnection, + TmError, + TcError, + >, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, - > TcpTmtcInCobsServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > TcpTmtcInCobsServer { /// Create a new TCP TMTC server which exchanges TMTC packets encoded with /// [COBS protocol](https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing). @@ -142,6 +154,7 @@ impl< cfg: ServerConfig, tm_source: TmSource, tc_receiver: TcReceiver, + handled_connection: HandledConnection, stop_signal: Option>, ) -> Result { Ok(Self { @@ -151,6 +164,7 @@ impl< CobsTmSender::new(cfg.tm_buffer_size), tm_source, tc_receiver, + handled_connection, stop_signal, )?, }) @@ -167,6 +181,7 @@ impl< /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. pub fn handle_next_connection( &mut self, + poll_duration: Option, ) -> Result>; } } @@ -181,15 +196,15 @@ mod tests { use std::{ io::{Read, Write}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}, - thread, + panic, thread, time::Instant, }; use crate::{ encoding::tests::{INVERTED_PACKET, SIMPLE_PACKET}, hal::std::tcp_server::{ - tests::{SyncTcCacher, SyncTmSource}, - ServerConfig, + tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource}, + ConnectionResult, ServerConfig, }, }; use alloc::sync::Arc; @@ -218,11 +233,12 @@ mod tests { tc_receiver: SyncTcCacher, tm_source: SyncTmSource, stop_signal: Option>, - ) -> TcpTmtcInCobsServer<(), (), SyncTmSource, SyncTcCacher> { + ) -> TcpTmtcInCobsServer { TcpTmtcInCobsServer::new( ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver, + ConnectionFinishedHandler::default(), stop_signal, ) .expect("TCP server generation failed") @@ -242,13 +258,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 0); + let result = result.unwrap(); + assert_eq!(result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(0, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); // Send TC to server now. @@ -299,13 +322,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 2, "Not enough TCs received"); - assert_eq!(conn_result.num_sent_tms, 2, "Not enough TMs received"); + let result = result.unwrap(); + assert_eq!(result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(2, 2); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); // Send TC to server now. @@ -389,6 +419,31 @@ mod tests { drop(tc_queue); } + #[test] + fn test_server_accept_timeout() { + let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); + let tc_receiver = SyncTcCacher::default(); + let tm_source = SyncTmSource::default(); + let mut tcp_server = + generic_tmtc_server(&auto_port_addr, tc_receiver.clone(), tm_source, None); + let start = Instant::now(); + // Call the connection handler in separate thread, does block. + let thread_jh = thread::spawn(move || loop { + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20))); + if result.is_err() { + panic!("handling connection failed: {:?}", result.unwrap_err()); + } + let result = result.unwrap(); + if result == ConnectionResult::AcceptTimeout { + break; + } + if Instant::now() - start > Duration::from_millis(100) { + panic!("regular stop signal handling failed"); + } + }); + thread_jh.join().expect("thread join failed"); + } + #[test] fn test_server_stop_signal() { let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); @@ -401,17 +456,32 @@ mod tests { tm_source, Some(stop_signal.clone()), ); + let dest_addr = tcp_server + .local_addr() + .expect("retrieving dest addr failed"); + let stop_signal_copy = stop_signal.clone(); let start = Instant::now(); // Call the connection handler in separate thread, does block. - thread::spawn(move || loop { - let result = tcp_server.handle_next_connection(); + let thread_jh = thread::spawn(move || loop { + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - if Instant::now() - start > Duration::from_millis(50) { + let result = result.unwrap(); + if result == ConnectionResult::AcceptTimeout { + panic!("unexpected accept timeout"); + } + if stop_signal_copy.load(Ordering::Relaxed) { + break; + } + if Instant::now() - start > Duration::from_millis(100) { panic!("regular stop signal handling failed"); } }); + // We connect but do not do anything. + let _stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); stop_signal.store(true, Ordering::Relaxed); + // No need to drop the connection, the stop signal should take take of everything. + thread_jh.join().expect("thread join failed"); } } diff --git a/satrs/src/hal/std/tcp_server.rs b/satrs/src/hal/std/tcp_server.rs index bfbcae0..9b05f43 100644 --- a/satrs/src/hal/std/tcp_server.rs +++ b/satrs/src/hal/std/tcp_server.rs @@ -4,10 +4,13 @@ use alloc::vec; use alloc::vec::Vec; use core::sync::atomic::AtomicBool; use core::time::Duration; +use mio::net::{TcpListener, TcpStream}; +use mio::{Events, Interest, Poll, Token}; use socket2::{Domain, Socket, Type}; -use std::io::Read; -use std::net::TcpListener; -use std::net::{SocketAddr, TcpStream}; +use std::io::{self, Read}; +use std::net::SocketAddr; +// use std::net::TcpListener; +// use std::net::{SocketAddr, TcpStream}; use std::thread; use crate::tmtc::{ReceivesTc, TmPacketSource}; @@ -81,11 +84,35 @@ pub enum TcpTmtcError { /// Result of one connection attempt. Contains the client address if a connection was established, /// in addition to the number of telecommands and telemetry packets exchanged. -#[derive(Debug, Default)] -pub struct ConnectionResult { - pub addr: Option, +#[derive(Debug, PartialEq, Eq)] +pub enum ConnectionResult { + AcceptTimeout, + HandledConnections(u32), +} + +#[derive(Debug)] +pub struct HandledConnectionInfo { + pub addr: SocketAddr, pub num_received_tcs: u32, pub num_sent_tms: u32, + /// The generic TCP server can be stopped using an external signal. If this happened, this + /// boolean will be set to true. + pub stopped_by_signal: bool, +} + +impl HandledConnectionInfo { + pub fn new(addr: SocketAddr) -> Self { + Self { + addr, + num_received_tcs: 0, + num_sent_tms: 0, + stopped_by_signal: false, + } + } +} + +pub trait HandledConnectionHandler { + fn handled_connection(&mut self, info: HandledConnectionInfo); } /// Generic parser abstraction for an object which can parse for telecommands given a raw @@ -96,7 +123,7 @@ pub trait TcpTcParser { &mut self, tc_buffer: &mut [u8], tc_receiver: &mut (impl ReceivesTc + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, current_write_idx: usize, next_write_idx: &mut usize, ) -> Result<(), TcpTmtcError>; @@ -111,7 +138,7 @@ pub trait TcpTmSender { &mut self, tm_buffer: &mut [u8], tm_source: &mut (impl TmPacketSource + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, stream: &mut TcpStream, ) -> Result>; } @@ -134,32 +161,46 @@ pub trait TcpTmSender { /// /// 1. [TcpTmtcInCobsServer] to exchange TMTC wrapped inside the COBS framing protocol. pub struct TcpTmtcGenericServer< - TmError, - TcError, TmSource: TmPacketSource, TcReceiver: ReceivesTc, TmSender: TcpTmSender, TcParser: TcpTcParser, + HandledConnection: HandledConnectionHandler, + TmError, + TcError, > { + pub finished_handler: HandledConnection, pub(crate) listener: TcpListener, pub(crate) inner_loop_delay: Duration, pub(crate) tm_source: TmSource, pub(crate) tm_buffer: Vec, pub(crate) tc_receiver: TcReceiver, pub(crate) tc_buffer: Vec, - stop_signal: Option>, + poll: Poll, + events: Events, tc_handler: TcParser, tm_handler: TmSender, + stop_signal: Option>, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, TmSender: TcpTmSender, TcParser: TcpTcParser, - > TcpTmtcGenericServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > + TcpTmtcGenericServer< + TmSource, + TcReceiver, + TmSender, + TcParser, + HandledConnection, + TmError, + TcError, + > { /// Create a new generic TMTC server instance. /// @@ -173,32 +214,52 @@ impl< /// then sent back to the client. /// * `tc_receiver` - Any received telecommand which was decoded successfully will be forwarded /// to this TC receiver. + /// * `stop_signal` - Can be used to stop the server even if a connection is ongoing. pub fn new( cfg: ServerConfig, tc_parser: TcParser, tm_sender: TmSender, tm_source: TmSource, tc_receiver: TcReceiver, + finished_handler: HandledConnection, stop_signal: Option>, ) -> Result { // Create a TCP listener bound to two addresses. let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?; + socket.set_reuse_address(cfg.reuse_addr)?; #[cfg(unix)] socket.set_reuse_port(cfg.reuse_port)?; + // MIO does not do this for us. We want the accept calls to be non-blocking. + socket.set_nonblocking(true)?; let addr = (cfg.addr).into(); socket.bind(&addr)?; socket.listen(128)?; + + // Create a poll instance. + let poll = Poll::new()?; + // Create storage for events. + let events = Events::with_capacity(10); + let listener: std::net::TcpListener = socket.into(); + let mut mio_listener = TcpListener::from_std(listener); + + // Start listening for incoming connections. + poll.registry() + .register(&mut mio_listener, Token(0), Interest::READABLE)?; + Ok(Self { tc_handler: tc_parser, tm_handler: tm_sender, - listener: socket.into(), + poll, + events, + listener: mio_listener, inner_loop_delay: cfg.inner_loop_delay, tm_source, tm_buffer: vec![0; cfg.tm_buffer_size], tc_receiver, tc_buffer: vec![0; cfg.tc_buffer_size], stop_signal, + finished_handler, }) } @@ -228,13 +289,50 @@ impl< /// client does not send any telecommands and no telemetry needs to be sent back to the client. pub fn handle_next_connection( &mut self, + poll_timeout: Option, ) -> Result> { - let mut connection_result = ConnectionResult::default(); + let mut handled_connections = 0; + // Poll Mio for events. + self.poll.poll(&mut self.events, poll_timeout)?; + let mut acceptable_connection = false; + // Process each event. + for event in self.events.iter() { + if event.token() == Token(0) { + acceptable_connection = true; + } else { + // Should never happen.. + panic!("unexpected TCP event token"); + } + } + // I'd love to do this in the loop above, but there are issues with multiple borrows. + if acceptable_connection { + // There might be mutliple connections available. Accept until all of them have + // been handled. + loop { + match self.listener.accept() { + Ok((stream, addr)) => { + self.handle_accepted_connection(stream, addr)?; + handled_connections += 1; + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break, + Err(err) => return Err(TcpTmtcError::Io(err)), + } + } + } + if handled_connections > 0 { + return Ok(ConnectionResult::HandledConnections(handled_connections)); + } + Ok(ConnectionResult::AcceptTimeout) + } + + fn handle_accepted_connection( + &mut self, + mut stream: TcpStream, + addr: SocketAddr, + ) -> Result<(), TcpTmtcError> { let mut current_write_idx; let mut next_write_idx = 0; - let (mut stream, addr) = self.listener.accept()?; - stream.set_nonblocking(true)?; - connection_result.addr = Some(addr); + let mut connection_result = HandledConnectionInfo::new(addr); current_write_idx = next_write_idx; loop { let read_result = stream.read(&mut self.tc_buffer[current_write_idx..]); @@ -297,7 +395,9 @@ impl< .unwrap() .load(std::sync::atomic::Ordering::Relaxed) { - return Ok(connection_result); + connection_result.stopped_by_signal = true; + self.finished_handler.handled_connection(connection_result); + return Ok(()); } } } @@ -313,7 +413,8 @@ impl< &mut connection_result, &mut stream, )?; - Ok(connection_result) + self.finished_handler.handled_connection(connection_result); + Ok(()) } } @@ -325,6 +426,8 @@ pub(crate) mod tests { use crate::tmtc::{ReceivesTcCore, TmPacketSourceCore}; + use super::*; + #[derive(Default, Clone)] pub(crate) struct SyncTcCacher { pub(crate) tc_queue: Arc>>>, @@ -371,4 +474,30 @@ pub(crate) mod tests { Ok(0) } } + + #[derive(Default)] + pub struct ConnectionFinishedHandler { + connection_info: VecDeque, + } + + impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: HandledConnectionInfo) { + self.connection_info.push_back(info); + } + } + + impl ConnectionFinishedHandler { + pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) { + let last_conn_result = self + .connection_info + .pop_back() + .expect("no connection info available"); + assert_eq!(last_conn_result.num_received_tcs, num_tcs); + assert_eq!(last_conn_result.num_sent_tms, num_tms); + } + + pub fn check_no_connections_left(&self) { + assert!(self.connection_info.is_empty()); + } + } } diff --git a/satrs/src/hal/std/tcp_spacepackets_server.rs b/satrs/src/hal/std/tcp_spacepackets_server.rs index a1abba9..45f5fd8 100644 --- a/satrs/src/hal/std/tcp_spacepackets_server.rs +++ b/satrs/src/hal/std/tcp_spacepackets_server.rs @@ -1,10 +1,8 @@ use alloc::sync::Arc; -use core::sync::atomic::AtomicBool; +use core::{sync::atomic::AtomicBool, time::Duration}; use delegate::delegate; -use std::{ - io::Write, - net::{SocketAddr, TcpListener, TcpStream}, -}; +use mio::net::{TcpListener, TcpStream}; +use std::{io::Write, net::SocketAddr}; use crate::{ encoding::parse_buffer_for_ccsds_space_packets, @@ -13,7 +11,8 @@ use crate::{ }; use super::tcp_server::{ - ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, + ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig, TcpTcParser, + TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, }; /// Concrete [TcpTcParser] implementation for the [TcpSpacepacketsServer]. @@ -27,14 +26,14 @@ impl SpacepacketsTcParser { } } -impl TcpTcParser +impl TcpTcParser for SpacepacketsTcParser { fn handle_tc_parsing( &mut self, tc_buffer: &mut [u8], tc_receiver: &mut (impl ReceivesTc + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, current_write_idx: usize, next_write_idx: &mut usize, ) -> Result<(), TcpTmtcError> { @@ -59,7 +58,7 @@ impl TcpTmSender for SpacepacketsTmSender { &mut self, tm_buffer: &mut [u8], tm_source: &mut (impl TmPacketSource + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, stream: &mut TcpStream, ) -> Result> { let mut tm_was_sent = false; @@ -94,29 +93,40 @@ impl TcpTmSender for SpacepacketsTmSender { /// The [TCP server integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// also serves as the example application for this module. pub struct TcpSpacepacketsServer< - TmError, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, PacketIdChecker: ValidatorU16Id, + HandledConnection: HandledConnectionHandler, + TmError, + TcError: 'static, > { - generic_server: TcpTmtcGenericServer< - TmError, - TcError, + pub generic_server: TcpTmtcGenericServer< TmSource, TcReceiver, SpacepacketsTmSender, SpacepacketsTcParser, + HandledConnection, + TmError, + TcError, >, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, PacketIdChecker: ValidatorU16Id, - > TcpSpacepacketsServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > + TcpSpacepacketsServer< + TmSource, + TcReceiver, + PacketIdChecker, + HandledConnection, + TmError, + TcError, + > { /// /// ## Parameter @@ -133,6 +143,7 @@ impl< tm_source: TmSource, tc_receiver: TcReceiver, packet_id_checker: PacketIdChecker, + handled_connection: HandledConnection, stop_signal: Option>, ) -> Result { Ok(Self { @@ -142,6 +153,7 @@ impl< SpacepacketsTmSender::default(), tm_source, tc_receiver, + handled_connection, stop_signal, )?, }) @@ -158,6 +170,7 @@ impl< /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. pub fn handle_next_connection( &mut self, + poll_timeout: Option ) -> Result>; } } @@ -185,8 +198,8 @@ mod tests { }; use crate::hal::std::tcp_server::{ - tests::{SyncTcCacher, SyncTmSource}, - ServerConfig, + tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource}, + ConnectionResult, ServerConfig, }; use super::TcpSpacepacketsServer; @@ -202,12 +215,20 @@ mod tests { tm_source: SyncTmSource, packet_id_lookup: HashSet, stop_signal: Option>, - ) -> TcpSpacepacketsServer<(), (), SyncTmSource, SyncTcCacher, HashSet> { + ) -> TcpSpacepacketsServer< + SyncTmSource, + SyncTcCacher, + HashSet, + ConnectionFinishedHandler, + (), + (), + > { TcpSpacepacketsServer::new( ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver, packet_id_lookup, + ConnectionFinishedHandler::default(), stop_signal, ) .expect("TCP server generation failed") @@ -234,13 +255,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 0); + matches!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(0, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let ping_tc = @@ -305,16 +333,20 @@ mod tests { // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!( - conn_result.num_received_tcs, 2, - "wrong number of received TCs" - ); - assert_eq!(conn_result.num_sent_tms, 2, "wrong number of sent TMs"); + matches!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(2, 2); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); diff --git a/satrs/tests/tcp_servers.rs b/satrs/tests/tcp_servers.rs index 02e3824..65d124d 100644 --- a/satrs/tests/tcp_servers.rs +++ b/satrs/tests/tcp_servers.rs @@ -24,7 +24,10 @@ use std::{ use hashbrown::HashSet; use satrs::{ encoding::cobs::encode_packet_with_cobs, - hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer, TcpTmtcInCobsServer}, + hal::std::tcp_server::{ + ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig, + TcpSpacepacketsServer, TcpTmtcInCobsServer, + }, tmtc::{ReceivesTcCore, TmPacketSourceCore}, }; use spacepackets::{ @@ -33,10 +36,36 @@ use spacepackets::{ }; use std::{collections::VecDeque, sync::Arc, vec::Vec}; +#[derive(Default)] +pub struct ConnectionFinishedHandler { + connection_info: VecDeque, +} + +impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: HandledConnectionInfo) { + self.connection_info.push_back(info); + } +} + +impl ConnectionFinishedHandler { + pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) { + let last_conn_result = self + .connection_info + .pop_back() + .expect("no connection info available"); + assert_eq!(last_conn_result.num_received_tcs, num_tcs); + assert_eq!(last_conn_result.num_sent_tms, num_tms); + } + + pub fn check_no_connections_left(&self) { + assert!(self.connection_info.is_empty()); + } +} #[derive(Default, Clone)] struct SyncTcCacher { tc_queue: Arc>>>, } + impl ReceivesTcCore for SyncTcCacher { type Error = (); @@ -96,6 +125,7 @@ fn test_cobs_server() { ServerConfig::new(AUTO_PORT_ADDR, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver.clone(), + ConnectionFinishedHandler::default(), None, ) .expect("TCP server generation failed"); @@ -107,13 +137,20 @@ fn test_cobs_server() { // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(400))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1, "No TC received"); - assert_eq!(conn_result.num_sent_tms, 1, "No TM received"); + assert_eq!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(1, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); // Signal the main thread we are done. set_if_done.store(true, Ordering::Relaxed); }); @@ -180,6 +217,7 @@ fn test_ccsds_server() { tm_source, tc_receiver.clone(), packet_id_lookup, + ConnectionFinishedHandler::default(), None, ) .expect("TCP server generation failed"); @@ -190,13 +228,20 @@ fn test_ccsds_server() { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(500))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 1); + assert_eq!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(1, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");