From 0a6cdd894bb0f224e672cfd37a789d4bc02390e1 Mon Sep 17 00:00:00 2001 From: Robin Mueller Date: Wed, 13 May 2026 10:54:20 +0200 Subject: [PATCH] added UDP transport --- src/transport/udp.rs | 121 ++++++++++++++++++++++++++++++++----------- 1 file changed, 91 insertions(+), 30 deletions(-) diff --git a/src/transport/udp.rs b/src/transport/udp.rs index 88a2406..21cdcd3 100644 --- a/src/transport/udp.rs +++ b/src/transport/udp.rs @@ -1,6 +1,7 @@ //! # Packet transport via UDP use std::io::ErrorKind; +use std::net::ToSocketAddrs; use crate::transport::PacketTransport; @@ -10,6 +11,7 @@ use crate::transport::PacketTransport; pub struct PacketTransportUdp { /// Underlying UDP socket. pub socket: std::net::UdpSocket, + target: std::net::SocketAddr, reception_buffer: [u8; 4096], } @@ -18,22 +20,45 @@ impl PacketTransportUdp { /// /// The `socket` parameter is the underlying UDP stream which should already be connected. /// It will be set non-blocking by the construtor. - pub fn new(socket: std::net::UdpSocket) -> Result { + pub fn new( + socket: std::net::UdpSocket, + target: std::net::SocketAddr, + ) -> Result { socket.set_nonblocking(true)?; Ok(Self { socket, + target, reception_buffer: [0u8; 4096], }) } -} -impl PacketTransport for PacketTransportUdp { - fn send(&mut self, packet: &[u8]) -> Result<(), super::SendError> { - self.socket.send(packet).map_err(super::SendError::Io)?; + /// Update default target. + pub fn set_default_target(&mut self, target: std::net::SocketAddr) { + self.target = target; + } + + /// Send a packet to the target address specified in the constructor. + pub fn send(&mut self, packet: &[u8]) -> Result<(), super::SendError> { + self.socket + .send_to(packet, self.target) + .map_err(super::SendError::Io)?; Ok(()) } - fn receive(&mut self, mut f: F) -> Result { + /// Send packet to a specific address. + pub fn send_to( + &mut self, + packet: &[u8], + addr: A, + ) -> Result<(), super::SendError> { + self.socket + .send_to(packet, addr) + .map_err(super::SendError::Io)?; + Ok(()) + } + + /// Receive packets and call the provided callback for each received packet. + pub fn receive(&mut self, mut f: F) -> Result { let mut packets_received = 0; loop { match self.socket.recv_from(&mut self.reception_buffer) { @@ -51,6 +76,16 @@ impl PacketTransport for PacketTransportUdp { } Ok(packets_received) } +} + +impl PacketTransport for PacketTransportUdp { + fn send(&mut self, packet: &[u8]) -> Result<(), super::SendError> { + self.send(packet) + } + + fn receive(&mut self, f: F) -> Result { + self.receive(f) + } fn close(&mut self) {} } @@ -67,7 +102,7 @@ mod tests { let receiver_addr = receiver.local_addr().unwrap(); sender.connect(receiver_addr).unwrap(); - let mut transport = PacketTransportUdp::new(sender).unwrap(); + let mut transport = PacketTransportUdp::new(sender, receiver_addr).unwrap(); let payload = [1u8, 2, 3, 4]; transport.send(&payload).unwrap(); @@ -82,7 +117,10 @@ mod tests { #[test] fn receive_is_non_blocking_when_no_data_is_available() { let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let mut transport = PacketTransportUdp::new(receiver).unwrap(); + let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + + let receiver_addr = receiver.local_addr().unwrap(); + let mut transport = PacketTransportUdp::new(sender, receiver_addr).unwrap(); let start = std::time::Instant::now(); let mut callback_called = false; @@ -106,13 +144,16 @@ mod tests { #[test] fn basic_receive_test_single() { - let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let plain_sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let transport_socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let receiver_addr = receiver.local_addr().unwrap(); - sender.send_to(&[1u8, 2, 3, 4], receiver_addr).unwrap(); + let transport_addr = transport_socket.local_addr().unwrap(); - let mut transport = PacketTransportUdp::new(receiver).unwrap(); + let mut transport = PacketTransportUdp::new(transport_socket, transport_addr).unwrap(); + + plain_sender + .send_to(&[1u8, 2, 3, 4], transport_addr) + .unwrap(); let mut received_packets: Vec> = Vec::new(); let packets_received = transport @@ -126,15 +167,21 @@ mod tests { #[test] fn multi_packet_receive_test() { - let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let plain_sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let transport_socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let receiver_addr = receiver.local_addr().unwrap(); - sender.send_to(&[1u8, 2, 3, 4], receiver_addr).unwrap(); - sender.send_to(&[5u8, 6, 7, 8], receiver_addr).unwrap(); - sender.send_to(&[9u8, 10, 11, 12], receiver_addr).unwrap(); + let receiver_addr = transport_socket.local_addr().unwrap(); + let mut transport = PacketTransportUdp::new(transport_socket, receiver_addr).unwrap(); - let mut transport = PacketTransportUdp::new(receiver).unwrap(); + plain_sender + .send_to(&[1u8, 2, 3, 4], receiver_addr) + .unwrap(); + plain_sender + .send_to(&[5u8, 6, 7, 8], receiver_addr) + .unwrap(); + plain_sender + .send_to(&[9u8, 10, 11, 12], receiver_addr) + .unwrap(); let mut received_packets: Vec> = Vec::new(); let packets_received = transport @@ -150,26 +197,40 @@ mod tests { #[test] fn send_and_receive_test() { - let receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let sender = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + // Bind both sockets + let plain_receiver = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let sender_socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); - let receiver_addr = receiver.local_addr().unwrap(); - sender.connect(receiver_addr).unwrap(); + let receiver_addr = plain_receiver.local_addr().unwrap(); - let mut send_transport = PacketTransportUdp::new(sender).unwrap(); + // Build the transport around the sender socket + let mut send_transport = PacketTransportUdp::new(sender_socket, receiver_addr).unwrap(); + + // Send a packet let payload = [1u8, 2, 3, 4]; - send_transport.send(&payload).unwrap(); - let mut receive_transport = PacketTransportUdp::new(receiver).unwrap(); - let mut received_packets: Vec> = Vec::new(); + // Plain receiver reads what was sent + let mut buf = [0u8; 4096]; + let (n, src) = plain_receiver.recv_from(&mut buf).unwrap(); + assert_eq!(n, payload.len()); + assert_eq!(&buf[..n], &payload); - let packets_received = receive_transport + // Plain receiver sends back some test data + let reply = [5u8, 6, 7, 8]; + plain_receiver.send_to(&reply, src).unwrap(); + + // Read the reply via the transport (non-blocking socket, reply should already be in flight) + let mut received_packets: Vec> = Vec::new(); + // Small yield to ensure the reply has arrived on the loopback interface + std::thread::sleep(std::time::Duration::from_millis(10)); + + let packets_received = send_transport .receive(|packet| received_packets.push(packet.to_vec())) .unwrap(); assert_eq!(packets_received, 1); assert_eq!(received_packets.len(), 1); - assert_eq!(received_packets[0], payload.to_vec()); + assert_eq!(received_packets[0], reply.to_vec()); } }