From 7d51c2813d7c6f711fe1802bba71332a79ee32e5 Mon Sep 17 00:00:00 2001 From: Robin Mueller Date: Thu, 18 Apr 2024 20:08:08 +0200 Subject: [PATCH] i think mio would still be good here --- Cargo.lock | 2 + Cargo.toml | 2 + pytmtc/common.py | 1 + pytmtc/main.py | 25 ++++--- pytmtc/tcp_server.py | 77 ++++++++++++++------ src/interface/tcp_spp_client.rs | 120 ++++++++++++++++++++------------ src/main.rs | 5 +- 7 files changed, 152 insertions(+), 80 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f6c288d..90468ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -485,9 +485,11 @@ dependencies = [ "humantime", "lazy_static", "log", + "mio", "num_enum", "satrs", "satrs-mib", + "socket2", "strum", "thiserror", "toml", diff --git a/Cargo.toml b/Cargo.toml index d8e7bd5..bc552ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ strum = { version = "0.26", features = ["derive"] } thiserror = "1" derive-new = "0.6" num_enum = "0.7" +mio = "0.8" +socket2 = "0.5" [dependencies.satrs] version = "0.2.0-rc.3" diff --git a/pytmtc/common.py b/pytmtc/common.py index 56b469f..e0e0c0c 100644 --- a/pytmtc/common.py +++ b/pytmtc/common.py @@ -7,6 +7,7 @@ import struct EXPERIMENT_ID = 278 EXPERIMENT_APID = 1024 + EXPERIMENT_ID +TCP_SERVER_PORT = 4096 class EventSeverity(enum.IntEnum): diff --git a/pytmtc/main.py b/pytmtc/main.py index deb1fbe..deadf9a 100755 --- a/pytmtc/main.py +++ b/pytmtc/main.py @@ -18,7 +18,7 @@ from spacepackets.ccsds.time import CdsShortTimestamp from tmtccmd import TcHandlerBase, ProcedureParamsWrapper from tmtccmd.core.base import BackendRequest from tmtccmd.pus import VerificationWrapper -from tmtccmd.tmtc import CcsdsTmHandler, GenericApidHandlerBase, TelemetryListT +from tmtccmd.tmtc import CcsdsTmHandler, GenericApidHandlerBase from tmtccmd.com import ComInterface from tmtccmd.config import ( CmdTreeNode, @@ -44,13 +44,12 @@ from tmtccmd.tmtc import ( QueueWrapper, ) from spacepackets.seqcount import FileSeqCountProvider, PusFileSeqCountProvider +from tcp_server import TcpServer from tmtccmd.util.obj_id import ObjectIdDictT -from collections import deque -import socket import pus_tc -from common import EXPERIMENT_APID, EventU32 +from common import EXPERIMENT_APID, EventU32, TCP_SERVER_PORT _LOGGER = logging.getLogger() @@ -68,13 +67,17 @@ class SatRsConfigHook(HookBase): assert self.cfg_path is not None packet_id_list = [] packet_id_list.append(PacketId(PacketType.TM, True, EXPERIMENT_APID)) - cfg = create_com_interface_cfg_default( - com_if_key=com_if_key, - json_cfg_path=self.cfg_path, - space_packet_ids=packet_id_list, - ) - assert cfg is not None - return create_com_interface_default(cfg) + if com_if_key == "tcp_server": + tcp_server = TcpServer(TCP_SERVER_PORT) + return tcp_server + else: + cfg = create_com_interface_cfg_default( + com_if_key=com_if_key, + json_cfg_path=self.cfg_path, + space_packet_ids=packet_id_list, + ) + assert cfg is not None + return create_com_interface_default(cfg) def get_command_definitions(self) -> CmdTreeNode: """This function should return the root node of the command definition tree.""" diff --git a/pytmtc/tcp_server.py b/pytmtc/tcp_server.py index 1e4570f..0c2c6dc 100644 --- a/pytmtc/tcp_server.py +++ b/pytmtc/tcp_server.py @@ -1,4 +1,6 @@ from typing import Any, Optional +import select +import time import socket import logging from threading import Thread, Event, Lock @@ -16,6 +18,7 @@ class TcpServer(ComInterface): self.port = port self._max_num_packets_in_tc_queue = 500 self._max_num_packets_in_tm_queue = 500 + self._default_timeout_secs = 0.5 self._server_addr = ("localhost", self.port) self._tc_packet_queue = deque() self._tm_packet_queue = deque() @@ -23,8 +26,11 @@ class TcpServer(ComInterface): self._tm_lock = Lock() self._kill_signal = Event() self._server_socket: Optional[socket.socket] = None - self._server_thread = Thread(target=TcpServer._server_task, daemon=True) + self._server_thread = Thread(target=self._server_task, daemon=True) self._connected = False + self._conn_start = None + self._writing_done = False + self._reading_done = False @property def connected(self) -> bool: @@ -47,31 +53,44 @@ class TcpServer(ComInterface): """ if self.connected: return - self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # We need to check the kill signal periodically to allow closing the server. - self.server_socket.settimeout(0.5) - self.server_socket.bind(self._server_addr) self._connected = True self._server_thread.start() def _server_task(self): - assert self._server_socket is not None + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # We need to check the kill signal periodically to allow closing the server. + self._server_socket.settimeout(self._default_timeout_secs) + self._server_socket.bind(self._server_addr) + self._server_socket.listen() while True and not self._kill_signal.is_set(): - self._server_socket.listen() - (conn_socket, conn_addr) = self._server_socket.accept() - _LOGGER.info("TCP client {} connected", conn_addr) - while True: - bytes_recvd = conn_socket.recv(4096) - if len(bytes_recvd) > 0: - print(f"Received bytes from TCP client: {bytes_recvd.decode()}") - with self._tm_lock: - self._tm_packet_queue.append(bytes_recvd) - elif len(bytes_recvd) == 0: - break - else: - print("error receiving data from TCP client") + try: + (conn_socket, conn_addr) = self._server_socket.accept() + self.conn_start = time.time() + while True: + self._handle_connection(conn_socket, conn_addr) + if ( + self._reading_done and self._writing_done + ) or time.time() - self.conn_start > 0.5: + print("reading and writing done") + break + except TimeoutError: + print("timeout error") + continue + def _handle_connection(self, conn_socket: socket.socket, conn_addr: Any): + _LOGGER.info(f"TCP client {conn_addr} connected") + (readable, writable, _) = select.select( + [conn_socket], + [conn_socket], + [], + 0.1, + ) + + # TODO: Why is the stupid conn socket never readable? + print(f"Writable: {writable}") + print(f"Readable: {readable}") + if writable and writable[0]: queue_len = 0 with self._tc_lock: queue_len = len(self._tc_packet_queue) @@ -79,8 +98,23 @@ class TcpServer(ComInterface): next_packet = bytes() with self._tc_lock: next_packet = self._tc_packet_queue.popleft() - conn_socket.sendall(next_packet) + if len(next_packet) > 0: + conn_socket.sendall(next_packet) queue_len -= 1 + self._writing_done = True + if readable and readable[0]: + print("reading shit") + while True: + bytes_recvd = conn_socket.recv(4096) + if len(bytes_recvd) > 0: + print(f"Received bytes from TCP client: {bytes_recvd.decode()}") + with self._tm_lock: + self._tm_packet_queue.append(bytes_recvd) + elif len(bytes_recvd) == 0: + self._reading_done = True + break + else: + print("error receiving data from TCP client") def is_open(self) -> bool: """Can be used to check whether the communication interface is open. This is useful if @@ -103,7 +137,6 @@ class TcpServer(ComInterface): :raises SendError: Sending failed for some reason. """ with self._tc_lock: - # Deque is thread-safe according to the documentation.. so this should be fine. if len(self._tc_packet_queue) >= self._max_num_packets_in_tc_queue: # Remove oldest packet self._tc_packet_queue.popleft() diff --git a/src/interface/tcp_spp_client.rs b/src/interface/tcp_spp_client.rs index 186491a..6f208c5 100644 --- a/src/interface/tcp_spp_client.rs +++ b/src/interface/tcp_spp_client.rs @@ -1,25 +1,20 @@ use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::mpsc; -use std::time::Duration; -use ops_sat_rs::config::SPP_CLIENT_WIRETAPPING_RX; +use mio::net::TcpStream; +use mio::{Events, Interest, Poll, Token}; +use ops_sat_rs::config::tasks::STOP_CHECK_FREQUENCY; +use ops_sat_rs::config::{SPP_CLIENT_WIRETAPPING_RX, TCP_SPP_SERVER_PORT}; use satrs::encoding::ccsds::parse_buffer_for_ccsds_space_packets; use satrs::queue::GenericSendError; use satrs::spacepackets::PacketId; use satrs::tmtc::PacketAsVec; use satrs::ComponentId; -use std::net::TcpStream; use thiserror::Error; use super::{SimpleSpValidator, TcpComponent}; -#[derive(Debug)] -pub enum ConnectionResult { - Connected(TcpStream), - Timeout, -} - #[derive(Debug, Error)] pub enum ClientError { #[error("send error: {0}")] @@ -30,6 +25,10 @@ pub enum ClientError { pub struct TcpSppClient { id: ComponentId, + poll: Poll, + events: Events, + // Optional to allow periodic reconnection attempts on the TCP server. + client: TcpStream, read_buf: [u8; 4096], tm_tcp_client_rx: mpsc::Receiver, server_addr: SocketAddr, @@ -45,9 +44,20 @@ impl TcpSppClient { valid_ids: &'static [PacketId], port: u16, ) -> io::Result { + let poll = Poll::new()?; + let events = Events::with_capacity(128); let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); + let mut client = TcpStream::connect(server_addr)?; + poll.registry().register( + &mut client, + Token(0), + Interest::READABLE | Interest::WRITABLE, + )?; Ok(Self { id, + poll, + events, + client, read_buf: [0; 4096], server_addr, tm_tcp_client_rx, @@ -56,59 +66,65 @@ impl TcpSppClient { }) } - pub fn attempt_connection( - &mut self, - timeout: Duration, - ) -> Result { - match TcpStream::connect_timeout(&self.server_addr, timeout) { - Ok(client) => Ok(ConnectionResult::Connected(client)), - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut { - return Ok(ConnectionResult::Timeout); + pub fn periodic_operation(&mut self) -> Result<(), ClientError> { + // if self.client.is_some() { + return self.perform_regular_operation(); + /* + } else { + log::info!("attempting reconnect"); + let client_result = self.attempt_connection(); + match client_result { + Ok(_) => { + self.perform_regular_operation()?; } - Err(e.into()) - } - } - } - - pub fn operation(&mut self, timeout: Duration) { - match self.attempt_connection(timeout) { - Ok(result) => { - if let ConnectionResult::Connected(mut client) = result { - if let Err(e) = self.handle_client_operation(&mut client) { - log::error!("error handling TCP client operation: {}", e) - } - drop(client); - std::thread::sleep(timeout); + Err(ref e) => { + log::warn!( + "connection to TCP server {} failed: {}", + self.server_addr, + e + ); } } - Err(e) => { - log::error!("TCP client error: {}", e); - } } - } - - pub fn handle_client_operation(&mut self, client: &mut TcpStream) -> Result<(), ClientError> { - self.write_to_server(client)?; - client.shutdown(std::net::Shutdown::Write)?; - self.read_from_server(client)?; + */ Ok(()) } - pub fn read_from_server(&mut self, client: &mut TcpStream) -> Result<(), ClientError> { - match client.read(&mut self.read_buf) { - Ok(0) => (), + pub fn perform_regular_operation(&mut self) -> Result<(), ClientError> { + self.poll + .poll(&mut self.events, Some(STOP_CHECK_FREQUENCY))?; + let events: Vec = self.events.iter().cloned().collect(); + for event in events { + if event.token() == Token(0) { + if event.is_readable() { + log::info!("readable event"); + self.check_conn_status()?; + self.read_from_server()?; + } + if event.is_writable() { + log::info!("writable event"); + self.check_conn_status()?; + self.write_to_server()?; + } + } + } + Ok(()) + } + + pub fn read_from_server(&mut self) -> Result<(), ClientError> { + match self.client.read(&mut self.read_buf) { + Ok(0) => (), // return Err(io::Error::from(io::ErrorKind::BrokenPipe).into()), Ok(read_bytes) => self.handle_read_bytstream(read_bytes)?, Err(e) => return Err(e.into()), } Ok(()) } - pub fn write_to_server(&mut self, client: &mut TcpStream) -> io::Result<()> { + pub fn write_to_server(&mut self) -> io::Result<()> { loop { match self.tm_tcp_client_rx.try_recv() { Ok(tm) => { - client.write_all(&tm.packet)?; + self.client.write_all(&tm.packet)?; } Err(e) => match e { mpsc::TryRecvError::Empty => break, @@ -122,6 +138,18 @@ impl TcpSppClient { Ok(()) } + pub fn check_conn_status(&mut self) -> io::Result { + match self.client.peer_addr() { + Ok(_) => Ok(true), + Err(e) => { + if e.kind() == io::ErrorKind::NotConnected { + return Ok(false); + } + Err(e) + } + } + } + pub fn handle_read_bytstream(&mut self, read_bytes: usize) -> Result<(), ClientError> { let mut dummy = 0; if SPP_CLIENT_WIRETAPPING_RX { diff --git a/src/main.rs b/src/main.rs index d6a65c1..8f8f7b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -232,7 +232,10 @@ fn main() { .spawn(move || { info!("Running TCP SPP client"); loop { - tcp_spp_client.operation(STOP_CHECK_FREQUENCY); + let result = tcp_spp_client.periodic_operation(); + if let Err(e) = result { + log::error!("TCP SPP client error: {}", e); + } if tcp_client_stop_signal.load(std::sync::atomic::Ordering::Relaxed) { break; }