From 299d37d894c1959b278b15538e3f57aa0a4ba187 Mon Sep 17 00:00:00 2001 From: Robin Mueller Date: Mon, 4 Dec 2023 15:54:35 +0100 Subject: [PATCH] introduce new TLV abstractions --- CHANGELOG.md | 1 + src/cfdp/mod.rs | 6 +- src/cfdp/pdu/eof.rs | 4 +- src/cfdp/pdu/finished.rs | 4 +- src/cfdp/pdu/metadata.rs | 2 +- src/cfdp/tlv/mod.rs | 230 ++++++++++++++++++++++++------------ src/cfdp/tlv/msg_to_user.rs | 24 +++- 7 files changed, 186 insertions(+), 85 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88862a7..8bfbdc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add `WritablePduPacket` trait which is a common trait of all CFDP PDU implementations. - Add `CfdpPdu` trait which exposes fields and attributes common to all CFDP PDUs. +- Add `GenericTlv` and `WritableTlv` trait as abstractions for the various TLV types. ## Fixed diff --git a/src/cfdp/mod.rs b/src/cfdp/mod.rs index 4e65ec9..0cd7897 100644 --- a/src/cfdp/mod.rs +++ b/src/cfdp/mod.rs @@ -205,14 +205,14 @@ impl Display for TlvLvError { TlvLvError::ByteConversion(e) => { write!(f, "tlv or lv byte conversion: {}", e) } - TlvLvError::InvalidTlvTypeField((found, expected)) => { + TlvLvError::InvalidTlvTypeField { found, expected } => { write!( f, - "invalid TLV type field, found {found}, possibly expected {expected:?}" + "invalid TLV type field, found {found}, expected {expected:?}" ) } TlvLvError::InvalidValueLength(len) => { - write!(f, "invalid value length {len} detected") + write!(f, "invalid value length {len}") } TlvLvError::SecondNameMissing => { write!(f, "second name missing for filestore request or response") diff --git a/src/cfdp/pdu/eof.rs b/src/cfdp/pdu/eof.rs index 2ee5fa4..627c5a7 100644 --- a/src/cfdp/pdu/eof.rs +++ b/src/cfdp/pdu/eof.rs @@ -2,7 +2,7 @@ use crate::cfdp::pdu::{ add_pdu_crc, generic_length_checks_pdu_deserialization, read_fss_field, write_fss_field, FileDirectiveType, PduError, PduHeader, }; -use crate::cfdp::tlv::EntityIdTlv; +use crate::cfdp::tlv::{EntityIdTlv, WritableTlv}; use crate::cfdp::{ConditionCode, CrcFlag, Direction, LargeFileFlag}; use crate::ByteConversionError; #[cfg(feature = "serde")] @@ -147,7 +147,7 @@ impl WritablePduPacket for EofPdu { &mut buf[current_idx..], )?; if let Some(fault_location) = self.fault_location { - current_idx += fault_location.write_to_be_bytes(buf)?; + current_idx += fault_location.write_to_bytes(buf)?; } if self.crc_flag() == CrcFlag::WithCrc { current_idx = add_pdu_crc(buf, current_idx); diff --git a/src/cfdp/pdu/finished.rs b/src/cfdp/pdu/finished.rs index 6a5b912..b115f8e 100644 --- a/src/cfdp/pdu/finished.rs +++ b/src/cfdp/pdu/finished.rs @@ -1,7 +1,7 @@ use crate::cfdp::pdu::{ add_pdu_crc, generic_length_checks_pdu_deserialization, FileDirectiveType, PduError, PduHeader, }; -use crate::cfdp::tlv::{EntityIdTlv, Tlv, TlvType, TlvTypeField}; +use crate::cfdp::tlv::{EntityIdTlv, GenericTlv, Tlv, TlvType, TlvTypeField, WritableTlv}; use crate::cfdp::{ConditionCode, CrcFlag, Direction, PduType, TlvLvError}; use crate::ByteConversionError; use num_enum::{IntoPrimitive, TryFromPrimitive}; @@ -255,7 +255,7 @@ impl WritablePduPacket for FinishedPdu<'_> { current_idx += fs_responses.len(); } if let Some(fault_location) = self.fault_location { - current_idx += fault_location.write_to_be_bytes(&mut buf[current_idx..])?; + current_idx += fault_location.write_to_bytes(&mut buf[current_idx..])?; } if self.crc_flag() == CrcFlag::WithCrc { current_idx = add_pdu_crc(buf, current_idx); diff --git a/src/cfdp/pdu/metadata.rs b/src/cfdp/pdu/metadata.rs index 5ed70d0..90303d3 100644 --- a/src/cfdp/pdu/metadata.rs +++ b/src/cfdp/pdu/metadata.rs @@ -3,7 +3,7 @@ use crate::cfdp::pdu::{ add_pdu_crc, generic_length_checks_pdu_deserialization, read_fss_field, write_fss_field, FileDirectiveType, PduError, PduHeader, }; -use crate::cfdp::tlv::Tlv; +use crate::cfdp::tlv::{Tlv, WritableTlv}; use crate::cfdp::{ChecksumType, CrcFlag, Direction, LargeFileFlag, PduType}; use crate::ByteConversionError; #[cfg(feature = "alloc")] diff --git a/src/cfdp/tlv/mod.rs b/src/cfdp/tlv/mod.rs index d9013ee..ae83094 100644 --- a/src/cfdp/tlv/mod.rs +++ b/src/cfdp/tlv/mod.rs @@ -5,6 +5,10 @@ use crate::cfdp::lv::{ use crate::cfdp::TlvLvError; use crate::util::{UnsignedByteField, UnsignedByteFieldError, UnsignedEnum}; use crate::ByteConversionError; +#[cfg(feature = "alloc")] +use alloc::vec; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; use num_enum::{IntoPrimitive, TryFromPrimitive}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -13,6 +17,38 @@ pub mod msg_to_user; pub const MIN_TLV_LEN: usize = 2; +pub trait GenericTlv { + fn tlv_type_field(&self) -> TlvTypeField; + + /// Checks whether the type field contains one of the standard types specified in the CFDP + /// standard and is part of the [TlvType] enum. + fn is_standard_tlv(&self) -> bool { + if let TlvTypeField::Standard(_) = self.tlv_type_field() { + return true; + } + false + } + + /// Returns the standard TLV type if the TLV field is not a custom field + fn tlv_type(&self) -> Option { + if let TlvTypeField::Standard(tlv_type) = self.tlv_type_field() { + Some(tlv_type) + } else { + None + } + } +} + +pub trait WritableTlv { + fn write_to_bytes(&self, buf: &mut [u8]) -> Result; + fn len_written(&self) -> usize; + fn to_vec(&self) -> Vec { + let mut buf = vec![0; self.len_written()]; + self.write_to_bytes(&mut buf).unwrap(); + buf + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, TryFromPrimitive, IntoPrimitive)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[repr(u8)] @@ -103,53 +139,24 @@ impl<'data> Tlv<'data> { } } - /// Checks whether the type field contains one of the standard types specified in the CFDP - /// standard and is part of the [TlvType] enum. - pub fn is_standard_tlv(&self) -> bool { - if let TlvTypeField::Standard(_) = self.tlv_type_field { - return true; - } - false - } - - /// Returns the standard TLV type if the TLV field is not a custom field - pub fn tlv_type(&self) -> Option { - if let TlvTypeField::Standard(tlv_type) = self.tlv_type_field { - Some(tlv_type) - } else { - None - } - } - - pub fn tlv_type_field(&self) -> TlvTypeField { - self.tlv_type_field - } - - pub fn write_to_bytes(&self, buf: &mut [u8]) -> Result { - generic_len_check_data_serialization(buf, self.value().len(), MIN_TLV_LEN)?; - buf[0] = self.tlv_type_field.into(); - self.lv.write_to_be_bytes_no_len_check(&mut buf[1..]); - Ok(self.len_full()) - } - pub fn value(&self) -> &[u8] { self.lv.value() } + /// Checks whether the value field is empty. + pub fn is_empty(&self) -> bool { + self.value().is_empty() + } + /// Helper method to retrieve the length of the value. Simply calls the [slice::len] method of /// [Self::value] pub fn len_value(&self) -> usize { - self.lv.len_value() + self.value().len() } /// Returns the full raw length, including the length byte. pub fn len_full(&self) -> usize { - self.lv.len_full() + 1 - } - - /// Checks whether the value field is empty. - pub fn is_empty(&self) -> bool { - self.lv.is_empty() + self.len_value() + 2 } /// Creates a TLV give a raw bytestream. Please note that is is not necessary to pass the @@ -175,6 +182,24 @@ impl<'data> Tlv<'data> { } } +impl WritableTlv for Tlv<'_> { + fn write_to_bytes(&self, buf: &mut [u8]) -> Result { + generic_len_check_data_serialization(buf, self.value().len(), MIN_TLV_LEN)?; + buf[0] = self.tlv_type_field.into(); + self.lv.write_to_be_bytes_no_len_check(&mut buf[1..]); + Ok(self.len_full()) + } + fn len_written(&self) -> usize { + self.len_full() + } +} + +impl GenericTlv for Tlv<'_> { + fn tlv_type_field(&self) -> TlvTypeField { + self.tlv_type_field + } +} + pub(crate) fn verify_tlv_type(raw_type: u8, expected_tlv_type: TlvType) -> Result<(), TlvLvError> { let tlv_type = TlvType::try_from(raw_type).map_err(|_| TlvLvError::InvalidTlvTypeField { found: raw_type, @@ -222,13 +247,6 @@ impl EntityIdTlv { 2 + self.entity_id.size() } - pub fn write_to_be_bytes(&self, buf: &mut [u8]) -> Result { - Self::len_check(buf)?; - buf[0] = TlvType::EntityId as u8; - buf[1] = self.entity_id.size() as u8; - Ok(2 + self.entity_id.write_to_be_bytes(&mut buf[2..])?) - } - pub fn from_bytes(buf: &[u8]) -> Result { Self::len_check(buf)?; verify_tlv_type(buf[0], TlvType::EntityId)?; @@ -254,6 +272,25 @@ impl EntityIdTlv { } } +impl WritableTlv for EntityIdTlv { + fn write_to_bytes(&self, buf: &mut [u8]) -> Result { + Self::len_check(buf)?; + buf[0] = TlvType::EntityId as u8; + buf[1] = self.entity_id.size() as u8; + Ok(2 + self.entity_id.write_to_be_bytes(&mut buf[2..])?) + } + + fn len_written(&self) -> usize { + self.len_full() + } +} + +impl GenericTlv for EntityIdTlv { + fn tlv_type_field(&self) -> TlvTypeField { + TlvTypeField::Standard(TlvType::EntityId) + } +} + impl<'data> TryFrom> for EntityIdTlv { type Error = TlvLvError; @@ -426,31 +463,6 @@ impl<'first_name, 'second_name> FilestoreRequestTlv<'first_name, 'second_name> { 2 + self.len_value() } - pub fn write_to_bytes(&self, buf: &mut [u8]) -> Result { - if buf.len() < self.len_full() { - return Err(ByteConversionError::ToSliceTooSmall { - found: buf.len(), - expected: self.len_full(), - }); - } - buf[0] = TlvType::FilestoreRequest as u8; - buf[1] = self.len_value() as u8; - buf[2] = (self.action_code as u8) << 4; - let mut current_idx = 3; - // Length checks were already performed. - self.first_name.write_to_be_bytes_no_len_check( - &mut buf[current_idx..current_idx + self.first_name.len_full()], - ); - current_idx += self.first_name.len_full(); - if let Some(second_name) = self.second_name { - second_name.write_to_be_bytes_no_len_check( - &mut buf[current_idx..current_idx + second_name.len_full()], - ); - current_idx += second_name.len_full(); - } - Ok(current_idx) - } - pub fn from_bytes<'longest: 'first_name + 'second_name>( buf: &'longest [u8], ) -> Result { @@ -485,16 +497,51 @@ impl<'first_name, 'second_name> FilestoreRequestTlv<'first_name, 'second_name> { } } +impl WritableTlv for FilestoreRequestTlv<'_, '_> { + fn write_to_bytes(&self, buf: &mut [u8]) -> Result { + if buf.len() < self.len_full() { + return Err(ByteConversionError::ToSliceTooSmall { + found: buf.len(), + expected: self.len_full(), + }); + } + buf[0] = TlvType::FilestoreRequest as u8; + buf[1] = self.len_value() as u8; + buf[2] = (self.action_code as u8) << 4; + let mut current_idx = 3; + // Length checks were already performed. + self.first_name.write_to_be_bytes_no_len_check( + &mut buf[current_idx..current_idx + self.first_name.len_full()], + ); + current_idx += self.first_name.len_full(); + if let Some(second_name) = self.second_name { + second_name.write_to_be_bytes_no_len_check( + &mut buf[current_idx..current_idx + second_name.len_full()], + ); + current_idx += second_name.len_full(); + } + Ok(current_idx) + } + + fn len_written(&self) -> usize { + self.len_full() + } +} + +impl GenericTlv for FilestoreRequestTlv<'_, '_> { + fn tlv_type_field(&self) -> TlvTypeField { + TlvTypeField::Standard(TlvType::FilestoreRequest) + } +} + #[cfg(test)] mod tests { - - use alloc::string::ToString; - use super::*; use crate::cfdp::lv::Lv; use crate::cfdp::tlv::{FilestoreActionCode, FilestoreRequestTlv, Tlv, TlvType, TlvTypeField}; use crate::cfdp::TlvLvError; use crate::util::{UbfU16, UbfU8, UnsignedEnum}; + use alloc::string::ToString; const TLV_TEST_STR_0: &str = "hello.txt"; const TLV_TEST_STR_1: &str = "hello2.txt"; @@ -559,8 +606,9 @@ mod tests { let entity_id = UbfU16::new(0x0102); let entity_id_tlv = EntityIdTlv::new(entity_id.into()); let mut buf: [u8; 16] = [0; 16]; - let written_len = entity_id_tlv.write_to_be_bytes(&mut buf).unwrap(); + let written_len = entity_id_tlv.write_to_bytes(&mut buf).unwrap(); assert_eq!(written_len, entity_id_tlv.len_full()); + assert!(entity_id_tlv.is_standard_tlv()); assert_eq!(buf[0], TlvType::EntityId as u8); assert_eq!(buf[1], 2); assert_eq!(u16::from_be_bytes(buf[2..4].try_into().unwrap()), 0x0102); @@ -571,7 +619,7 @@ mod tests { let entity_id = UbfU16::new(0x0102); let entity_id_tlv = EntityIdTlv::new(entity_id.into()); let mut buf: [u8; 16] = [0; 16]; - let _ = entity_id_tlv.write_to_be_bytes(&mut buf).unwrap(); + let _ = entity_id_tlv.write_to_bytes(&mut buf).unwrap(); let entity_tlv_from_raw = EntityIdTlv::from_bytes(&buf).expect("creating entity ID TLV failed"); assert_eq!(entity_tlv_from_raw, entity_id_tlv); @@ -848,4 +896,40 @@ mod tests { assert_eq!(tlv.len_value(), 2); assert_eq!(tlv.value(), &[0x01, 0x02]); } + + #[test] + fn test_invalid_tlv_conversion() { + let msg_to_user_tlv = Tlv::new_empty(TlvType::MsgToUser); + let error = EntityIdTlv::try_from(msg_to_user_tlv); + assert!(error.is_err()); + let error = error.unwrap_err(); + if let TlvLvError::InvalidTlvTypeField { found, expected } = error { + assert_eq!(found, TlvType::MsgToUser as u8); + assert_eq!(expected, Some(TlvType::EntityId as u8)); + assert_eq!( + error.to_string(), + "invalid TLV type field, found 2, expected Some(6)" + ); + } else { + panic!("unexpected error"); + } + } + + #[test] + fn test_entity_id_invalid_value_len() { + let entity_id = UbfU16::new(0x0102); + let entity_id_tlv = EntityIdTlv::new(entity_id.into()); + let mut buf: [u8; 32] = [0; 32]; + entity_id_tlv.write_to_bytes(&mut buf).unwrap(); + buf[1] = 12; + let error = EntityIdTlv::from_bytes(&buf); + assert!(error.is_err()); + let error = error.unwrap_err(); + if let TlvLvError::InvalidValueLength(len) = error { + assert_eq!(len, 12); + assert_eq!(error.to_string(), "invalid value length 12"); + } else { + panic!("unexpected error"); + } + } } diff --git a/src/cfdp/tlv/msg_to_user.rs b/src/cfdp/tlv/msg_to_user.rs index 7f30565..9ccd0be 100644 --- a/src/cfdp/tlv/msg_to_user.rs +++ b/src/cfdp/tlv/msg_to_user.rs @@ -1,5 +1,5 @@ //! Abstractions for the Message to User CFDP TLV subtype. -use super::{Tlv, TlvLvError, TlvType, TlvTypeField}; +use super::{GenericTlv, Tlv, TlvLvError, TlvType, TlvTypeField, WritableTlv}; use crate::ByteConversionError; use delegate::delegate; @@ -18,8 +18,6 @@ impl<'data> MsgToUserTlv<'data> { delegate! { to self.tlv { - pub fn tlv_type_field(&self) -> TlvTypeField; - pub fn write_to_bytes(&self, buf: &mut [u8]) -> Result; pub fn value(&self) -> &[u8]; /// Helper method to retrieve the length of the value. Simply calls the [slice::len] method of /// [Self::value] @@ -69,7 +67,7 @@ impl<'data> MsgToUserTlv<'data> { } } TlvTypeField::Custom(raw) => { - return Err(TlvLvError::InvalidTlvTypeField{ + return Err(TlvLvError::InvalidTlvTypeField { found: raw, expected: Some(TlvType::MsgToUser as u8), }); @@ -79,6 +77,24 @@ impl<'data> MsgToUserTlv<'data> { } } +impl WritableTlv for MsgToUserTlv<'_> { + fn len_written(&self) -> usize { + self.len_full() + } + + delegate!( + to self.tlv { + fn write_to_bytes(&self, buf: &mut [u8]) -> Result; + } + ); +} + +impl GenericTlv for MsgToUserTlv<'_> { + fn tlv_type_field(&self) -> TlvTypeField { + TlvTypeField::Standard(TlvType::MsgToUser) + } +} + #[cfg(test)] mod tests { use super::*;