diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 423da47..4b388e2 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -44,4 +44,7 @@ pub trait PacketTransport { /// For each received packet, the closure will be called with the packet as an argument. /// The function will return the number of received packets. fn receive(&mut self, f: F) -> Result; + + /// Close the connection, used for graceful shutdowns. + fn close(&mut self); } diff --git a/src/transport/serial.rs b/src/transport/serial.rs index 3050a98..f688ff8 100644 --- a/src/transport/serial.rs +++ b/src/transport/serial.rs @@ -2,6 +2,7 @@ use std::time::Duration; use cobs::CobsDecoderOwned; +use serialport::ReadMode; use crate::transport::PacketTransport; @@ -20,15 +21,22 @@ impl PacketTransportSerialCobs { /// the passed parameters. /// /// The `max_rx_packet_size` parameter defines the expected maximum size of a received packet. + /// On non-linux platforms, the serial timeout parameter has to be specified as well. pub fn new_from_params( port_name: &str, baud_rate: u32, max_rx_packet_size: usize, - serial_timeout: Duration, ) -> Result { + #[cfg(target_os = "linux")] + let mut serial = serialport::new(port_name, baud_rate).open_native()?; + #[cfg(target_os = "linux")] + serial.set_read_mode(ReadMode::Immediate)?; + #[cfg(not(target_os = "linux"))] let mut serial = serialport::new(port_name, baud_rate).open()?; - serial.set_timeout(serial_timeout)?; - Ok(Self::new(serial, CobsDecoderOwned::new(max_rx_packet_size))) + Ok(Self::new( + Box::new(serial), + CobsDecoderOwned::new(max_rx_packet_size), + )) } /// Generic constructor. @@ -41,6 +49,11 @@ impl PacketTransportSerialCobs { } } + /// Set the serial port timeout. + pub fn set_serial_timeout(&mut self, timeout: Duration) -> Result<(), serialport::Error> { + self.serial.set_timeout(timeout) + } + /// Send a packet. /// /// It encodes the packet using COBS encoding before sending it over the serial port. @@ -103,6 +116,8 @@ impl PacketTransport for PacketTransportSerialCobs { fn receive(&mut self, f: F) -> Result { self.receive(f) } + + fn close(&mut self) {} } #[cfg(test)] diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 21f3766..0164e61 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -21,17 +21,18 @@ impl PacketTransportTcpWithCobs { /// Generic constructor. /// /// The `tcp_stream` parameter is the underlying TCP stream which should already be connected. - pub fn new(tcp_stream: std::net::TcpStream, decoder: cobs::CobsDecoderOwned) -> Self { - tcp_stream.set_nonblocking(true).unwrap(); - tcp_stream - .set_read_timeout(Some(Duration::from_millis(100))) - .unwrap(); - Self { + pub fn new( + tcp_stream: std::net::TcpStream, + decoder: cobs::CobsDecoderOwned, + ) -> std::io::Result { + tcp_stream.set_nonblocking(true)?; + tcp_stream.set_read_timeout(Some(Duration::from_millis(100)))?; + Ok(Self { tcp_stream, decoder, reception_buffer: [0u8; 1024], log_decoding_errors: true, - } + }) } /// Send a packet. @@ -80,6 +81,10 @@ impl PacketTransportTcpWithCobs { log::warn!("COBS decoding error: {:?}", error); } } + + pub fn close(&mut self) -> std::io::Result<()> { + self.tcp_stream.shutdown(std::net::Shutdown::Both) + } } impl PacketTransport for PacketTransportTcpWithCobs { @@ -90,6 +95,10 @@ impl PacketTransport for PacketTransportTcpWithCobs { fn receive(&mut self, f: F) -> Result { self.receive(f) } + + fn close(&mut self) { + let _ = self.close(); + } } #[cfg(test)] @@ -105,7 +114,7 @@ mod tests { let addr = tcp_server.local_addr().unwrap(); let tcp_client = std::net::TcpStream::connect(addr).unwrap(); let mut transport = - PacketTransportTcpWithCobs::new(tcp_client, cobs::CobsDecoderOwned::new(1024)); + PacketTransportTcpWithCobs::new(tcp_client, cobs::CobsDecoderOwned::new(1024)).unwrap(); let packet = [1, 2, 3, 4]; transport.send(&packet).unwrap(); tcp_server @@ -128,7 +137,7 @@ mod tests { let addr = tcp_server.local_addr().unwrap(); let tcp_client = std::net::TcpStream::connect(addr).unwrap(); let mut transport = - PacketTransportTcpWithCobs::new(tcp_client, cobs::CobsDecoderOwned::new(1024)); + PacketTransportTcpWithCobs::new(tcp_client, cobs::CobsDecoderOwned::new(1024)).unwrap(); let rx_data = [1, 2, 3, 4]; let encoded_data = cobs::encode_vec_including_sentinels(&rx_data); tcp_server