i think mio would still be good here

This commit is contained in:
Robin Müller 2024-04-18 20:08:08 +02:00
parent c510df3154
commit 7d51c2813d
Signed by: muellerr
GPG Key ID: A649FB78196E3849
7 changed files with 152 additions and 80 deletions

2
Cargo.lock generated
View File

@ -485,9 +485,11 @@ dependencies = [
"humantime", "humantime",
"lazy_static", "lazy_static",
"log", "log",
"mio",
"num_enum", "num_enum",
"satrs", "satrs",
"satrs-mib", "satrs-mib",
"socket2",
"strum", "strum",
"thiserror", "thiserror",
"toml", "toml",

View File

@ -16,6 +16,8 @@ strum = { version = "0.26", features = ["derive"] }
thiserror = "1" thiserror = "1"
derive-new = "0.6" derive-new = "0.6"
num_enum = "0.7" num_enum = "0.7"
mio = "0.8"
socket2 = "0.5"
[dependencies.satrs] [dependencies.satrs]
version = "0.2.0-rc.3" version = "0.2.0-rc.3"

View File

@ -7,6 +7,7 @@ import struct
EXPERIMENT_ID = 278 EXPERIMENT_ID = 278
EXPERIMENT_APID = 1024 + EXPERIMENT_ID EXPERIMENT_APID = 1024 + EXPERIMENT_ID
TCP_SERVER_PORT = 4096
class EventSeverity(enum.IntEnum): class EventSeverity(enum.IntEnum):

View File

@ -18,7 +18,7 @@ from spacepackets.ccsds.time import CdsShortTimestamp
from tmtccmd import TcHandlerBase, ProcedureParamsWrapper from tmtccmd import TcHandlerBase, ProcedureParamsWrapper
from tmtccmd.core.base import BackendRequest from tmtccmd.core.base import BackendRequest
from tmtccmd.pus import VerificationWrapper from tmtccmd.pus import VerificationWrapper
from tmtccmd.tmtc import CcsdsTmHandler, GenericApidHandlerBase, TelemetryListT from tmtccmd.tmtc import CcsdsTmHandler, GenericApidHandlerBase
from tmtccmd.com import ComInterface from tmtccmd.com import ComInterface
from tmtccmd.config import ( from tmtccmd.config import (
CmdTreeNode, CmdTreeNode,
@ -44,13 +44,12 @@ from tmtccmd.tmtc import (
QueueWrapper, QueueWrapper,
) )
from spacepackets.seqcount import FileSeqCountProvider, PusFileSeqCountProvider from spacepackets.seqcount import FileSeqCountProvider, PusFileSeqCountProvider
from tcp_server import TcpServer
from tmtccmd.util.obj_id import ObjectIdDictT from tmtccmd.util.obj_id import ObjectIdDictT
from collections import deque
import socket
import pus_tc import pus_tc
from common import EXPERIMENT_APID, EventU32 from common import EXPERIMENT_APID, EventU32, TCP_SERVER_PORT
_LOGGER = logging.getLogger() _LOGGER = logging.getLogger()
@ -68,6 +67,10 @@ class SatRsConfigHook(HookBase):
assert self.cfg_path is not None assert self.cfg_path is not None
packet_id_list = [] packet_id_list = []
packet_id_list.append(PacketId(PacketType.TM, True, EXPERIMENT_APID)) packet_id_list.append(PacketId(PacketType.TM, True, EXPERIMENT_APID))
if com_if_key == "tcp_server":
tcp_server = TcpServer(TCP_SERVER_PORT)
return tcp_server
else:
cfg = create_com_interface_cfg_default( cfg = create_com_interface_cfg_default(
com_if_key=com_if_key, com_if_key=com_if_key,
json_cfg_path=self.cfg_path, json_cfg_path=self.cfg_path,

View File

@ -1,4 +1,6 @@
from typing import Any, Optional from typing import Any, Optional
import select
import time
import socket import socket
import logging import logging
from threading import Thread, Event, Lock from threading import Thread, Event, Lock
@ -16,6 +18,7 @@ class TcpServer(ComInterface):
self.port = port self.port = port
self._max_num_packets_in_tc_queue = 500 self._max_num_packets_in_tc_queue = 500
self._max_num_packets_in_tm_queue = 500 self._max_num_packets_in_tm_queue = 500
self._default_timeout_secs = 0.5
self._server_addr = ("localhost", self.port) self._server_addr = ("localhost", self.port)
self._tc_packet_queue = deque() self._tc_packet_queue = deque()
self._tm_packet_queue = deque() self._tm_packet_queue = deque()
@ -23,8 +26,11 @@ class TcpServer(ComInterface):
self._tm_lock = Lock() self._tm_lock = Lock()
self._kill_signal = Event() self._kill_signal = Event()
self._server_socket: Optional[socket.socket] = None self._server_socket: Optional[socket.socket] = None
self._server_thread = Thread(target=TcpServer._server_task, daemon=True) self._server_thread = Thread(target=self._server_task, daemon=True)
self._connected = False self._connected = False
self._conn_start = None
self._writing_done = False
self._reading_done = False
@property @property
def connected(self) -> bool: def connected(self) -> bool:
@ -47,31 +53,44 @@ class TcpServer(ComInterface):
""" """
if self.connected: if self.connected:
return return
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# We need to check the kill signal periodically to allow closing the server.
self.server_socket.settimeout(0.5)
self.server_socket.bind(self._server_addr)
self._connected = True self._connected = True
self._server_thread.start() self._server_thread.start()
def _server_task(self): def _server_task(self):
assert self._server_socket is not None self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True and not self._kill_signal.is_set(): self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# We need to check the kill signal periodically to allow closing the server.
self._server_socket.settimeout(self._default_timeout_secs)
self._server_socket.bind(self._server_addr)
self._server_socket.listen() self._server_socket.listen()
while True and not self._kill_signal.is_set():
try:
(conn_socket, conn_addr) = self._server_socket.accept() (conn_socket, conn_addr) = self._server_socket.accept()
_LOGGER.info("TCP client {} connected", conn_addr) self.conn_start = time.time()
while True: while True:
bytes_recvd = conn_socket.recv(4096) self._handle_connection(conn_socket, conn_addr)
if len(bytes_recvd) > 0: if (
print(f"Received bytes from TCP client: {bytes_recvd.decode()}") self._reading_done and self._writing_done
with self._tm_lock: ) or time.time() - self.conn_start > 0.5:
self._tm_packet_queue.append(bytes_recvd) print("reading and writing done")
elif len(bytes_recvd) == 0:
break break
else: except TimeoutError:
print("error receiving data from TCP client") print("timeout error")
continue
def _handle_connection(self, conn_socket: socket.socket, conn_addr: Any):
_LOGGER.info(f"TCP client {conn_addr} connected")
(readable, writable, _) = select.select(
[conn_socket],
[conn_socket],
[],
0.1,
)
# TODO: Why is the stupid conn socket never readable?
print(f"Writable: {writable}")
print(f"Readable: {readable}")
if writable and writable[0]:
queue_len = 0 queue_len = 0
with self._tc_lock: with self._tc_lock:
queue_len = len(self._tc_packet_queue) queue_len = len(self._tc_packet_queue)
@ -79,8 +98,23 @@ class TcpServer(ComInterface):
next_packet = bytes() next_packet = bytes()
with self._tc_lock: with self._tc_lock:
next_packet = self._tc_packet_queue.popleft() next_packet = self._tc_packet_queue.popleft()
if len(next_packet) > 0:
conn_socket.sendall(next_packet) conn_socket.sendall(next_packet)
queue_len -= 1 queue_len -= 1
self._writing_done = True
if readable and readable[0]:
print("reading shit")
while True:
bytes_recvd = conn_socket.recv(4096)
if len(bytes_recvd) > 0:
print(f"Received bytes from TCP client: {bytes_recvd.decode()}")
with self._tm_lock:
self._tm_packet_queue.append(bytes_recvd)
elif len(bytes_recvd) == 0:
self._reading_done = True
break
else:
print("error receiving data from TCP client")
def is_open(self) -> bool: def is_open(self) -> bool:
"""Can be used to check whether the communication interface is open. This is useful if """Can be used to check whether the communication interface is open. This is useful if
@ -103,7 +137,6 @@ class TcpServer(ComInterface):
:raises SendError: Sending failed for some reason. :raises SendError: Sending failed for some reason.
""" """
with self._tc_lock: with self._tc_lock:
# Deque is thread-safe according to the documentation.. so this should be fine.
if len(self._tc_packet_queue) >= self._max_num_packets_in_tc_queue: if len(self._tc_packet_queue) >= self._max_num_packets_in_tc_queue:
# Remove oldest packet # Remove oldest packet
self._tc_packet_queue.popleft() self._tc_packet_queue.popleft()

View File

@ -1,25 +1,20 @@
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::mpsc; use std::sync::mpsc;
use std::time::Duration;
use ops_sat_rs::config::SPP_CLIENT_WIRETAPPING_RX; use mio::net::TcpStream;
use mio::{Events, Interest, Poll, Token};
use ops_sat_rs::config::tasks::STOP_CHECK_FREQUENCY;
use ops_sat_rs::config::{SPP_CLIENT_WIRETAPPING_RX, TCP_SPP_SERVER_PORT};
use satrs::encoding::ccsds::parse_buffer_for_ccsds_space_packets; use satrs::encoding::ccsds::parse_buffer_for_ccsds_space_packets;
use satrs::queue::GenericSendError; use satrs::queue::GenericSendError;
use satrs::spacepackets::PacketId; use satrs::spacepackets::PacketId;
use satrs::tmtc::PacketAsVec; use satrs::tmtc::PacketAsVec;
use satrs::ComponentId; use satrs::ComponentId;
use std::net::TcpStream;
use thiserror::Error; use thiserror::Error;
use super::{SimpleSpValidator, TcpComponent}; use super::{SimpleSpValidator, TcpComponent};
#[derive(Debug)]
pub enum ConnectionResult {
Connected(TcpStream),
Timeout,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ClientError { pub enum ClientError {
#[error("send error: {0}")] #[error("send error: {0}")]
@ -30,6 +25,10 @@ pub enum ClientError {
pub struct TcpSppClient { pub struct TcpSppClient {
id: ComponentId, id: ComponentId,
poll: Poll,
events: Events,
// Optional to allow periodic reconnection attempts on the TCP server.
client: TcpStream,
read_buf: [u8; 4096], read_buf: [u8; 4096],
tm_tcp_client_rx: mpsc::Receiver<PacketAsVec>, tm_tcp_client_rx: mpsc::Receiver<PacketAsVec>,
server_addr: SocketAddr, server_addr: SocketAddr,
@ -45,9 +44,20 @@ impl TcpSppClient {
valid_ids: &'static [PacketId], valid_ids: &'static [PacketId],
port: u16, port: u16,
) -> io::Result<Self> { ) -> io::Result<Self> {
let poll = Poll::new()?;
let events = Events::with_capacity(128);
let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port); let server_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
let mut client = TcpStream::connect(server_addr)?;
poll.registry().register(
&mut client,
Token(0),
Interest::READABLE | Interest::WRITABLE,
)?;
Ok(Self { Ok(Self {
id, id,
poll,
events,
client,
read_buf: [0; 4096], read_buf: [0; 4096],
server_addr, server_addr,
tm_tcp_client_rx, tm_tcp_client_rx,
@ -56,59 +66,65 @@ impl TcpSppClient {
}) })
} }
pub fn attempt_connection( pub fn periodic_operation(&mut self) -> Result<(), ClientError> {
&mut self, // if self.client.is_some() {
timeout: Duration, return self.perform_regular_operation();
) -> Result<ConnectionResult, ClientError> { /*
match TcpStream::connect_timeout(&self.server_addr, timeout) { } else {
Ok(client) => Ok(ConnectionResult::Connected(client)), log::info!("attempting reconnect");
Err(e) => { let client_result = self.attempt_connection();
if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut { match client_result {
return Ok(ConnectionResult::Timeout); Ok(_) => {
self.perform_regular_operation()?;
} }
Err(e.into()) Err(ref e) => {
log::warn!(
"connection to TCP server {} failed: {}",
self.server_addr,
e
);
} }
} }
} }
*/
pub fn operation(&mut self, timeout: Duration) {
match self.attempt_connection(timeout) {
Ok(result) => {
if let ConnectionResult::Connected(mut client) = result {
if let Err(e) = self.handle_client_operation(&mut client) {
log::error!("error handling TCP client operation: {}", e)
}
drop(client);
std::thread::sleep(timeout);
}
}
Err(e) => {
log::error!("TCP client error: {}", e);
}
}
}
pub fn handle_client_operation(&mut self, client: &mut TcpStream) -> Result<(), ClientError> {
self.write_to_server(client)?;
client.shutdown(std::net::Shutdown::Write)?;
self.read_from_server(client)?;
Ok(()) Ok(())
} }
pub fn read_from_server(&mut self, client: &mut TcpStream) -> Result<(), ClientError> { pub fn perform_regular_operation(&mut self) -> Result<(), ClientError> {
match client.read(&mut self.read_buf) { self.poll
Ok(0) => (), .poll(&mut self.events, Some(STOP_CHECK_FREQUENCY))?;
let events: Vec<mio::event::Event> = self.events.iter().cloned().collect();
for event in events {
if event.token() == Token(0) {
if event.is_readable() {
log::info!("readable event");
self.check_conn_status()?;
self.read_from_server()?;
}
if event.is_writable() {
log::info!("writable event");
self.check_conn_status()?;
self.write_to_server()?;
}
}
}
Ok(())
}
pub fn read_from_server(&mut self) -> Result<(), ClientError> {
match self.client.read(&mut self.read_buf) {
Ok(0) => (), // return Err(io::Error::from(io::ErrorKind::BrokenPipe).into()),
Ok(read_bytes) => self.handle_read_bytstream(read_bytes)?, Ok(read_bytes) => self.handle_read_bytstream(read_bytes)?,
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
} }
Ok(()) Ok(())
} }
pub fn write_to_server(&mut self, client: &mut TcpStream) -> io::Result<()> { pub fn write_to_server(&mut self) -> io::Result<()> {
loop { loop {
match self.tm_tcp_client_rx.try_recv() { match self.tm_tcp_client_rx.try_recv() {
Ok(tm) => { Ok(tm) => {
client.write_all(&tm.packet)?; self.client.write_all(&tm.packet)?;
} }
Err(e) => match e { Err(e) => match e {
mpsc::TryRecvError::Empty => break, mpsc::TryRecvError::Empty => break,
@ -122,6 +138,18 @@ impl TcpSppClient {
Ok(()) Ok(())
} }
pub fn check_conn_status(&mut self) -> io::Result<bool> {
match self.client.peer_addr() {
Ok(_) => Ok(true),
Err(e) => {
if e.kind() == io::ErrorKind::NotConnected {
return Ok(false);
}
Err(e)
}
}
}
pub fn handle_read_bytstream(&mut self, read_bytes: usize) -> Result<(), ClientError> { pub fn handle_read_bytstream(&mut self, read_bytes: usize) -> Result<(), ClientError> {
let mut dummy = 0; let mut dummy = 0;
if SPP_CLIENT_WIRETAPPING_RX { if SPP_CLIENT_WIRETAPPING_RX {

View File

@ -232,7 +232,10 @@ fn main() {
.spawn(move || { .spawn(move || {
info!("Running TCP SPP client"); info!("Running TCP SPP client");
loop { loop {
tcp_spp_client.operation(STOP_CHECK_FREQUENCY); let result = tcp_spp_client.periodic_operation();
if let Err(e) = result {
log::error!("TCP SPP client error: {}", e);
}
if tcp_client_stop_signal.load(std::sync::atomic::Ordering::Relaxed) { if tcp_client_stop_signal.load(std::sync::atomic::Ordering::Relaxed) {
break; break;
} }