diff --git a/satrs-example/src/tcp.rs b/satrs-example/src/tcp.rs index 7abcf5b..7db4c2c 100644 --- a/satrs-example/src/tcp.rs +++ b/satrs-example/src/tcp.rs @@ -5,7 +5,7 @@ use std::{ use log::{info, warn}; use satrs::{ - hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer}, + hal::std::tcp_server::{HandledConnectionHandler, ServerConfig, TcpSpacepacketsServer}, pus::ReceivesEcssPusTc, spacepackets::PacketId, tmtc::{CcsdsDistributor, CcsdsError, ReceivesCcsdsTc, TmPacketSourceCore}, @@ -13,6 +13,18 @@ use satrs::{ use crate::ccsds::CcsdsReceiver; +#[derive(Default)] +pub struct ConnectionFinishedHandler {} + +impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: satrs::hal::std::tcp_server::HandledConnectionInfo) { + info!( + "Served {} TMs and {} TCs for client {:?}", + info.num_sent_tms, info.num_received_tcs, info.addr + ); + } +} + #[derive(Default, Clone)] pub struct SyncTcpTmSource { tm_queue: Arc>>>, @@ -70,11 +82,12 @@ impl TmPacketSourceCore for SyncTcpTmSource { } pub type TcpServerType = TcpSpacepacketsServer< - (), - CcsdsError, SyncTcpTmSource, CcsdsDistributor, MpscErrorType>, HashSet, + ConnectionFinishedHandler, + (), + CcsdsError, >; pub struct TcpTask< @@ -109,6 +122,7 @@ impl< tm_source, tc_receiver, packet_id_lookup, + ConnectionFinishedHandler::default(), None, )?, }) @@ -116,14 +130,9 @@ impl< pub fn periodic_operation(&mut self) { loop { - let result = self.server.handle_next_connection(); + let result = self.server.handle_next_connection(None); match result { - Ok(conn_result) => { - info!( - "Served {} TMs and {} TCs for client {:?}", - conn_result.num_sent_tms, conn_result.num_received_tcs, conn_result.addr - ); - } + Ok(_conn_result) => (), Err(e) => { warn!("TCP server error: {e:?}"); } diff --git a/satrs/CHANGELOG.md b/satrs/CHANGELOG.md index e119322..508c69b 100644 --- a/satrs/CHANGELOG.md +++ b/satrs/CHANGELOG.md @@ -22,9 +22,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/). parameter type. This message also contains the sender ID which can be useful for debugging or application layer / FDIR logic. - Stop signal handling for the TCP servers. +- TCP server now uses `mio` crate to allow non-blocking operation. The server can now handle + multiple connections at once, and the context information about handled transfers is + passed via a callback which is inserted as a generic as well. ## Changed +- TCP server generics order. The error generics come last now. - `encoding::ccsds::PacketIdValidator` renamed to `ValidatorU16Id`, which lives in the crate root. It can be used for both CCSDS packet ID and CCSDS APID validation. - `EventManager::try_event_handling` not expects a mutable error handling closure instead of diff --git a/satrs/src/hal/std/tcp_cobs_server.rs b/satrs/src/hal/std/tcp_cobs_server.rs index 076de55..8b044f6 100644 --- a/satrs/src/hal/std/tcp_cobs_server.rs +++ b/satrs/src/hal/std/tcp_cobs_server.rs @@ -2,11 +2,11 @@ use alloc::sync::Arc; use alloc::vec; use cobs::encode; use core::sync::atomic::AtomicBool; +use core::time::Duration; use delegate::delegate; +use mio::net::{TcpListener, TcpStream}; use std::io::Write; use std::net::SocketAddr; -use std::net::TcpListener; -use std::net::TcpStream; use std::vec::Vec; use crate::encoding::parse_buffer_for_cobs_encoded_packets; @@ -17,6 +17,9 @@ use crate::hal::std::tcp_server::{ ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, }; +use super::tcp_server::HandledConnectionHandler; +use super::tcp_server::HandledConnectionInfo; + /// Concrete [TcpTcParser] implementation for the [TcpTmtcInCobsServer]. #[derive(Default)] pub struct CobsTcParser {} @@ -26,7 +29,7 @@ impl TcpTcParser for CobsTcParser { &mut self, tc_buffer: &mut [u8], tc_receiver: &mut (impl ReceivesTc + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, current_write_idx: usize, next_write_idx: &mut usize, ) -> Result<(), TcpTmtcError> { @@ -60,7 +63,7 @@ impl TcpTmSender for CobsTmSender { &mut self, tm_buffer: &mut [u8], tm_source: &mut (impl TmPacketSource + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, stream: &mut TcpStream, ) -> Result> { let mut tm_was_sent = false; @@ -112,21 +115,30 @@ impl TcpTmSender for CobsTmSender { /// The [TCP integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// test also serves as the example application for this module. pub struct TcpTmtcInCobsServer< - TmError, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, + HandledConnection: HandledConnectionHandler, + TmError, + TcError: 'static, > { - generic_server: - TcpTmtcGenericServer, + pub generic_server: TcpTmtcGenericServer< + TmSource, + TcReceiver, + CobsTmSender, + CobsTcParser, + HandledConnection, + TmError, + TcError, + >, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, - > TcpTmtcInCobsServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > TcpTmtcInCobsServer { /// Create a new TCP TMTC server which exchanges TMTC packets encoded with /// [COBS protocol](https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing). @@ -142,6 +154,7 @@ impl< cfg: ServerConfig, tm_source: TmSource, tc_receiver: TcReceiver, + handled_connection: HandledConnection, stop_signal: Option>, ) -> Result { Ok(Self { @@ -151,6 +164,7 @@ impl< CobsTmSender::new(cfg.tm_buffer_size), tm_source, tc_receiver, + handled_connection, stop_signal, )?, }) @@ -167,6 +181,7 @@ impl< /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. pub fn handle_next_connection( &mut self, + poll_duration: Option, ) -> Result>; } } @@ -181,15 +196,15 @@ mod tests { use std::{ io::{Read, Write}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}, - panic, println, thread, + panic, thread, time::Instant, }; use crate::{ encoding::tests::{INVERTED_PACKET, SIMPLE_PACKET}, hal::std::tcp_server::{ - tests::{SyncTcCacher, SyncTmSource}, - ServerConfig, + tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource}, + ConnectionResult, ServerConfig, }, }; use alloc::sync::Arc; @@ -218,11 +233,12 @@ mod tests { tc_receiver: SyncTcCacher, tm_source: SyncTmSource, stop_signal: Option>, - ) -> TcpTmtcInCobsServer<(), (), SyncTmSource, SyncTcCacher> { + ) -> TcpTmtcInCobsServer { TcpTmtcInCobsServer::new( ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver, + ConnectionFinishedHandler::default(), stop_signal, ) .expect("TCP server generation failed") @@ -242,13 +258,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 0); + let result = result.unwrap(); + assert_eq!(result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(0, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); // Send TC to server now. @@ -299,13 +322,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 2, "Not enough TCs received"); - assert_eq!(conn_result.num_sent_tms, 2, "Not enough TMs received"); + let result = result.unwrap(); + assert_eq!(result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(2, 2); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); // Send TC to server now. @@ -389,6 +419,31 @@ mod tests { drop(tc_queue); } + #[test] + fn test_server_accept_timeout() { + let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); + let tc_receiver = SyncTcCacher::default(); + let tm_source = SyncTmSource::default(); + let mut tcp_server = + generic_tmtc_server(&auto_port_addr, tc_receiver.clone(), tm_source, None); + let start = Instant::now(); + // Call the connection handler in separate thread, does block. + let thread_jh = thread::spawn(move || loop { + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20))); + if result.is_err() { + panic!("handling connection failed: {:?}", result.unwrap_err()); + } + let result = result.unwrap(); + if result == ConnectionResult::AcceptTimeout { + break; + } + if Instant::now() - start > Duration::from_millis(100) { + panic!("regular stop signal handling failed"); + } + }); + thread_jh.join().expect("thread join failed"); + } + #[test] fn test_server_stop_signal() { let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); @@ -401,26 +456,32 @@ mod tests { tm_source, Some(stop_signal.clone()), ); + let dest_addr = tcp_server + .local_addr() + .expect("retrieving dest addr failed"); + let stop_signal_copy = stop_signal.clone(); let start = Instant::now(); // Call the connection handler in separate thread, does block. let thread_jh = thread::spawn(move || loop { - println!("hello wtf!!!!"); - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } - println!("helluuuu"); let result = result.unwrap(); - if result.stopped_by_signal { + if result == ConnectionResult::AcceptTimeout { + panic!("unexpected accept timeout"); + } + if stop_signal_copy.load(Ordering::Relaxed) { break; } - if Instant::now() - start > Duration::from_millis(50) { + if Instant::now() - start > Duration::from_millis(100) { panic!("regular stop signal handling failed"); } }); + // We connect but do not do anything. + let _stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); stop_signal.store(true, Ordering::Relaxed); - thread::sleep(Duration::from_millis(100)); - panic!("shit") - //thread_jh.join().expect("thread join failed"); + // No need to drop the connection, the stop signal should take take of everything. + thread_jh.join().expect("thread join failed"); } } diff --git a/satrs/src/hal/std/tcp_server.rs b/satrs/src/hal/std/tcp_server.rs index b81f6c8..9b05f43 100644 --- a/satrs/src/hal/std/tcp_server.rs +++ b/satrs/src/hal/std/tcp_server.rs @@ -7,7 +7,7 @@ use core::time::Duration; use mio::net::{TcpListener, TcpStream}; use mio::{Events, Interest, Poll, Token}; use socket2::{Domain, Socket, Type}; -use std::io::Read; +use std::io::{self, Read}; use std::net::SocketAddr; // use std::net::TcpListener; // use std::net::{SocketAddr, TcpStream}; @@ -84,21 +84,15 @@ pub enum TcpTmtcError { /// Result of one connection attempt. Contains the client address if a connection was established, /// in addition to the number of telecommands and telemetry packets exchanged. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum ConnectionResult { AcceptTimeout, - HandledConnection(HandledConnectionInfo), + HandledConnections(u32), } -impl From for ConnectionResult { - fn from(info: HandledConnectionInfo) -> Self { - ConnectionResult::HandledConnection(info) - } -} - -#[derive(Debug, Default)] +#[derive(Debug)] pub struct HandledConnectionInfo { - pub addr: Option, + pub addr: SocketAddr, pub num_received_tcs: u32, pub num_sent_tms: u32, /// The generic TCP server can be stopped using an external signal. If this happened, this @@ -106,6 +100,17 @@ pub struct HandledConnectionInfo { pub stopped_by_signal: bool, } +impl HandledConnectionInfo { + pub fn new(addr: SocketAddr) -> Self { + Self { + addr, + num_received_tcs: 0, + num_sent_tms: 0, + stopped_by_signal: false, + } + } +} + pub trait HandledConnectionHandler { fn handled_connection(&mut self, info: HandledConnectionInfo); } @@ -156,14 +161,15 @@ pub trait TcpTmSender { /// /// 1. [TcpTmtcInCobsServer] to exchange TMTC wrapped inside the COBS framing protocol. pub struct TcpTmtcGenericServer< - TmError, - TcError, TmSource: TmPacketSource, TcReceiver: ReceivesTc, TmSender: TcpTmSender, TcParser: TcpTcParser, - //HandledConnection: HandledConnectionHandler + HandledConnection: HandledConnectionHandler, + TmError, + TcError, > { + pub finished_handler: HandledConnection, pub(crate) listener: TcpListener, pub(crate) inner_loop_delay: Duration, pub(crate) tm_source: TmSource, @@ -172,19 +178,29 @@ pub struct TcpTmtcGenericServer< pub(crate) tc_buffer: Vec, poll: Poll, events: Events, - stop_signal: Option>, tc_handler: TcParser, tm_handler: TmSender, + stop_signal: Option>, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, TmSender: TcpTmSender, TcParser: TcpTcParser, - > TcpTmtcGenericServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > + TcpTmtcGenericServer< + TmSource, + TcReceiver, + TmSender, + TcParser, + HandledConnection, + TmError, + TcError, + > { /// Create a new generic TMTC server instance. /// @@ -198,12 +214,14 @@ impl< /// then sent back to the client. /// * `tc_receiver` - Any received telecommand which was decoded successfully will be forwarded /// to this TC receiver. + /// * `stop_signal` - Can be used to stop the server even if a connection is ongoing. pub fn new( cfg: ServerConfig, tc_parser: TcParser, tm_sender: TmSender, tm_source: TmSource, tc_receiver: TcReceiver, + finished_handler: HandledConnection, stop_signal: Option>, ) -> Result { // Create a TCP listener bound to two addresses. @@ -212,14 +230,16 @@ impl< socket.set_reuse_address(cfg.reuse_addr)?; #[cfg(unix)] socket.set_reuse_port(cfg.reuse_port)?; + // MIO does not do this for us. We want the accept calls to be non-blocking. + socket.set_nonblocking(true)?; let addr = (cfg.addr).into(); socket.bind(&addr)?; socket.listen(128)?; // Create a poll instance. - let mut poll = Poll::new()?; + let poll = Poll::new()?; // Create storage for events. - let mut events = Events::with_capacity(10); + let events = Events::with_capacity(10); let listener: std::net::TcpListener = socket.into(); let mut mio_listener = TcpListener::from_std(listener); @@ -239,6 +259,7 @@ impl< tc_receiver, tc_buffer: vec![0; cfg.tc_buffer_size], stop_signal, + finished_handler, }) } @@ -268,19 +289,38 @@ impl< /// client does not send any telecommands and no telemetry needs to be sent back to the client. pub fn handle_next_connection( &mut self, + poll_timeout: Option, ) -> Result> { - // Poll Mio for events, blocking until we get an event. - self.poll - .poll(&mut self.events, Some(Duration::from_millis(400)))?; + let mut handled_connections = 0; + // Poll Mio for events. + self.poll.poll(&mut self.events, poll_timeout)?; + let mut acceptable_connection = false; // Process each event. - if let Some(event) = self.events.iter().next() { + for event in self.events.iter() { if event.token() == Token(0) { - let connection = self.listener.accept()?; - return self - .handle_accepted_connection(connection.0, connection.1) - .map(|v| v.into()); + acceptable_connection = true; + } else { + // Should never happen.. + panic!("unexpected TCP event token"); } - panic!("unexpected TCP event token"); + } + // I'd love to do this in the loop above, but there are issues with multiple borrows. + if acceptable_connection { + // There might be mutliple connections available. Accept until all of them have + // been handled. + loop { + match self.listener.accept() { + Ok((stream, addr)) => { + self.handle_accepted_connection(stream, addr)?; + handled_connections += 1; + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break, + Err(err) => return Err(TcpTmtcError::Io(err)), + } + } + } + if handled_connections > 0 { + return Ok(ConnectionResult::HandledConnections(handled_connections)); } Ok(ConnectionResult::AcceptTimeout) } @@ -289,13 +329,10 @@ impl< &mut self, mut stream: TcpStream, addr: SocketAddr, - ) -> Result> { + ) -> Result<(), TcpTmtcError> { let mut current_write_idx; let mut next_write_idx = 0; - let mut connection_result = HandledConnectionInfo::default(); - // stream.set_nonblocking(true)?; - connection_result.addr = Some(addr); - connection_result.stopped_by_signal = false; + let mut connection_result = HandledConnectionInfo::new(addr); current_write_idx = next_write_idx; loop { let read_result = stream.read(&mut self.tc_buffer[current_write_idx..]); @@ -359,7 +396,8 @@ impl< .load(std::sync::atomic::Ordering::Relaxed) { connection_result.stopped_by_signal = true; - return Ok(connection_result); + self.finished_handler.handled_connection(connection_result); + return Ok(()); } } } @@ -375,7 +413,8 @@ impl< &mut connection_result, &mut stream, )?; - Ok(connection_result) + self.finished_handler.handled_connection(connection_result); + Ok(()) } } @@ -387,6 +426,8 @@ pub(crate) mod tests { use crate::tmtc::{ReceivesTcCore, TmPacketSourceCore}; + use super::*; + #[derive(Default, Clone)] pub(crate) struct SyncTcCacher { pub(crate) tc_queue: Arc>>>, @@ -433,4 +474,30 @@ pub(crate) mod tests { Ok(0) } } + + #[derive(Default)] + pub struct ConnectionFinishedHandler { + connection_info: VecDeque, + } + + impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: HandledConnectionInfo) { + self.connection_info.push_back(info); + } + } + + impl ConnectionFinishedHandler { + pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) { + let last_conn_result = self + .connection_info + .pop_back() + .expect("no connection info available"); + assert_eq!(last_conn_result.num_received_tcs, num_tcs); + assert_eq!(last_conn_result.num_sent_tms, num_tms); + } + + pub fn check_no_connections_left(&self) { + assert!(self.connection_info.is_empty()); + } + } } diff --git a/satrs/src/hal/std/tcp_spacepackets_server.rs b/satrs/src/hal/std/tcp_spacepackets_server.rs index a1abba9..45f5fd8 100644 --- a/satrs/src/hal/std/tcp_spacepackets_server.rs +++ b/satrs/src/hal/std/tcp_spacepackets_server.rs @@ -1,10 +1,8 @@ use alloc::sync::Arc; -use core::sync::atomic::AtomicBool; +use core::{sync::atomic::AtomicBool, time::Duration}; use delegate::delegate; -use std::{ - io::Write, - net::{SocketAddr, TcpListener, TcpStream}, -}; +use mio::net::{TcpListener, TcpStream}; +use std::{io::Write, net::SocketAddr}; use crate::{ encoding::parse_buffer_for_ccsds_space_packets, @@ -13,7 +11,8 @@ use crate::{ }; use super::tcp_server::{ - ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, + ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig, TcpTcParser, + TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, }; /// Concrete [TcpTcParser] implementation for the [TcpSpacepacketsServer]. @@ -27,14 +26,14 @@ impl SpacepacketsTcParser { } } -impl TcpTcParser +impl TcpTcParser for SpacepacketsTcParser { fn handle_tc_parsing( &mut self, tc_buffer: &mut [u8], tc_receiver: &mut (impl ReceivesTc + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, current_write_idx: usize, next_write_idx: &mut usize, ) -> Result<(), TcpTmtcError> { @@ -59,7 +58,7 @@ impl TcpTmSender for SpacepacketsTmSender { &mut self, tm_buffer: &mut [u8], tm_source: &mut (impl TmPacketSource + ?Sized), - conn_result: &mut ConnectionResult, + conn_result: &mut HandledConnectionInfo, stream: &mut TcpStream, ) -> Result> { let mut tm_was_sent = false; @@ -94,29 +93,40 @@ impl TcpTmSender for SpacepacketsTmSender { /// The [TCP server integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// also serves as the example application for this module. pub struct TcpSpacepacketsServer< - TmError, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, PacketIdChecker: ValidatorU16Id, + HandledConnection: HandledConnectionHandler, + TmError, + TcError: 'static, > { - generic_server: TcpTmtcGenericServer< - TmError, - TcError, + pub generic_server: TcpTmtcGenericServer< TmSource, TcReceiver, SpacepacketsTmSender, SpacepacketsTcParser, + HandledConnection, + TmError, + TcError, >, } impl< - TmError: 'static, - TcError: 'static, TmSource: TmPacketSource, TcReceiver: ReceivesTc, PacketIdChecker: ValidatorU16Id, - > TcpSpacepacketsServer + HandledConnection: HandledConnectionHandler, + TmError: 'static, + TcError: 'static, + > + TcpSpacepacketsServer< + TmSource, + TcReceiver, + PacketIdChecker, + HandledConnection, + TmError, + TcError, + > { /// /// ## Parameter @@ -133,6 +143,7 @@ impl< tm_source: TmSource, tc_receiver: TcReceiver, packet_id_checker: PacketIdChecker, + handled_connection: HandledConnection, stop_signal: Option>, ) -> Result { Ok(Self { @@ -142,6 +153,7 @@ impl< SpacepacketsTmSender::default(), tm_source, tc_receiver, + handled_connection, stop_signal, )?, }) @@ -158,6 +170,7 @@ impl< /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. pub fn handle_next_connection( &mut self, + poll_timeout: Option ) -> Result>; } } @@ -185,8 +198,8 @@ mod tests { }; use crate::hal::std::tcp_server::{ - tests::{SyncTcCacher, SyncTmSource}, - ServerConfig, + tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource}, + ConnectionResult, ServerConfig, }; use super::TcpSpacepacketsServer; @@ -202,12 +215,20 @@ mod tests { tm_source: SyncTmSource, packet_id_lookup: HashSet, stop_signal: Option>, - ) -> TcpSpacepacketsServer<(), (), SyncTmSource, SyncTcCacher, HashSet> { + ) -> TcpSpacepacketsServer< + SyncTmSource, + SyncTcCacher, + HashSet, + ConnectionFinishedHandler, + (), + (), + > { TcpSpacepacketsServer::new( ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver, packet_id_lookup, + ConnectionFinishedHandler::default(), stop_signal, ) .expect("TCP server generation failed") @@ -234,13 +255,20 @@ mod tests { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 0); + matches!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(0, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let ping_tc = @@ -305,16 +333,20 @@ mod tests { // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!( - conn_result.num_received_tcs, 2, - "wrong number of received TCs" - ); - assert_eq!(conn_result.num_sent_tms, 2, "wrong number of sent TMs"); + matches!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(2, 2); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); diff --git a/satrs/tests/tcp_servers.rs b/satrs/tests/tcp_servers.rs index 02e3824..65d124d 100644 --- a/satrs/tests/tcp_servers.rs +++ b/satrs/tests/tcp_servers.rs @@ -24,7 +24,10 @@ use std::{ use hashbrown::HashSet; use satrs::{ encoding::cobs::encode_packet_with_cobs, - hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer, TcpTmtcInCobsServer}, + hal::std::tcp_server::{ + ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig, + TcpSpacepacketsServer, TcpTmtcInCobsServer, + }, tmtc::{ReceivesTcCore, TmPacketSourceCore}, }; use spacepackets::{ @@ -33,10 +36,36 @@ use spacepackets::{ }; use std::{collections::VecDeque, sync::Arc, vec::Vec}; +#[derive(Default)] +pub struct ConnectionFinishedHandler { + connection_info: VecDeque, +} + +impl HandledConnectionHandler for ConnectionFinishedHandler { + fn handled_connection(&mut self, info: HandledConnectionInfo) { + self.connection_info.push_back(info); + } +} + +impl ConnectionFinishedHandler { + pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) { + let last_conn_result = self + .connection_info + .pop_back() + .expect("no connection info available"); + assert_eq!(last_conn_result.num_received_tcs, num_tcs); + assert_eq!(last_conn_result.num_sent_tms, num_tms); + } + + pub fn check_no_connections_left(&self) { + assert!(self.connection_info.is_empty()); + } +} #[derive(Default, Clone)] struct SyncTcCacher { tc_queue: Arc>>>, } + impl ReceivesTcCore for SyncTcCacher { type Error = (); @@ -96,6 +125,7 @@ fn test_cobs_server() { ServerConfig::new(AUTO_PORT_ADDR, Duration::from_millis(2), 1024, 1024), tm_source, tc_receiver.clone(), + ConnectionFinishedHandler::default(), None, ) .expect("TCP server generation failed"); @@ -107,13 +137,20 @@ fn test_cobs_server() { // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(400))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1, "No TC received"); - assert_eq!(conn_result.num_sent_tms, 1, "No TM received"); + assert_eq!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(1, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); // Signal the main thread we are done. set_if_done.store(true, Ordering::Relaxed); }); @@ -180,6 +217,7 @@ fn test_ccsds_server() { tm_source, tc_receiver.clone(), packet_id_lookup, + ConnectionFinishedHandler::default(), None, ) .expect("TCP server generation failed"); @@ -190,13 +228,20 @@ fn test_ccsds_server() { let set_if_done = conn_handled.clone(); // Call the connection handler in separate thread, does block. thread::spawn(move || { - let result = tcp_server.handle_next_connection(); + let result = tcp_server.handle_next_connection(Some(Duration::from_millis(500))); if result.is_err() { panic!("handling connection failed: {:?}", result.unwrap_err()); } let conn_result = result.unwrap(); - assert_eq!(conn_result.num_received_tcs, 1); - assert_eq!(conn_result.num_sent_tms, 1); + assert_eq!(conn_result, ConnectionResult::HandledConnections(1)); + tcp_server + .generic_server + .finished_handler + .check_last_connection(1, 1); + tcp_server + .generic_server + .finished_handler + .check_no_connections_left(); set_if_done.store(true, Ordering::Relaxed); }); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");