Refactor and improve TCP servers #150

Merged
muellerr merged 3 commits from refactor-tcp-server into main 2024-04-10 12:29:24 +02:00
6 changed files with 331 additions and 113 deletions
Showing only changes of commit d27ac5dfc9 - Show all commits

View File

@ -5,7 +5,7 @@ use std::{
use log::{info, warn}; use log::{info, warn};
use satrs::{ use satrs::{
hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer}, hal::std::tcp_server::{HandledConnectionHandler, ServerConfig, TcpSpacepacketsServer},
pus::ReceivesEcssPusTc, pus::ReceivesEcssPusTc,
spacepackets::PacketId, spacepackets::PacketId,
tmtc::{CcsdsDistributor, CcsdsError, ReceivesCcsdsTc, TmPacketSourceCore}, tmtc::{CcsdsDistributor, CcsdsError, ReceivesCcsdsTc, TmPacketSourceCore},
@ -13,6 +13,18 @@ use satrs::{
use crate::ccsds::CcsdsReceiver; use crate::ccsds::CcsdsReceiver;
#[derive(Default)]
pub struct ConnectionFinishedHandler {}
impl HandledConnectionHandler for ConnectionFinishedHandler {
fn handled_connection(&mut self, info: satrs::hal::std::tcp_server::HandledConnectionInfo) {
info!(
"Served {} TMs and {} TCs for client {:?}",
info.num_sent_tms, info.num_received_tcs, info.addr
);
}
}
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub struct SyncTcpTmSource { pub struct SyncTcpTmSource {
tm_queue: Arc<Mutex<VecDeque<Vec<u8>>>>, tm_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
@ -70,11 +82,12 @@ impl TmPacketSourceCore for SyncTcpTmSource {
} }
pub type TcpServerType<TcSource, MpscErrorType> = TcpSpacepacketsServer< pub type TcpServerType<TcSource, MpscErrorType> = TcpSpacepacketsServer<
(),
CcsdsError<MpscErrorType>,
SyncTcpTmSource, SyncTcpTmSource,
CcsdsDistributor<CcsdsReceiver<TcSource, MpscErrorType>, MpscErrorType>, CcsdsDistributor<CcsdsReceiver<TcSource, MpscErrorType>, MpscErrorType>,
HashSet<PacketId>, HashSet<PacketId>,
ConnectionFinishedHandler,
(),
CcsdsError<MpscErrorType>,
>; >;
pub struct TcpTask< pub struct TcpTask<
@ -109,6 +122,7 @@ impl<
tm_source, tm_source,
tc_receiver, tc_receiver,
packet_id_lookup, packet_id_lookup,
ConnectionFinishedHandler::default(),
None, None,
)?, )?,
}) })
@ -116,14 +130,9 @@ impl<
pub fn periodic_operation(&mut self) { pub fn periodic_operation(&mut self) {
loop { loop {
let result = self.server.handle_next_connection(); let result = self.server.handle_next_connection(None);
match result { match result {
Ok(conn_result) => { Ok(_conn_result) => (),
info!(
"Served {} TMs and {} TCs for client {:?}",
conn_result.num_sent_tms, conn_result.num_received_tcs, conn_result.addr
);
}
Err(e) => { Err(e) => {
warn!("TCP server error: {e:?}"); warn!("TCP server error: {e:?}");
} }

View File

@ -22,9 +22,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
parameter type. This message also contains the sender ID which can be useful for debugging parameter type. This message also contains the sender ID which can be useful for debugging
or application layer / FDIR logic. or application layer / FDIR logic.
- Stop signal handling for the TCP servers. - Stop signal handling for the TCP servers.
- TCP server now uses `mio` crate to allow non-blocking operation. The server can now handle
multiple connections at once, and the context information about handled transfers is
passed via a callback which is inserted as a generic as well.
## Changed ## Changed
- TCP server generics order. The error generics come last now.
- `encoding::ccsds::PacketIdValidator` renamed to `ValidatorU16Id`, which lives in the crate root. - `encoding::ccsds::PacketIdValidator` renamed to `ValidatorU16Id`, which lives in the crate root.
It can be used for both CCSDS packet ID and CCSDS APID validation. It can be used for both CCSDS packet ID and CCSDS APID validation.
- `EventManager::try_event_handling` not expects a mutable error handling closure instead of - `EventManager::try_event_handling` not expects a mutable error handling closure instead of

View File

@ -2,11 +2,11 @@ use alloc::sync::Arc;
use alloc::vec; use alloc::vec;
use cobs::encode; use cobs::encode;
use core::sync::atomic::AtomicBool; use core::sync::atomic::AtomicBool;
use core::time::Duration;
use delegate::delegate; use delegate::delegate;
use mio::net::{TcpListener, TcpStream};
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::TcpListener;
use std::net::TcpStream;
use std::vec::Vec; use std::vec::Vec;
use crate::encoding::parse_buffer_for_cobs_encoded_packets; use crate::encoding::parse_buffer_for_cobs_encoded_packets;
@ -17,6 +17,9 @@ use crate::hal::std::tcp_server::{
ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer,
}; };
use super::tcp_server::HandledConnectionHandler;
use super::tcp_server::HandledConnectionInfo;
/// Concrete [TcpTcParser] implementation for the [TcpTmtcInCobsServer]. /// Concrete [TcpTcParser] implementation for the [TcpTmtcInCobsServer].
#[derive(Default)] #[derive(Default)]
pub struct CobsTcParser {} pub struct CobsTcParser {}
@ -26,7 +29,7 @@ impl<TmError, TcError: 'static> TcpTcParser<TmError, TcError> for CobsTcParser {
&mut self, &mut self,
tc_buffer: &mut [u8], tc_buffer: &mut [u8],
tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized), tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized),
conn_result: &mut ConnectionResult, conn_result: &mut HandledConnectionInfo,
current_write_idx: usize, current_write_idx: usize,
next_write_idx: &mut usize, next_write_idx: &mut usize,
) -> Result<(), TcpTmtcError<TmError, TcError>> { ) -> Result<(), TcpTmtcError<TmError, TcError>> {
@ -60,7 +63,7 @@ impl<TmError, TcError> TcpTmSender<TmError, TcError> for CobsTmSender {
&mut self, &mut self,
tm_buffer: &mut [u8], tm_buffer: &mut [u8],
tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized), tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized),
conn_result: &mut ConnectionResult, conn_result: &mut HandledConnectionInfo,
stream: &mut TcpStream, stream: &mut TcpStream,
) -> Result<bool, TcpTmtcError<TmError, TcError>> { ) -> Result<bool, TcpTmtcError<TmError, TcError>> {
let mut tm_was_sent = false; let mut tm_was_sent = false;
@ -112,21 +115,30 @@ impl<TmError, TcError> TcpTmSender<TmError, TcError> for CobsTmSender {
/// The [TCP integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// The [TCP integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs)
/// test also serves as the example application for this module. /// test also serves as the example application for this module.
pub struct TcpTmtcInCobsServer< pub struct TcpTmtcInCobsServer<
TmError,
TcError: 'static,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
HandledConnection: HandledConnectionHandler,
TmError,
TcError: 'static,
> { > {
generic_server: pub generic_server: TcpTmtcGenericServer<
TcpTmtcGenericServer<TmError, TcError, TmSource, TcReceiver, CobsTmSender, CobsTcParser>, TmSource,
TcReceiver,
CobsTmSender,
CobsTcParser,
HandledConnection,
TmError,
TcError,
>,
} }
impl< impl<
TmError: 'static,
TcError: 'static,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
> TcpTmtcInCobsServer<TmError, TcError, TmSource, TcReceiver> HandledConnection: HandledConnectionHandler,
TmError: 'static,
TcError: 'static,
> TcpTmtcInCobsServer<TmSource, TcReceiver, HandledConnection, TmError, TcError>
{ {
/// Create a new TCP TMTC server which exchanges TMTC packets encoded with /// Create a new TCP TMTC server which exchanges TMTC packets encoded with
/// [COBS protocol](https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing). /// [COBS protocol](https://en.wikipedia.org/wiki/Consistent_Overhead_Byte_Stuffing).
@ -142,6 +154,7 @@ impl<
cfg: ServerConfig, cfg: ServerConfig,
tm_source: TmSource, tm_source: TmSource,
tc_receiver: TcReceiver, tc_receiver: TcReceiver,
handled_connection: HandledConnection,
stop_signal: Option<Arc<AtomicBool>>, stop_signal: Option<Arc<AtomicBool>>,
) -> Result<Self, std::io::Error> { ) -> Result<Self, std::io::Error> {
Ok(Self { Ok(Self {
@ -151,6 +164,7 @@ impl<
CobsTmSender::new(cfg.tm_buffer_size), CobsTmSender::new(cfg.tm_buffer_size),
tm_source, tm_source,
tc_receiver, tc_receiver,
handled_connection,
stop_signal, stop_signal,
)?, )?,
}) })
@ -167,6 +181,7 @@ impl<
/// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call.
pub fn handle_next_connection( pub fn handle_next_connection(
&mut self, &mut self,
poll_duration: Option<Duration>,
) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>>; ) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>>;
} }
} }
@ -181,15 +196,15 @@ mod tests {
use std::{ use std::{
io::{Read, Write}, io::{Read, Write},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream},
panic, println, thread, panic, thread,
time::Instant, time::Instant,
}; };
use crate::{ use crate::{
encoding::tests::{INVERTED_PACKET, SIMPLE_PACKET}, encoding::tests::{INVERTED_PACKET, SIMPLE_PACKET},
hal::std::tcp_server::{ hal::std::tcp_server::{
tests::{SyncTcCacher, SyncTmSource}, tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource},
ServerConfig, ConnectionResult, ServerConfig,
}, },
}; };
use alloc::sync::Arc; use alloc::sync::Arc;
@ -218,11 +233,12 @@ mod tests {
tc_receiver: SyncTcCacher, tc_receiver: SyncTcCacher,
tm_source: SyncTmSource, tm_source: SyncTmSource,
stop_signal: Option<Arc<AtomicBool>>, stop_signal: Option<Arc<AtomicBool>>,
) -> TcpTmtcInCobsServer<(), (), SyncTmSource, SyncTcCacher> { ) -> TcpTmtcInCobsServer<SyncTmSource, SyncTcCacher, ConnectionFinishedHandler, (), ()> {
TcpTmtcInCobsServer::new( TcpTmtcInCobsServer::new(
ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024),
tm_source, tm_source,
tc_receiver, tc_receiver,
ConnectionFinishedHandler::default(),
stop_signal, stop_signal,
) )
.expect("TCP server generation failed") .expect("TCP server generation failed")
@ -242,13 +258,20 @@ mod tests {
let set_if_done = conn_handled.clone(); let set_if_done = conn_handled.clone();
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let result = result.unwrap();
assert_eq!(conn_result.num_received_tcs, 1); assert_eq!(result, ConnectionResult::HandledConnections(1));
assert_eq!(conn_result.num_sent_tms, 0); tcp_server
.generic_server
.finished_handler
.check_last_connection(0, 1);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
// Send TC to server now. // Send TC to server now.
@ -299,13 +322,20 @@ mod tests {
let set_if_done = conn_handled.clone(); let set_if_done = conn_handled.clone();
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let result = result.unwrap();
assert_eq!(conn_result.num_received_tcs, 2, "Not enough TCs received"); assert_eq!(result, ConnectionResult::HandledConnections(1));
assert_eq!(conn_result.num_sent_tms, 2, "Not enough TMs received"); tcp_server
.generic_server
.finished_handler
.check_last_connection(2, 2);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
// Send TC to server now. // Send TC to server now.
@ -389,6 +419,31 @@ mod tests {
drop(tc_queue); drop(tc_queue);
} }
#[test]
fn test_server_accept_timeout() {
let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
let tc_receiver = SyncTcCacher::default();
let tm_source = SyncTmSource::default();
let mut tcp_server =
generic_tmtc_server(&auto_port_addr, tc_receiver.clone(), tm_source, None);
let start = Instant::now();
// Call the connection handler in separate thread, does block.
let thread_jh = thread::spawn(move || loop {
let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20)));
if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err());
}
let result = result.unwrap();
if result == ConnectionResult::AcceptTimeout {
break;
}
if Instant::now() - start > Duration::from_millis(100) {
panic!("regular stop signal handling failed");
}
});
thread_jh.join().expect("thread join failed");
}
#[test] #[test]
fn test_server_stop_signal() { fn test_server_stop_signal() {
let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0); let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
@ -401,26 +456,32 @@ mod tests {
tm_source, tm_source,
Some(stop_signal.clone()), Some(stop_signal.clone()),
); );
let dest_addr = tcp_server
.local_addr()
.expect("retrieving dest addr failed");
let stop_signal_copy = stop_signal.clone();
let start = Instant::now(); let start = Instant::now();
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
let thread_jh = thread::spawn(move || loop { let thread_jh = thread::spawn(move || loop {
println!("hello wtf!!!!"); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(20)));
let result = tcp_server.handle_next_connection();
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
println!("helluuuu");
let result = result.unwrap(); let result = result.unwrap();
if result.stopped_by_signal { if result == ConnectionResult::AcceptTimeout {
panic!("unexpected accept timeout");
}
if stop_signal_copy.load(Ordering::Relaxed) {
break; break;
} }
if Instant::now() - start > Duration::from_millis(50) { if Instant::now() - start > Duration::from_millis(100) {
panic!("regular stop signal handling failed"); panic!("regular stop signal handling failed");
} }
}); });
// We connect but do not do anything.
let _stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");
stop_signal.store(true, Ordering::Relaxed); stop_signal.store(true, Ordering::Relaxed);
thread::sleep(Duration::from_millis(100)); // No need to drop the connection, the stop signal should take take of everything.
panic!("shit") thread_jh.join().expect("thread join failed");
//thread_jh.join().expect("thread join failed");
} }
} }

