diff --git a/src/interface/tcp_spp_client.rs b/src/interface/tcp_spp_client.rs index 280a68b..fe34e52 100644 --- a/src/interface/tcp_spp_client.rs +++ b/src/interface/tcp_spp_client.rs @@ -25,6 +25,12 @@ pub enum ClientError { Io(#[from] io::Error), } +#[derive(Debug)] +pub enum ClientResult { + Ok, + ConnectionLost, +} + #[allow(dead_code)] pub struct TcpSppClientCommon { id: ComponentId, @@ -73,6 +79,7 @@ impl TcpSppClientCommon { Err(e) => match e { mpsc::TryRecvError::Empty => break, mpsc::TryRecvError::Disconnected => { + println!("god fuckikng damn it"); log::error!("TM sender to TCP client has disconnected"); break; } @@ -132,7 +139,12 @@ impl TcpSppClientStd { }) } - pub fn operation(&mut self) -> Result<(), ClientError> { + #[allow(dead_code)] + pub fn connected(&self) -> bool { + self.stream.is_some() + } + + pub fn operation(&mut self) -> Result { if let Some(client) = &mut self.stream { // Write TM first before blocking on the read call. self.common.write_to_server(client)?; @@ -141,17 +153,19 @@ impl TcpSppClientStd { Ok(0) => { log::info!("server closed connection"); self.stream = None; + return Ok(ClientResult::ConnectionLost); } Ok(read_bytes) => self.common.handle_read_bytstream(read_bytes)?, Err(e) => { if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut { self.common.write_to_server(client)?; - return Ok(()); + return Ok(ClientResult::ConnectionLost); } log::warn!("server error: {e:?}"); if e.kind() == io::ErrorKind::ConnectionReset { self.stream = None; + return Ok(ClientResult::ConnectionLost); } return Err(e.into()); } @@ -164,7 +178,7 @@ impl TcpSppClientStd { std::thread::sleep(self.read_and_idle_delay); } - Ok(()) + Ok(ClientResult::Ok) } } @@ -301,6 +315,7 @@ mod tests { use std::{ io::Write, net::{TcpListener, TcpStream}, + sync::{atomic::AtomicBool, Arc}, thread, time::Duration, }; @@ -320,13 +335,18 @@ mod tests { 1, ); + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + struct TcpServerTestbench { tcp_server: TcpListener, } impl TcpServerTestbench { - fn new() -> Self { - let tcp_server = TcpListener::bind("127.0.0.1:0").unwrap(); + fn new(port: u16) -> Self { + let tcp_server = + TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)).unwrap(); tcp_server .set_nonblocking(true) .expect("setting TCP server non-blocking failed"); @@ -337,7 +357,7 @@ mod tests { self.tcp_server.local_addr().unwrap() } - fn attempt_connection(&mut self, limit: u32) -> Result { + fn check_for_connections(&mut self, limit: u32) -> Result { for _ in 0..limit { match self.tcp_server.accept() { Ok((stream, _)) => { @@ -392,11 +412,11 @@ mod tests { fn basic_client_test() { let (tc_source_tx, _tc_source_rx) = mpsc::channel(); let (_tm_tcp_client_tx, tm_tcp_client_rx) = mpsc::channel(); - let mut tcp_server = TcpServerTestbench::new(); + let mut tcp_server = TcpServerTestbench::new(0); let local_addr = tcp_server.local_addr(); let jh0 = thread::spawn(move || { tcp_server - .attempt_connection(3) + .check_for_connections(3) .expect("no client connection detected"); }); let mut spp_client = TcpSppClientStd::new( @@ -417,7 +437,7 @@ mod tests { fn basic_client_tm_test() { let (tc_source_tx, _tc_source_rx) = mpsc::channel(); let (tm_tcp_client_tx, tm_tcp_client_rx) = mpsc::channel(); - let mut tcp_server = TcpServerTestbench::new(); + let mut tcp_server = TcpServerTestbench::new(0); let local_addr = tcp_server.local_addr(); let mut buf: [u8; 7] = [0; 7]; TEST_TM @@ -426,7 +446,7 @@ mod tests { let jh0 = thread::spawn(move || { let mut read_buf: [u8; 64] = [0; 64]; let mut stream = tcp_server - .attempt_connection(3) + .check_for_connections(3) .expect("no client connection detected"); stream .set_read_timeout(Some(Duration::from_millis(10))) @@ -461,7 +481,7 @@ mod tests { fn basic_client_tc_test() { let (tc_source_tx, tc_source_rx) = mpsc::channel(); let (_tm_tcp_client_tx, tm_tcp_client_rx) = mpsc::channel(); - let mut tcp_server = TcpServerTestbench::new(); + let mut tcp_server = TcpServerTestbench::new(0); let local_addr = tcp_server.local_addr(); let mut buf: [u8; 8] = [0; 8]; TEST_TC @@ -469,7 +489,7 @@ mod tests { .expect("writing TM failed"); let jh0 = thread::spawn(move || { let mut stream = tcp_server - .attempt_connection(3) + .check_for_connections(3) .expect("no client connection detected"); stream .set_read_timeout(Some(Duration::from_millis(10))) @@ -486,6 +506,7 @@ mod tests { local_addr.port(), ) .expect("creating TCP SPP client failed"); + assert!(spp_client.connected()); let mut received_packet = false; (0..3).for_each(|_| { spp_client.operation().unwrap(); @@ -506,7 +527,7 @@ mod tests { fn basic_client_tmtc_test() { let (tc_source_tx, tc_source_rx) = mpsc::channel(); let (tm_tcp_client_tx, tm_tcp_client_rx) = mpsc::channel(); - let mut tcp_server = TcpServerTestbench::new(); + let mut tcp_server = TcpServerTestbench::new(0); let local_addr = tcp_server.local_addr(); let mut tc_buf: [u8; 8] = [0; 8]; let mut tm_buf: [u8; 8] = [0; 8]; @@ -519,7 +540,7 @@ mod tests { let jh0 = thread::spawn(move || { let mut read_buf: [u8; 64] = [0; 64]; let mut stream = tcp_server - .attempt_connection(3) + .check_for_connections(3) .expect("no client connection detected"); stream .set_read_timeout(Some(Duration::from_millis(10))) @@ -545,6 +566,7 @@ mod tests { local_addr.port(), ) .expect("creating TCP SPP client failed"); + assert!(spp_client.connected()); let mut received_packet = false; (0..3).for_each(|_| { spp_client.operation().unwrap(); @@ -561,7 +583,83 @@ mod tests { #[test] fn test_broken_connection() { - // TODO: Verify the client re-connects automatically if the server is dropped and then set - // up again. + init(); + let (tc_source_tx, _tc_source_rx) = mpsc::channel(); + let (tm_tcp_client_tx, tm_tcp_client_rx) = mpsc::channel(); + let mut tcp_server = TcpServerTestbench::new(0); + let local_port = tcp_server.local_addr().port(); + let drop_signal = Arc::new(AtomicBool::new(false)); + let drop_signal_0 = drop_signal.clone(); + let mut tc_buf: [u8; 8] = [0; 8]; + let mut tm_buf: [u8; 8] = [0; 8]; + TEST_TC + .write_to_be_bytes(&mut tc_buf) + .expect("writing TM failed"); + TEST_TM + .write_to_be_bytes(&mut tm_buf) + .expect("writing TM failed"); + + let mut jh0 = thread::spawn(move || { + tcp_server + .check_for_connections(3) + .expect("no client connection detected"); + drop_signal_0.store(true, std::sync::atomic::Ordering::Relaxed); + }); + let mut spp_client = TcpSppClientStd::new( + 1, + tc_source_tx, + tm_tcp_client_rx, + VALID_IDS, + Duration::from_millis(30), + local_port, + ) + .expect("creating TCP SPP client failed"); + while !drop_signal.load(std::sync::atomic::Ordering::Relaxed) { + std::thread::sleep(Duration::from_millis(100)); + } + tm_tcp_client_tx + .send(PacketAsVec::new(0, tm_buf.to_vec())) + .unwrap(); + match spp_client.operation() { + Ok(ClientResult::ConnectionLost) => (), + Ok(ClientResult::Ok) => { + panic!("expected operation error"); + } + Err(ClientError::Io(e)) => { + println!("io error: {:?}", e); + if e.kind() != io::ErrorKind::ConnectionReset + && e.kind() != io::ErrorKind::ConnectionAborted + { + panic!("expected some disconnet error"); + } + } + _ => { + panic!("unexpected error") + } + }; + assert!(!spp_client.connected()); + jh0.join().unwrap(); + // spp_client.operation(); + tcp_server = TcpServerTestbench::new(local_port); + tm_tcp_client_tx + .send(PacketAsVec::new(0, tm_buf.to_vec())) + .unwrap(); + jh0 = thread::spawn(move || { + let mut stream = tcp_server + .check_for_connections(3) + .expect("no client connection detected"); + let mut read_buf: [u8; 64] = [0; 64]; + let read_bytes = tcp_server.try_reading_one_packet(&mut stream, 5, &mut read_buf); + if read_bytes == 0 { + panic!("did not receive expected data"); + } else { + assert_eq!(&tm_buf, &read_buf[0..read_bytes]); + } + }); + let result = spp_client.operation(); + println!("{:?}", result); + assert!(!spp_client.connected()); + assert!(result.is_ok()); + jh0.join().unwrap(); } }