Refactor and improve TCP servers #150

Merged
muellerr merged 3 commits from refactor-tcp-server into main 2024-04-10 12:29:24 +02:00
3 changed files with 89 additions and 12 deletions
Showing only changes of commit c67b7cb93a - Show all commits

View File

@ -70,6 +70,11 @@ version = "0.5.4"
features = ["all"]
optional = true
[dependencies.mio]
version = "0.8"
features = ["os-poll", "net"]
optional = true
[dependencies.spacepackets]
# git = "https://egit.irs.uni-stuttgart.de/rust/spacepackets.git"
version = "0.11.0-rc.2"
@ -104,7 +109,8 @@ std = [
"spacepackets/std",
"num_enum/std",
"thiserror",
"socket2"
"socket2",
"mio"
]
alloc = [
"serde/alloc",

View File

@ -181,7 +181,7 @@ mod tests {
use std::{
io::{Read, Write},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream},
thread,
panic, println, thread,
time::Instant,
};
@ -403,15 +403,24 @@ mod tests {
);
let start = Instant::now();
// Call the connection handler in separate thread, does block.
thread::spawn(move || loop {
let thread_jh = thread::spawn(move || loop {
println!("hello wtf!!!!");
let result = tcp_server.handle_next_connection();
if result.is_err() {
panic!("handling connection failed: {:?}", result.unwrap_err());
}
println!("helluuuu");
let result = result.unwrap();
if result.stopped_by_signal {
break;
}
if Instant::now() - start > Duration::from_millis(50) {
panic!("regular stop signal handling failed");
}
});
stop_signal.store(true, Ordering::Relaxed);
thread::sleep(Duration::from_millis(100));
panic!("shit")
//thread_jh.join().expect("thread join failed");
}
}

View File

@ -4,10 +4,13 @@ use alloc::vec;
use alloc::vec::Vec;
use core::sync::atomic::AtomicBool;
use core::time::Duration;
use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Token};
use socket2::{Domain, Socket, Type};
use std::io::Read;
use std::net::TcpListener;
use std::net::{SocketAddr, TcpStream};
use std::net::SocketAddr;
// use std::net::TcpListener;
// use std::net::{SocketAddr, TcpStream};
use std::thread;
use crate::tmtc::{ReceivesTc, TmPacketSource};
@ -81,11 +84,30 @@ pub enum TcpTmtcError<TmError, TcError> {
/// Result of one connection attempt. Contains the client address if a connection was established,
/// in addition to the number of telecommands and telemetry packets exchanged.
#[derive(Debug)]
pub enum ConnectionResult {
AcceptTimeout,
HandledConnection(HandledConnectionInfo),
}
impl From<HandledConnectionInfo> for ConnectionResult {
fn from(info: HandledConnectionInfo) -> Self {
ConnectionResult::HandledConnection(info)
}
}
#[derive(Debug, Default)]
pub struct ConnectionResult {
pub struct HandledConnectionInfo {
pub addr: Option<SocketAddr>,
pub num_received_tcs: u32,
pub num_sent_tms: u32,
/// The generic TCP server can be stopped using an external signal. If this happened, this
/// boolean will be set to true.
pub stopped_by_signal: bool,
}
pub trait HandledConnectionHandler {
fn handled_connection(&mut self, info: HandledConnectionInfo);
}
/// Generic parser abstraction for an object which can parse for telecommands given a raw
@ -96,7 +118,7 @@ pub trait TcpTcParser<TmError, TcError> {
&mut self,
tc_buffer: &mut [u8],
tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized),
conn_result: &mut ConnectionResult,
conn_result: &mut HandledConnectionInfo,
current_write_idx: usize,
next_write_idx: &mut usize,
) -> Result<(), TcpTmtcError<TmError, TcError>>;
@ -111,7 +133,7 @@ pub trait TcpTmSender<TmError, TcError> {
&mut self,
tm_buffer: &mut [u8],
tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized),
conn_result: &mut ConnectionResult,
conn_result: &mut HandledConnectionInfo,
stream: &mut TcpStream,
) -> Result<bool, TcpTmtcError<TmError, TcError>>;
}
@ -140,6 +162,7 @@ pub struct TcpTmtcGenericServer<
TcReceiver: ReceivesTc<Error = TcError>,
TmSender: TcpTmSender<TmError, TcError>,
TcParser: TcpTcParser<TmError, TcError>,
//HandledConnection: HandledConnectionHandler
> {
pub(crate) listener: TcpListener,
pub(crate) inner_loop_delay: Duration,
@ -147,6 +170,8 @@ pub struct TcpTmtcGenericServer<
pub(crate) tm_buffer: Vec<u8>,
pub(crate) tc_receiver: TcReceiver,
pub(crate) tc_buffer: Vec<u8>,
poll: Poll,
events: Events,
stop_signal: Option<Arc<AtomicBool>>,
tc_handler: TcParser,
tm_handler: TmSender,
@ -183,16 +208,31 @@ impl<
) -> Result<Self, std::io::Error> {
// Create a TCP listener bound to two addresses.
let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
socket.set_reuse_address(cfg.reuse_addr)?;
#[cfg(unix)]
socket.set_reuse_port(cfg.reuse_port)?;
let addr = (cfg.addr).into();
socket.bind(&addr)?;
socket.listen(128)?;
// Create a poll instance.
let mut poll = Poll::new()?;
// Create storage for events.
let mut events = Events::with_capacity(10);
let listener: std::net::TcpListener = socket.into();
let mut mio_listener = TcpListener::from_std(listener);
// Start listening for incoming connections.
poll.registry()
.register(&mut mio_listener, Token(0), Interest::READABLE)?;
Ok(Self {
tc_handler: tc_parser,
tm_handler: tm_sender,
listener: socket.into(),
poll,
events,
listener: mio_listener,
inner_loop_delay: cfg.inner_loop_delay,
tm_source,
tm_buffer: vec![0; cfg.tm_buffer_size],
@ -229,12 +269,33 @@ impl<
pub fn handle_next_connection(
&mut self,
) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>> {
let mut connection_result = ConnectionResult::default();
// Poll Mio for events, blocking until we get an event.
self.poll
.poll(&mut self.events, Some(Duration::from_millis(400)))?;
// Process each event.
if let Some(event) = self.events.iter().next() {
if event.token() == Token(0) {
let connection = self.listener.accept()?;
return self
.handle_accepted_connection(connection.0, connection.1)
.map(|v| v.into());
}
panic!("unexpected TCP event token");
}
Ok(ConnectionResult::AcceptTimeout)
}
fn handle_accepted_connection(
&mut self,
mut stream: TcpStream,
addr: SocketAddr,
) -> Result<HandledConnectionInfo, TcpTmtcError<TmError, TcError>> {
let mut current_write_idx;
let mut next_write_idx = 0;
let (mut stream, addr) = self.listener.accept()?;
stream.set_nonblocking(true)?;
let mut connection_result = HandledConnectionInfo::default();
// stream.set_nonblocking(true)?;
connection_result.addr = Some(addr);
connection_result.stopped_by_signal = false;
current_write_idx = next_write_idx;
loop {
let read_result = stream.read(&mut self.tc_buffer[current_write_idx..]);
@ -297,6 +358,7 @@ impl<
.unwrap()
.load(std::sync::atomic::Ordering::Relaxed)
{
connection_result.stopped_by_signal = true;
return Ok(connection_result);
}
}