View File

@ -7,7 +7,7 @@ use core::time::Duration;
use mio::net::{TcpListener, TcpStream}; use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Token}; use mio::{Events, Interest, Poll, Token};
use socket2::{Domain, Socket, Type}; use socket2::{Domain, Socket, Type};
use std::io::Read; use std::io::{self, Read};
use std::net::SocketAddr; use std::net::SocketAddr;
// use std::net::TcpListener; // use std::net::TcpListener;
// use std::net::{SocketAddr, TcpStream}; // use std::net::{SocketAddr, TcpStream};
@ -84,21 +84,15 @@ pub enum TcpTmtcError<TmError, TcError> {
/// Result of one connection attempt. Contains the client address if a connection was established, /// Result of one connection attempt. Contains the client address if a connection was established,
/// in addition to the number of telecommands and telemetry packets exchanged. /// in addition to the number of telecommands and telemetry packets exchanged.
#[derive(Debug)] #[derive(Debug, PartialEq, Eq)]
pub enum ConnectionResult { pub enum ConnectionResult {
AcceptTimeout, AcceptTimeout,
HandledConnection(HandledConnectionInfo), HandledConnections(u32),
} }
impl From<HandledConnectionInfo> for ConnectionResult { #[derive(Debug)]
fn from(info: HandledConnectionInfo) -> Self {
ConnectionResult::HandledConnection(info)
}
}
#[derive(Debug, Default)]
pub struct HandledConnectionInfo { pub struct HandledConnectionInfo {
pub addr: Option<SocketAddr>, pub addr: SocketAddr,
pub num_received_tcs: u32, pub num_received_tcs: u32,
pub num_sent_tms: u32, pub num_sent_tms: u32,
/// The generic TCP server can be stopped using an external signal. If this happened, this /// The generic TCP server can be stopped using an external signal. If this happened, this
@ -106,6 +100,17 @@ pub struct HandledConnectionInfo {
pub stopped_by_signal: bool, pub stopped_by_signal: bool,
} }
impl HandledConnectionInfo {
pub fn new(addr: SocketAddr) -> Self {
Self {
addr,
num_received_tcs: 0,
num_sent_tms: 0,
stopped_by_signal: false,
}
}
}
pub trait HandledConnectionHandler { pub trait HandledConnectionHandler {
fn handled_connection(&mut self, info: HandledConnectionInfo); fn handled_connection(&mut self, info: HandledConnectionInfo);
} }
@ -156,14 +161,15 @@ pub trait TcpTmSender<TmError, TcError> {
/// ///
/// 1. [TcpTmtcInCobsServer] to exchange TMTC wrapped inside the COBS framing protocol. /// 1. [TcpTmtcInCobsServer] to exchange TMTC wrapped inside the COBS framing protocol.
pub struct TcpTmtcGenericServer< pub struct TcpTmtcGenericServer<
TmError,
TcError,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
TmSender: TcpTmSender<TmError, TcError>, TmSender: TcpTmSender<TmError, TcError>,
TcParser: TcpTcParser<TmError, TcError>, TcParser: TcpTcParser<TmError, TcError>,
//HandledConnection: HandledConnectionHandler HandledConnection: HandledConnectionHandler,
TmError,
TcError,
> { > {
pub finished_handler: HandledConnection,
pub(crate) listener: TcpListener, pub(crate) listener: TcpListener,
pub(crate) inner_loop_delay: Duration, pub(crate) inner_loop_delay: Duration,
pub(crate) tm_source: TmSource, pub(crate) tm_source: TmSource,
@ -172,19 +178,29 @@ pub struct TcpTmtcGenericServer<
pub(crate) tc_buffer: Vec<u8>, pub(crate) tc_buffer: Vec<u8>,
poll: Poll, poll: Poll,
events: Events, events: Events,
stop_signal: Option<Arc<AtomicBool>>,
tc_handler: TcParser, tc_handler: TcParser,
tm_handler: TmSender, tm_handler: TmSender,
stop_signal: Option<Arc<AtomicBool>>,
} }
impl< impl<
TmError: 'static,
TcError: 'static,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
TmSender: TcpTmSender<TmError, TcError>, TmSender: TcpTmSender<TmError, TcError>,
TcParser: TcpTcParser<TmError, TcError>, TcParser: TcpTcParser<TmError, TcError>,
> TcpTmtcGenericServer<TmError, TcError, TmSource, TcReceiver, TmSender, TcParser> HandledConnection: HandledConnectionHandler,
TmError: 'static,
TcError: 'static,
>
TcpTmtcGenericServer<
TmSource,
TcReceiver,
TmSender,
TcParser,
HandledConnection,
TmError,
TcError,
>
{ {
/// Create a new generic TMTC server instance. /// Create a new generic TMTC server instance.
/// ///
@ -198,12 +214,14 @@ impl<
/// then sent back to the client. /// then sent back to the client.
/// * `tc_receiver` - Any received telecommand which was decoded successfully will be forwarded /// * `tc_receiver` - Any received telecommand which was decoded successfully will be forwarded
/// to this TC receiver. /// to this TC receiver.
/// * `stop_signal` - Can be used to stop the server even if a connection is ongoing.
pub fn new( pub fn new(
cfg: ServerConfig, cfg: ServerConfig,
tc_parser: TcParser, tc_parser: TcParser,
tm_sender: TmSender, tm_sender: TmSender,
tm_source: TmSource, tm_source: TmSource,
tc_receiver: TcReceiver, tc_receiver: TcReceiver,
finished_handler: HandledConnection,
stop_signal: Option<Arc<AtomicBool>>, stop_signal: Option<Arc<AtomicBool>>,
) -> Result<Self, std::io::Error> { ) -> Result<Self, std::io::Error> {
// Create a TCP listener bound to two addresses. // Create a TCP listener bound to two addresses.
@ -212,14 +230,16 @@ impl<
socket.set_reuse_address(cfg.reuse_addr)?; socket.set_reuse_address(cfg.reuse_addr)?;
#[cfg(unix)] #[cfg(unix)]
socket.set_reuse_port(cfg.reuse_port)?; socket.set_reuse_port(cfg.reuse_port)?;
// MIO does not do this for us. We want the accept calls to be non-blocking.
socket.set_nonblocking(true)?;
let addr = (cfg.addr).into(); let addr = (cfg.addr).into();
socket.bind(&addr)?; socket.bind(&addr)?;
socket.listen(128)?; socket.listen(128)?;
// Create a poll instance. // Create a poll instance.
let mut poll = Poll::new()?; let poll = Poll::new()?;
// Create storage for events. // Create storage for events.
let mut events = Events::with_capacity(10); let events = Events::with_capacity(10);
let listener: std::net::TcpListener = socket.into(); let listener: std::net::TcpListener = socket.into();
let mut mio_listener = TcpListener::from_std(listener); let mut mio_listener = TcpListener::from_std(listener);
@ -239,6 +259,7 @@ impl<
tc_receiver, tc_receiver,
tc_buffer: vec![0; cfg.tc_buffer_size], tc_buffer: vec![0; cfg.tc_buffer_size],
stop_signal, stop_signal,
finished_handler,
}) })
} }
@ -268,20 +289,39 @@ impl<
/// client does not send any telecommands and no telemetry needs to be sent back to the client. /// client does not send any telecommands and no telemetry needs to be sent back to the client.
pub fn handle_next_connection( pub fn handle_next_connection(
&mut self, &mut self,
poll_timeout: Option<Duration>,
) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>> { ) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>> {
// Poll Mio for events, blocking until we get an event. let mut handled_connections = 0;
self.poll // Poll Mio for events.
.poll(&mut self.events, Some(Duration::from_millis(400)))?; self.poll.poll(&mut self.events, poll_timeout)?;
let mut acceptable_connection = false;
// Process each event. // Process each event.
if let Some(event) = self.events.iter().next() { for event in self.events.iter() {
if event.token() == Token(0) { if event.token() == Token(0) {
let connection = self.listener.accept()?; acceptable_connection = true;
return self } else {
.handle_accepted_connection(connection.0, connection.1) // Should never happen..
.map(|v| v.into());
}
panic!("unexpected TCP event token"); panic!("unexpected TCP event token");
} }
}
// I'd love to do this in the loop above, but there are issues with multiple borrows.
if acceptable_connection {
// There might be mutliple connections available. Accept until all of them have
// been handled.
loop {
match self.listener.accept() {
Ok((stream, addr)) => {
self.handle_accepted_connection(stream, addr)?;
handled_connections += 1;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break,
Err(err) => return Err(TcpTmtcError::Io(err)),
}
}
}
if handled_connections > 0 {
return Ok(ConnectionResult::HandledConnections(handled_connections));
}
Ok(ConnectionResult::AcceptTimeout) Ok(ConnectionResult::AcceptTimeout)
} }
@ -289,13 +329,10 @@ impl<
&mut self, &mut self,
mut stream: TcpStream, mut stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
) -> Result<HandledConnectionInfo, TcpTmtcError<TmError, TcError>> { ) -> Result<(), TcpTmtcError<TmError, TcError>> {
let mut current_write_idx; let mut current_write_idx;
let mut next_write_idx = 0; let mut next_write_idx = 0;
let mut connection_result = HandledConnectionInfo::default(); let mut connection_result = HandledConnectionInfo::new(addr);
// stream.set_nonblocking(true)?;
connection_result.addr = Some(addr);
connection_result.stopped_by_signal = false;
current_write_idx = next_write_idx; current_write_idx = next_write_idx;
loop { loop {
let read_result = stream.read(&mut self.tc_buffer[current_write_idx..]); let read_result = stream.read(&mut self.tc_buffer[current_write_idx..]);
@ -359,7 +396,8 @@ impl<
.load(std::sync::atomic::Ordering::Relaxed) .load(std::sync::atomic::Ordering::Relaxed)
{ {
connection_result.stopped_by_signal = true; connection_result.stopped_by_signal = true;
return Ok(connection_result); self.finished_handler.handled_connection(connection_result);
return Ok(());
} }
} }
} }
@ -375,7 +413,8 @@ impl<
&mut connection_result, &mut connection_result,
&mut stream, &mut stream,
)?; )?;
Ok(connection_result) self.finished_handler.handled_connection(connection_result);
Ok(())
} }
} }
@ -387,6 +426,8 @@ pub(crate) mod tests {
use crate::tmtc::{ReceivesTcCore, TmPacketSourceCore}; use crate::tmtc::{ReceivesTcCore, TmPacketSourceCore};
use super::*;
#[derive(Default, Clone)] #[derive(Default, Clone)]
pub(crate) struct SyncTcCacher { pub(crate) struct SyncTcCacher {
pub(crate) tc_queue: Arc<Mutex<VecDeque<Vec<u8>>>>, pub(crate) tc_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
@ -433,4 +474,30 @@ pub(crate) mod tests {
Ok(0) Ok(0)
} }
} }
#[derive(Default)]
pub struct ConnectionFinishedHandler {
connection_info: VecDeque<HandledConnectionInfo>,
}
impl HandledConnectionHandler for ConnectionFinishedHandler {
fn handled_connection(&mut self, info: HandledConnectionInfo) {
self.connection_info.push_back(info);
}
}
impl ConnectionFinishedHandler {
pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) {
let last_conn_result = self
.connection_info
.pop_back()
.expect("no connection info available");
assert_eq!(last_conn_result.num_received_tcs, num_tcs);
assert_eq!(last_conn_result.num_sent_tms, num_tms);
}
pub fn check_no_connections_left(&self) {
assert!(self.connection_info.is_empty());
}
}
} }

