diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 4b388e2..c4db470 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -5,6 +5,7 @@ //! something like CCSDS space packets over different transport mechanisms. pub mod serial; pub mod tcp; +pub mod udp; /// Generic send error. #[derive(Debug, thiserror::Error)] diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index cb8b021..7af22e2 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -7,6 +7,8 @@ use std::{ use crate::transport::PacketTransport; /// Packet transport via TCP with COBS encoding. +/// +/// Currently only allows a maxium packet size of 4096. pub struct PacketTransportTcpWithCobs { /// Underlying TCP stream. pub tcp_stream: std::net::TcpStream, @@ -14,7 +16,7 @@ pub struct PacketTransportTcpWithCobs { pub log_decoding_errors: bool, /// Decoder object. decoder: cobs::CobsDecoderOwned, - reception_buffer: [u8; 1024], + reception_buffer: [u8; 4096], } impl PacketTransportTcpWithCobs { @@ -30,7 +32,7 @@ impl PacketTransportTcpWithCobs { Ok(Self { tcp_stream, decoder, - reception_buffer: [0u8; 1024], + reception_buffer: [0u8; 4096], log_decoding_errors: true, }) } diff --git a/src/transport/udp.rs b/src/transport/udp.rs new file mode 100644 index 0000000..88a2406 --- /dev/null +++ b/src/transport/udp.rs @@ -0,0 +1,175 @@ +//! # Packet transport via UDP + +use std::io::ErrorKind; + +use crate::transport::PacketTransport; + +/// Generic packet transport via UDP. +/// +/// Currently only allows a maxium packet size of 4096. +pub struct PacketTransportUdp { + /// Underlying UDP socket. + pub socket: std::net::UdpSocket, + reception_buffer: [u8; 4096], +} + +impl PacketTransportUdp { + /// Generic constructor. + /// + /// The `socket` parameter is the underlying UDP stream which should already be connected. + /// It will be set non-blocking by the construtor. + pub fn new(socket: std::net::UdpSocket) -> Result { + socket.set_nonblocking(true)?; + Ok(Self { + socket, + reception_buffer: [0u8; 4096], + }) + } +} + +impl PacketTransport for PacketTransportUdp { + fn send(&mut self, packet: &[u8]) -> Result<(), super::SendError> { + self.socket.send(packet).map_err(super::SendError::Io)?; + Ok(()) + } + + fn receive(&mut self, mut f: F) -> Result { + let mut packets_received = 0; + loop { + match self.socket.recv_from(&mut self.reception_buffer) { + Ok((bytes, _)) => { + packets_received += 1; + f(&self.reception_buffer[..bytes]); + } + Err(e) => { + if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut { + break; + } + log::error!("UDP reception error: {e}"); + } + } + } + Ok(packets_received) + } + + fn close(&mut self) {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic_send_test() { + let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let receiver_addr = receiver.local_addr().unwrap(); + sender.connect(receiver_addr).unwrap(); + + let mut transport = PacketTransportUdp::new(sender).unwrap(); + let payload = [1u8, 2, 3, 4]; + + transport.send(&payload).unwrap(); + + let mut buf = [0u8; 16]; + let (len, from) = receiver.recv_from(&mut buf).unwrap(); + + assert_eq!(&buf[..len], &payload); + assert_eq!(from, transport.socket.local_addr().unwrap()); + } + + #[test] + fn receive_is_non_blocking_when_no_data_is_available() { + let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let mut transport = PacketTransportUdp::new(receiver).unwrap(); + + let start = std::time::Instant::now(); + let mut callback_called = false; + + let packets_received = transport + .receive(|_| { + callback_called = true; + }) + .unwrap(); + + let elapsed = start.elapsed(); + + assert_eq!(packets_received, 0); + assert!(!callback_called); + assert!( + elapsed < std::time::Duration::from_millis(50), + "receive() took too long for a non-blocking socket: {:?}", + elapsed + ); + } + + #[test] + fn basic_receive_test_single() { + let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let receiver_addr = receiver.local_addr().unwrap(); + sender.send_to(&[1u8, 2, 3, 4], receiver_addr).unwrap(); + + let mut transport = PacketTransportUdp::new(receiver).unwrap(); + + let mut received_packets: Vec> = Vec::new(); + let packets_received = transport + .receive(|packet| received_packets.push(packet.to_vec())) + .unwrap(); + + assert_eq!(packets_received, 1); + assert_eq!(received_packets.len(), 1); + assert_eq!(received_packets[0], vec![1u8, 2, 3, 4]); + } + + #[test] + fn multi_packet_receive_test() { + let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let receiver_addr = receiver.local_addr().unwrap(); + sender.send_to(&[1u8, 2, 3, 4], receiver_addr).unwrap(); + sender.send_to(&[5u8, 6, 7, 8], receiver_addr).unwrap(); + sender.send_to(&[9u8, 10, 11, 12], receiver_addr).unwrap(); + + let mut transport = PacketTransportUdp::new(receiver).unwrap(); + + let mut received_packets: Vec> = Vec::new(); + let packets_received = transport + .receive(|packet| received_packets.push(packet.to_vec())) + .unwrap(); + + assert_eq!(packets_received, 3); + assert_eq!(received_packets.len(), 3); + assert_eq!(received_packets[0], vec![1u8, 2, 3, 4]); + assert_eq!(received_packets[1], vec![5u8, 6, 7, 8]); + assert_eq!(received_packets[2], vec![9u8, 10, 11, 12]); + } + + #[test] + fn send_and_receive_test() { + let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let receiver_addr = receiver.local_addr().unwrap(); + sender.connect(receiver_addr).unwrap(); + + let mut send_transport = PacketTransportUdp::new(sender).unwrap(); + let payload = [1u8, 2, 3, 4]; + + send_transport.send(&payload).unwrap(); + + let mut receive_transport = PacketTransportUdp::new(receiver).unwrap(); + let mut received_packets: Vec> = Vec::new(); + + let packets_received = receive_transport + .receive(|packet| received_packets.push(packet.to_vec())) + .unwrap(); + + assert_eq!(packets_received, 1); + assert_eq!(received_packets.len(), 1); + assert_eq!(received_packets[0], payload.to_vec()); + } +}