View File

@ -1,10 +1,8 @@
use alloc::sync::Arc; use alloc::sync::Arc;
use core::sync::atomic::AtomicBool; use core::{sync::atomic::AtomicBool, time::Duration};
use delegate::delegate; use delegate::delegate;
use std::{ use mio::net::{TcpListener, TcpStream};
io::Write, use std::{io::Write, net::SocketAddr};
net::{SocketAddr, TcpListener, TcpStream},
};
use crate::{ use crate::{
encoding::parse_buffer_for_ccsds_space_packets, encoding::parse_buffer_for_ccsds_space_packets,
@ -13,7 +11,8 @@ use crate::{
}; };
use super::tcp_server::{ use super::tcp_server::{
ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer, ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig, TcpTcParser,
TcpTmSender, TcpTmtcError, TcpTmtcGenericServer,
}; };
/// Concrete [TcpTcParser] implementation for the [TcpSpacepacketsServer]. /// Concrete [TcpTcParser] implementation for the [TcpSpacepacketsServer].
@ -27,14 +26,14 @@ impl<PacketIdChecker: ValidatorU16Id> SpacepacketsTcParser<PacketIdChecker> {
} }
} }
impl<TmError, TcError: 'static, PacketIdChecker: ValidatorU16Id> TcpTcParser<TmError, TcError> impl<PacketIdChecker: ValidatorU16Id, TmError, TcError: 'static> TcpTcParser<TmError, TcError>
for SpacepacketsTcParser<PacketIdChecker> for SpacepacketsTcParser<PacketIdChecker>
{ {
fn handle_tc_parsing( fn handle_tc_parsing(
&mut self, &mut self,
tc_buffer: &mut [u8], tc_buffer: &mut [u8],
tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized), tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized),
conn_result: &mut ConnectionResult, conn_result: &mut HandledConnectionInfo,
current_write_idx: usize, current_write_idx: usize,
next_write_idx: &mut usize, next_write_idx: &mut usize,
) -> Result<(), TcpTmtcError<TmError, TcError>> { ) -> Result<(), TcpTmtcError<TmError, TcError>> {
@ -59,7 +58,7 @@ impl<TmError, TcError> TcpTmSender<TmError, TcError> for SpacepacketsTmSender {
&mut self, &mut self,
tm_buffer: &mut [u8], tm_buffer: &mut [u8],
tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized), tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized),
conn_result: &mut ConnectionResult, conn_result: &mut HandledConnectionInfo,
stream: &mut TcpStream, stream: &mut TcpStream,
) -> Result<bool, TcpTmtcError<TmError, TcError>> { ) -> Result<bool, TcpTmtcError<TmError, TcError>> {
let mut tm_was_sent = false; let mut tm_was_sent = false;
@ -94,29 +93,40 @@ impl<TmError, TcError> TcpTmSender<TmError, TcError> for SpacepacketsTmSender {
/// The [TCP server integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs) /// The [TCP server integration tests](https://egit.irs.uni-stuttgart.de/rust/sat-rs/src/branch/main/satrs/tests/tcp_servers.rs)
/// also serves as the example application for this module. /// also serves as the example application for this module.
pub struct TcpSpacepacketsServer< pub struct TcpSpacepacketsServer<
TmError,
TcError: 'static,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
PacketIdChecker: ValidatorU16Id, PacketIdChecker: ValidatorU16Id,
> { HandledConnection: HandledConnectionHandler,
generic_server: TcpTmtcGenericServer<
TmError, TmError,
TcError, TcError: 'static,
> {
pub generic_server: TcpTmtcGenericServer<
TmSource, TmSource,
TcReceiver, TcReceiver,
SpacepacketsTmSender, SpacepacketsTmSender,
SpacepacketsTcParser<PacketIdChecker>, SpacepacketsTcParser<PacketIdChecker>,
HandledConnection,
TmError,
TcError,
>, >,
} }
impl< impl<
TmError: 'static,
TcError: 'static,
TmSource: TmPacketSource<Error = TmError>, TmSource: TmPacketSource<Error = TmError>,
TcReceiver: ReceivesTc<Error = TcError>, TcReceiver: ReceivesTc<Error = TcError>,
PacketIdChecker: ValidatorU16Id, PacketIdChecker: ValidatorU16Id,
> TcpSpacepacketsServer<TmError, TcError, TmSource, TcReceiver, PacketIdChecker> HandledConnection: HandledConnectionHandler,
TmError: 'static,
TcError: 'static,
>
TcpSpacepacketsServer<
TmSource,
TcReceiver,
PacketIdChecker,
HandledConnection,
TmError,
TcError,
>
{ {
/// ///
/// ## Parameter /// ## Parameter
@ -133,6 +143,7 @@ impl<
tm_source: TmSource, tm_source: TmSource,
tc_receiver: TcReceiver, tc_receiver: TcReceiver,
packet_id_checker: PacketIdChecker, packet_id_checker: PacketIdChecker,
handled_connection: HandledConnection,
stop_signal: Option<Arc<AtomicBool>>, stop_signal: Option<Arc<AtomicBool>>,
) -> Result<Self, std::io::Error> { ) -> Result<Self, std::io::Error> {
Ok(Self { Ok(Self {
@ -142,6 +153,7 @@ impl<
SpacepacketsTmSender::default(), SpacepacketsTmSender::default(),
tm_source, tm_source,
tc_receiver, tc_receiver,
handled_connection,
stop_signal, stop_signal,
)?, )?,
}) })
@ -158,6 +170,7 @@ impl<
/// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call. /// Delegation to the [TcpTmtcGenericServer::handle_next_connection] call.
pub fn handle_next_connection( pub fn handle_next_connection(
&mut self, &mut self,
poll_timeout: Option<Duration>
) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>>; ) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>>;
} }
} }
@ -185,8 +198,8 @@ mod tests {
}; };
use crate::hal::std::tcp_server::{ use crate::hal::std::tcp_server::{
tests::{SyncTcCacher, SyncTmSource}, tests::{ConnectionFinishedHandler, SyncTcCacher, SyncTmSource},
ServerConfig, ConnectionResult, ServerConfig,
}; };
use super::TcpSpacepacketsServer; use super::TcpSpacepacketsServer;
@ -202,12 +215,20 @@ mod tests {
tm_source: SyncTmSource, tm_source: SyncTmSource,
packet_id_lookup: HashSet<PacketId>, packet_id_lookup: HashSet<PacketId>,
stop_signal: Option<Arc<AtomicBool>>, stop_signal: Option<Arc<AtomicBool>>,
) -> TcpSpacepacketsServer<(), (), SyncTmSource, SyncTcCacher, HashSet<PacketId>> { ) -> TcpSpacepacketsServer<
SyncTmSource,
SyncTcCacher,
HashSet<PacketId>,
ConnectionFinishedHandler,
(),
(),
> {
TcpSpacepacketsServer::new( TcpSpacepacketsServer::new(
ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024), ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024),
tm_source, tm_source,
tc_receiver, tc_receiver,
packet_id_lookup, packet_id_lookup,
ConnectionFinishedHandler::default(),
stop_signal, stop_signal,
) )
.expect("TCP server generation failed") .expect("TCP server generation failed")
@ -234,13 +255,20 @@ mod tests {
let set_if_done = conn_handled.clone(); let set_if_done = conn_handled.clone();
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let conn_result = result.unwrap();
assert_eq!(conn_result.num_received_tcs, 1); matches!(conn_result, ConnectionResult::HandledConnections(1));
assert_eq!(conn_result.num_sent_tms, 0); tcp_server
.generic_server
.finished_handler
.check_last_connection(0, 1);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
let ping_tc = let ping_tc =
@ -305,16 +333,20 @@ mod tests {
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(100)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let conn_result = result.unwrap();
assert_eq!( matches!(conn_result, ConnectionResult::HandledConnections(1));
conn_result.num_received_tcs, 2, tcp_server
"wrong number of received TCs" .generic_server
); .finished_handler
assert_eq!(conn_result.num_sent_tms, 2, "wrong number of sent TMs"); .check_last_connection(2, 2);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");

View File

@ -24,7 +24,10 @@ use std::{
use hashbrown::HashSet; use hashbrown::HashSet;
use satrs::{ use satrs::{
encoding::cobs::encode_packet_with_cobs, encoding::cobs::encode_packet_with_cobs,
hal::std::tcp_server::{ServerConfig, TcpSpacepacketsServer, TcpTmtcInCobsServer}, hal::std::tcp_server::{
ConnectionResult, HandledConnectionHandler, HandledConnectionInfo, ServerConfig,
TcpSpacepacketsServer, TcpTmtcInCobsServer,
},
tmtc::{ReceivesTcCore, TmPacketSourceCore}, tmtc::{ReceivesTcCore, TmPacketSourceCore},
}; };
use spacepackets::{ use spacepackets::{
@ -33,10 +36,36 @@ use spacepackets::{
}; };
use std::{collections::VecDeque, sync::Arc, vec::Vec}; use std::{collections::VecDeque, sync::Arc, vec::Vec};
#[derive(Default)]
pub struct ConnectionFinishedHandler {
connection_info: VecDeque<HandledConnectionInfo>,
}
impl HandledConnectionHandler for ConnectionFinishedHandler {
fn handled_connection(&mut self, info: HandledConnectionInfo) {
self.connection_info.push_back(info);
}
}
impl ConnectionFinishedHandler {
pub fn check_last_connection(&mut self, num_tms: u32, num_tcs: u32) {
let last_conn_result = self
.connection_info
.pop_back()
.expect("no connection info available");
assert_eq!(last_conn_result.num_received_tcs, num_tcs);
assert_eq!(last_conn_result.num_sent_tms, num_tms);
}
pub fn check_no_connections_left(&self) {
assert!(self.connection_info.is_empty());
}
}
#[derive(Default, Clone)] #[derive(Default, Clone)]
struct SyncTcCacher { struct SyncTcCacher {
tc_queue: Arc<Mutex<VecDeque<Vec<u8>>>>, tc_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
} }
impl ReceivesTcCore for SyncTcCacher { impl ReceivesTcCore for SyncTcCacher {
type Error = (); type Error = ();
@ -96,6 +125,7 @@ fn test_cobs_server() {
ServerConfig::new(AUTO_PORT_ADDR, Duration::from_millis(2), 1024, 1024), ServerConfig::new(AUTO_PORT_ADDR, Duration::from_millis(2), 1024, 1024),
tm_source, tm_source,
tc_receiver.clone(), tc_receiver.clone(),
ConnectionFinishedHandler::default(),
None, None,
) )
.expect("TCP server generation failed"); .expect("TCP server generation failed");
@ -107,13 +137,20 @@ fn test_cobs_server() {
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(400)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let conn_result = result.unwrap();
assert_eq!(conn_result.num_received_tcs, 1, "No TC received"); assert_eq!(conn_result, ConnectionResult::HandledConnections(1));
assert_eq!(conn_result.num_sent_tms, 1, "No TM received"); tcp_server
.generic_server
.finished_handler
.check_last_connection(1, 1);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
// Signal the main thread we are done. // Signal the main thread we are done.
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
@ -180,6 +217,7 @@ fn test_ccsds_server() {
tm_source, tm_source,
tc_receiver.clone(), tc_receiver.clone(),
packet_id_lookup, packet_id_lookup,
ConnectionFinishedHandler::default(),
None, None,
) )
.expect("TCP server generation failed"); .expect("TCP server generation failed");
@ -190,13 +228,20 @@ fn test_ccsds_server() {
let set_if_done = conn_handled.clone(); let set_if_done = conn_handled.clone();
// Call the connection handler in separate thread, does block. // Call the connection handler in separate thread, does block.
thread::spawn(move || { thread::spawn(move || {
let result = tcp_server.handle_next_connection(); let result = tcp_server.handle_next_connection(Some(Duration::from_millis(500)));
if result.is_err() { if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err()); panic!("handling connection failed: {:?}", result.unwrap_err());
} }
let conn_result = result.unwrap(); let conn_result = result.unwrap();
assert_eq!(conn_result.num_received_tcs, 1); assert_eq!(conn_result, ConnectionResult::HandledConnections(1));
assert_eq!(conn_result.num_sent_tms, 1); tcp_server
.generic_server
.finished_handler
.check_last_connection(1, 1);
tcp_server
.generic_server
.finished_handler
.check_no_connections_left();
set_if_done.store(true, Ordering::Relaxed); set_if_done.store(true, Ordering::Relaxed);
}); });
let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed"); let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");