diff --git a/macros/macros_impl/src/config.rs b/macros/macros_impl/src/config.rs index 25a9d7d2..4d3274cc 100644 --- a/macros/macros_impl/src/config.rs +++ b/macros/macros_impl/src/config.rs @@ -913,7 +913,7 @@ impl<'a> FieldConfig<'a> { )?; ) } else if self.extension_addition_group { - quote!(encoder.encode_extension_addition_group(#this #field.as_ref(), #identifier)?;) + quote!(encoder.encode_extension_addition_group(#tag, #this #field.as_ref(), #identifier)?;) } else { match (self.constraints.has_constraints(), self.default.is_some()) { (true, true) => { @@ -959,7 +959,7 @@ impl<'a> FieldConfig<'a> { )?; ) } else if self.extension_addition_group { - quote!(encoder.encode_extension_addition_group(#this #field.as_ref(), #identifier)?;) + quote!(encoder.encode_extension_addition_group(#tag, #this #field.as_ref(), #identifier)?;) } else { match (self.constraints.has_constraints(), self.default.is_some()) { (true, true) => { @@ -1063,7 +1063,7 @@ impl<'a> FieldConfig<'a> { }; let decode = if self.extension_addition_group { - quote!(decoder.decode_extension_addition_group() #or_else) + quote!(decoder.decode_extension_addition_group(#tag) #or_else) } else { match ( (self.tag.is_some() || self.container_config.automatic_tags) diff --git a/macros/macros_impl/src/decode.rs b/macros/macros_impl/src/decode.rs index 2af52217..90b111c9 100644 --- a/macros/macros_impl/src/decode.rs +++ b/macros/macros_impl/src/decode.rs @@ -142,7 +142,7 @@ pub fn derive_struct_impl( let decode_impl = if config.extension_addition { quote!(#field_name(decoder.decode_extension_addition()?)) } else if config.extension_addition_group { - quote!(#field_name(decoder.decode_extension_addition_group()?)) + quote!(#field_name(decoder.decode_extension_addition_group(#tag)?)) } else { quote!(<_>::decode(decoder)?) }; diff --git a/src/ber.rs b/src/ber.rs index ac0c136a..d08d7eba 100644 --- a/src/ber.rs +++ b/src/ber.rs @@ -8,6 +8,22 @@ mod rules; pub use identifier::Identifier; pub(crate) use rules::EncodingRules; +#[derive(Clone, Copy, Debug)] +pub(crate) enum ExtensionGroupState { + None, + Pending(crate::types::Tag), + Active(crate::types::Tag), +} + +impl ExtensionGroupState { + pub(crate) fn base_tag(self) -> Option { + match self { + Self::Pending(tag) | Self::Active(tag) => Some(tag), + Self::None => None, + } + } +} + /// Attempts to decode `T` from `input` using BER. /// # Errors /// Returns error specific to BER decoder if decoding is not possible. diff --git a/src/ber/de.rs b/src/ber/de.rs index 1430b6a1..c6098456 100644 --- a/src/ber/de.rs +++ b/src/ber/de.rs @@ -3,7 +3,7 @@ mod config; pub(super) mod parser; -use super::identifier::Identifier; +use super::{identifier::Identifier, ExtensionGroupState}; use crate::{ types::{ self, @@ -29,6 +29,7 @@ pub struct Decoder<'input> { input: &'input [u8], config: DecoderOptions, initial_len: usize, + extension_group: ExtensionGroupState, } impl<'input> Decoder<'input> { @@ -49,9 +50,14 @@ impl<'input> Decoder<'input> { input, config, initial_len: input.len(), + extension_group: ExtensionGroupState::None, } } + fn translate_tag(&self, tag: Tag) -> Tag { + tag.with_context_offset(self.extension_group.base_tag()) + } + /// Return a number of the decoded bytes by this decoder #[must_use] pub fn decoded_len(&self) -> usize { @@ -85,6 +91,9 @@ impl<'input> Decoder<'input> { { return Ok(None); } + + let tag = self.translate_tag(tag); + if tag != Tag::EOC { let upcoming_tag = self.peek_tag()?; if tag != upcoming_tag { @@ -102,6 +111,7 @@ impl<'input> Decoder<'input> { } pub(crate) fn parse_value(&mut self, tag: Tag) -> Result<(Identifier, Option<&'input [u8]>)> { + let tag = self.translate_tag(tag); let (input, (identifier, contents)) = self::parser::parse_value(self.config, self.input, Some(tag))?; self.input = input; @@ -109,6 +119,7 @@ impl<'input> Decoder<'input> { } pub(crate) fn parse_primitive_value(&mut self, tag: Tag) -> Result<(Identifier, &'input [u8])> { + let tag = self.translate_tag(tag); let (input, (identifier, contents)) = self::parser::parse_value(self.config, self.input, Some(tag))?; self.input = input; @@ -762,6 +773,28 @@ impl<'input> crate::Decoder for Decoder<'input> { default_initializer_fn: Option, decode_fn: F, ) -> Result { + if tag == Tag::SEQUENCE && matches!(self.extension_group, ExtensionGroupState::Pending(_)) { + // Extension addition groups are encoded flattened: skip the SEQUENCE wrapper once. + if let ExtensionGroupState::Pending(tag) = self.extension_group { + self.extension_group = ExtensionGroupState::Active(tag); + } + return if D::FIELDS.is_empty() && D::EXTENDED_FIELDS.is_none() + || (D::FIELDS.len() == D::FIELDS.number_of_optional_and_default_fields() + && self.input.is_empty()) + { + if let Some(default_initializer_fn) = default_initializer_fn { + Ok((default_initializer_fn)()) + } else { + Err(DecodeError::from_kind( + DecodeErrorKind::UnexpectedEmptyInput, + self.codec(), + )) + } + } else { + (decode_fn)(self) + }; + } + self.parse_constructed_contents(tag, true, |decoder| { // If there are no fields, or the input is empty and we know that // all fields are optional or default fields, we call the default @@ -806,7 +839,7 @@ impl<'input> crate::Decoder for Decoder<'input> { D: Fn(&mut Self, usize, Tag) -> Result, F: FnOnce(Vec) -> Result, { - self.parse_constructed_contents(tag, true, |decoder| { + let collect_fields = |decoder: &mut Self| -> Result, Self::Error> { let mut fields = Vec::new(); loop { @@ -823,6 +856,20 @@ impl<'input> crate::Decoder for Decoder<'input> { } } + Ok(fields) + }; + + if tag == Tag::SET && matches!(self.extension_group, ExtensionGroupState::Pending(_)) { + // Extension addition groups are encoded flattened: skip the SET wrapper once. + if let ExtensionGroupState::Pending(tag) = self.extension_group { + self.extension_group = ExtensionGroupState::Active(tag); + } + let fields = collect_fields(self)?; + return (field_fn)(fields); + } + + self.parse_constructed_contents(tag, true, |decoder| { + let fields = collect_fields(decoder)?; (field_fn)(fields) }) } @@ -897,8 +944,26 @@ impl<'input> crate::Decoder for Decoder<'input> { D: Decode + crate::types::Constructed, >( &mut self, + tag: Tag, ) -> Result, Self::Error> { - >::decode(self) + if self.input.is_empty() { + return Ok(None); + } + + let (_, identifier) = parser::parse_identifier_octet(self.input).map_err(|e| match e { + ParseNumberError::Nom(e) => DecodeError::map_nom_err(e, self.codec()), + ParseNumberError::Overflow => DecodeError::integer_overflow(32u32, self.codec()), + })?; + + if identifier.tag == tag { + let previous = self.extension_group; + self.extension_group = ExtensionGroupState::Pending(tag); + let result = D::decode(self).map(Some); + self.extension_group = previous; + result + } else { + Ok(None) + } } } diff --git a/src/ber/enc.rs b/src/ber/enc.rs index c2b34fb6..b0963f5f 100644 --- a/src/ber/enc.rs +++ b/src/ber/enc.rs @@ -5,7 +5,7 @@ mod config; use alloc::{borrow::ToOwned, collections::VecDeque, string::ToString, vec::Vec}; use chrono::Timelike; -use super::Identifier; +use super::{ExtensionGroupState, Identifier}; use crate::{ bits::octet_string_ascending, types::{ @@ -28,6 +28,7 @@ pub struct Encoder { config: EncoderOptions, is_set_encoding: bool, set_buffer: alloc::collections::BTreeMap>, + extension_group: ExtensionGroupState, } /// A convenience type around results needing to return one or many bytes. @@ -45,6 +46,7 @@ impl Encoder { is_set_encoding: false, output: <_>::default(), set_buffer: <_>::default(), + extension_group: ExtensionGroupState::None, } } @@ -63,6 +65,7 @@ impl Encoder { is_set_encoding: true, output: <_>::default(), set_buffer: <_>::default(), + extension_group: ExtensionGroupState::None, } } @@ -78,9 +81,14 @@ impl Encoder { config, is_set_encoding: false, set_buffer: <_>::default(), + extension_group: ExtensionGroupState::None, } } + fn translate_tag(&self, tag: Tag) -> Tag { + tag.with_context_offset(self.extension_group.base_tag()) + } + /// Consumes the encoder and returns the output of the encoding. #[must_use] pub fn output(self) -> Vec { @@ -243,6 +251,10 @@ impl Encoder { /// Encodes a given ASN.1 BER value with the `identifier`. fn encode_value(&mut self, identifier: Identifier, value: &[u8]) { + let identifier = Identifier::from_tag( + self.translate_tag(identifier.tag), + identifier.is_constructed, + ); let ident_bytes = self.encode_identifier(identifier); self.append_byte_or_bytes(ident_bytes); self.encode_length(identifier, value); @@ -711,6 +723,14 @@ impl crate::Encoder<'_> for Encoder { C: crate::types::Constructed, F: FnOnce(&mut Self::AnyEncoder<'b, 0, 0>) -> Result<(), Self::Error>, { + if tag == Tag::SEQUENCE && matches!(self.extension_group, ExtensionGroupState::Pending(_)) { + // Extension addition groups are encoded flattened: skip the SEQUENCE wrapper once. + if let ExtensionGroupState::Pending(tag) = self.extension_group { + self.extension_group = ExtensionGroupState::Active(tag); + } + return (encoder_scope)(self); + } + let mut encoder = Self::new(self.config); (encoder_scope)(&mut encoder)?; @@ -730,6 +750,14 @@ impl crate::Encoder<'_> for Encoder { C: crate::types::Constructed, F: FnOnce(&mut Self::AnyEncoder<'b, 0, 0>) -> Result<(), Self::Error>, { + if tag == Tag::SET && matches!(self.extension_group, ExtensionGroupState::Pending(_)) { + // Extension addition groups are encoded flattened: skip the SET wrapper once. + if let ExtensionGroupState::Pending(tag) = self.extension_group { + self.extension_group = ExtensionGroupState::Active(tag); + } + return (encoder_scope)(self); + } + let mut encoder = Self::new_set(self.config); (encoder_scope)(&mut encoder)?; @@ -757,13 +785,23 @@ impl crate::Encoder<'_> for Encoder { /// Encode a extension addition group value. fn encode_extension_addition_group( &mut self, + tag: Tag, value: Option<&E>, _: crate::types::Identifier, ) -> Result where E: Encode + crate::types::Constructed, { - value.encode(self) + match value { + Some(v) => { + let previous = self.extension_group; + self.extension_group = ExtensionGroupState::Pending(tag); + let result = v.encode(self); + self.extension_group = previous; + result + } + None => Ok(()), + } } } diff --git a/src/de.rs b/src/de.rs index 9ec8d26a..728771c8 100644 --- a/src/de.rs +++ b/src/de.rs @@ -520,6 +520,7 @@ pub trait Decoder: Sized { D: Decode + crate::types::Constructed, >( &mut self, + tag: Tag, ) -> Result, Self::Error>; } diff --git a/src/enc.rs b/src/enc.rs index f5038abb..541c3c6b 100644 --- a/src/enc.rs +++ b/src/enc.rs @@ -527,6 +527,7 @@ pub trait Encoder<'encoder, const RCL: usize = 0, const ECL: usize = 0> { /// `E` is the type of the extension addition group value being encoded. fn encode_extension_addition_group( &mut self, + tag: Tag, value: Option<&E>, identifier: Identifier, ) -> Result diff --git a/src/jer/de.rs b/src/jer/de.rs index f23bf23c..528688de 100644 --- a/src/jer/de.rs +++ b/src/jer/de.rs @@ -179,7 +179,7 @@ impl crate::Decoder for Decoder { F: FnOnce(&mut Self) -> Result, { let mut last = self.stack.pop().ok_or_else(JerDecodeErrorKind::eoi)?; - let value_map = last + let _ = last .as_object_mut() .ok_or_else(|| JerDecodeErrorKind::TypeMismatch { needed: "object", @@ -193,12 +193,24 @@ impl crate::Decoder for Decoder { field_names.extend(extended_fields.iter().map(|f| f.name)); } field_names.reverse(); + // Push the (now partially consumed) object onto the stack so extension-addition-group + // decoding can pull group fields from the same flattened object. + self.stack.push(last); + let scope_index = self.stack.len() - 1; for name in field_names { - self.stack - .push(value_map.remove(name).unwrap_or(Value::Null)); + let value = self + .stack + .get_mut(scope_index) + .and_then(|v| v.as_object_mut()) + .and_then(|obj| obj.remove(name)) + .unwrap_or(Value::Null); + self.stack.push(value); } - (decode_fn)(self) + let result = (decode_fn)(self); + // Pop the scope object frame. + let _ = self.stack.pop(); + result } fn decode_sequence_of( @@ -476,8 +488,53 @@ impl crate::Decoder for Decoder { D: crate::Decode + Constructed, >( &mut self, + _tag: Tag, ) -> Result, Self::Error> { - self.decode_optional() + // The SEQUENCE decoder pushes a placeholder for the extension group field (which is not + // explicitly present in JER because extension groups are flattened). + // + // We decode a group by extracting only the group's fields from the current object and + // decoding the group from that scoped object. + let _ = self.stack.pop().ok_or_else(JerDecodeErrorKind::eoi)?; + + let index = self + .stack + .iter() + .rposition(|v| v.is_object()) + .ok_or_else(JerDecodeErrorKind::eoi)?; + let obj = self + .stack + .get_mut(index) + .and_then(|v| v.as_object_mut()) + .ok_or_else(|| JerDecodeErrorKind::TypeMismatch { + needed: "object", + found: "unknown".into(), + })?; + + let mut group_obj = serde_json::Map::with_capacity( + D::FIELDS.len() + D::EXTENDED_FIELDS.as_ref().map_or(0, |fields| fields.len()), + ); + let mut is_present = false; + for field in D::FIELDS.iter() { + if let Some(value) = obj.remove(field.name) { + is_present |= !value.is_null(); + group_obj.insert(alloc::string::String::from(field.name), value); + } + } + if let Some(extended_fields) = D::EXTENDED_FIELDS { + for field in extended_fields.iter() { + if let Some(value) = obj.remove(field.name) { + is_present |= !value.is_null(); + group_obj.insert(alloc::string::String::from(field.name), value); + } + } + } + if is_present { + self.stack.push(Value::Object(group_obj)); + D::decode(self).map(Some) + } else { + Ok(None) + } } fn codec(&self) -> crate::Codec { diff --git a/src/jer/enc.rs b/src/jer/enc.rs index 67264e59..a6513977 100644 --- a/src/jer/enc.rs +++ b/src/jer/enc.rs @@ -529,15 +529,29 @@ impl crate::Encoder<'_> for Encoder { fn encode_extension_addition_group( &mut self, + _t: Tag, value: Option<&E>, _: Identifier, ) -> Result where E: crate::Encode + crate::types::Constructed, { + self.stack.pop(); match value { - Some(v) => v.encode(self), - None => self.encode_none::(Identifier::EMPTY), + Some(v) => { + let mut inner = Self::new(); + v.encode(&mut inner)?; + if let Value::Object(obj) = inner.to_json()? { + self.constructed_stack + .last_mut() + .ok_or_else(|| JerEncodeErrorKind::JsonEncoder { + msg: "Internal stack mismatch!".into(), + })? + .extend(obj); + } + Ok(()) + } + None => Ok(()), } } diff --git a/src/oer/de.rs b/src/oer/de.rs index c82f0da1..862bdb93 100644 --- a/src/oer/de.rs +++ b/src/oer/de.rs @@ -1027,6 +1027,7 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp D: Decode + Constructed, >( &mut self, + _tag: Tag, ) -> Result, Self::Error> { if !self.parse_extension_header()? { return Ok(None); diff --git a/src/oer/enc.rs b/src/oer/enc.rs index 420f1f90..e1a27e32 100644 --- a/src/oer/enc.rs +++ b/src/oer/enc.rs @@ -1151,6 +1151,7 @@ impl<'buffer, const RFC: usize, const EFC: usize> crate::Encoder<'buffer> } fn encode_extension_addition_group( &mut self, + _tag: Tag, value: Option<&E>, _: Identifier, ) -> Result diff --git a/src/per/de.rs b/src/per/de.rs index ca78638a..81611305 100644 --- a/src/per/de.rs +++ b/src/per/de.rs @@ -1132,6 +1132,7 @@ impl<'input, const RFC: usize, const EFC: usize> crate::Decoder for Decoder<'inp D: Decode + crate::types::Constructed, >( &mut self, + _tag: Tag, ) -> Result, Self::Error> { if !self.parse_extension_header()? { return Ok(None); diff --git a/src/per/enc.rs b/src/per/enc.rs index aa3e4f9a..b2e80523 100644 --- a/src/per/enc.rs +++ b/src/per/enc.rs @@ -1330,6 +1330,7 @@ impl crate::Encoder<'_> for Encoder( &mut self, + _tag: Tag, value: Option<&E>, _: Identifier, ) -> Result diff --git a/src/types/tag.rs b/src/types/tag.rs index 44dc1b56..d234facc 100644 --- a/src/types/tag.rs +++ b/src/types/tag.rs @@ -192,6 +192,17 @@ impl Tag { self } + /// Applies a context-specific base tag offset when both tags are context class. + #[must_use] + pub fn with_context_offset(self, base: Option) -> Self { + match base { + Some(base) if base.class == Class::Context && self.class == Class::Context => { + Tag::new(self.class, base.value.saturating_add(self.value)) + } + _ => self, + } + } + #[doc(hidden)] #[must_use] pub const fn const_eq(self, rhs: &Self) -> bool { diff --git a/src/xer/de.rs b/src/xer/de.rs index ba90d8a6..0f08b3ba 100644 --- a/src/xer/de.rs +++ b/src/xer/de.rs @@ -885,6 +885,7 @@ impl crate::Decoder for Decoder { D: crate::Decode + Constructed, >( &mut self, + _tag: Tag, ) -> Result, Self::Error> { self.decode_optional() } diff --git a/src/xer/enc.rs b/src/xer/enc.rs index a575839b..1eb5038d 100644 --- a/src/xer/enc.rs +++ b/src/xer/enc.rs @@ -694,6 +694,7 @@ impl crate::Encoder<'_> for Encoder { fn encode_extension_addition_group( &mut self, + _tag: Tag, value: Option<&E>, identifier: Identifier, ) -> Result diff --git a/tests/extension_group.rs b/tests/extension_group.rs new file mode 100644 index 00000000..64caf262 --- /dev/null +++ b/tests/extension_group.rs @@ -0,0 +1,131 @@ +pub mod extension_addition_group { + extern crate alloc; + use rasn::prelude::*; + #[doc = " Inner type "] + #[derive(AsnType, Debug, Clone, Decode, Encode, PartialEq, Eq, Hash)] + #[rasn(automatic_tags)] + pub struct S1ExtGroupB2 { + pub b2: Option, + } + impl S1ExtGroupB2 { + pub fn new(b2: Option) -> Self { + Self { b2 } + } + } + #[doc = " Inner type "] + #[derive(AsnType, Debug, Clone, Decode, Encode, PartialEq, Eq, Hash)] + #[rasn(automatic_tags)] + pub struct S1ExtGroupB3 { + pub b3: Option, + } + impl S1ExtGroupB3 { + pub fn new(b3: Option) -> Self { + Self { b3 } + } + } + #[derive(AsnType, Debug, Clone, Decode, Encode, PartialEq, Eq, Hash)] + #[rasn(automatic_tags)] + #[non_exhaustive] + pub struct S1 { + pub b1: bool, + #[rasn(extension_addition_group, identifier = "SEQUENCE")] + pub ext_group_b2: Option, + #[rasn(extension_addition_group, identifier = "SEQUENCE")] + pub ext_group_b3: Option, + } + impl S1 { + pub fn new( + b1: bool, + ext_group_b2: Option, + ext_group_b3: Option, + ) -> Self { + Self { + b1, + ext_group_b2, + ext_group_b3, + } + } + } +} + +const SAMPLE_S1: extension_addition_group::S1 = extension_addition_group::S1 { + b1: true, + ext_group_b2: None, + ext_group_b3: Some(extension_addition_group::S1ExtGroupB3 { b3: Some(true) }), +}; + +macro_rules! round_trip { + ($codec:ident, $typ:ty, $value:expr, $expected:expr) => {{ + let value: $typ = $value; + let expected: &[u8] = $expected; + let actual_encoding = rasn::$codec::encode(&value).unwrap(); + + pretty_assertions::assert_eq!(&*actual_encoding, expected); + + let decoded_value = rasn::$codec::decode::<$typ>(&actual_encoding); + match decoded_value { + Ok(decoded) => { + pretty_assertions::assert_eq!(value, decoded); + } + Err(err) => { + panic!("{:?}", err); + } + } + }}; +} + +#[test] +fn extension_group_roundtrip_aper() { + let encoded = &[0xc0, 0xa0, 0x01, 0xc0]; + round_trip!(aper, extension_addition_group::S1, SAMPLE_S1, encoded); +} + +#[test] +fn extension_group_roundtrip_uper() { + let encoded_correct = &[0xc0, 0xa0, 0x38, 0x00]; + round_trip!( + uper, + extension_addition_group::S1, + SAMPLE_S1, + encoded_correct + ); +} + +#[test] +fn extension_group_roundtrip_ber() { + let encoded = &[0x30, 0x06, 0x80, 0x01, 0xff, 0x82, 0x01, 0xff]; + round_trip!(ber, extension_addition_group::S1, SAMPLE_S1, encoded); +} + +#[test] +fn extension_group_roundtrip_cer() { + let encoded = &[0x30, 0x80, 0x80, 0x01, 0xff, 0x82, 0x01, 0xff, 0x00, 0x00]; + round_trip!(cer, extension_addition_group::S1, SAMPLE_S1, encoded); +} + +#[test] +fn extension_group_roundtrip_coer() { + let encoded = &[0x80, 0xff, 0x02, 0x06, 0x40, 0x02, 0x80, 0xff]; + round_trip!(coer, extension_addition_group::S1, SAMPLE_S1, encoded); +} + +#[test] +fn extension_group_roundtrip_der() { + let encoded = &[0x30, 0x06, 0x80, 0x01, 0xff, 0x82, 0x01, 0xff]; + round_trip!(der, extension_addition_group::S1, SAMPLE_S1, encoded); +} + +#[test] +fn extension_group_roundtrip_jer() { + let expected = "{\"b1\":true,\"b3\":true}"; + let encoded = rasn::jer::encode(&SAMPLE_S1).unwrap(); + pretty_assertions::assert_eq!(expected, encoded); + let decoded = rasn::jer::decode::(&encoded).unwrap(); + pretty_assertions::assert_eq!(SAMPLE_S1, decoded); +} + +#[test] +fn extension_group_roundtrip_oer() { + let encoded = &[0x80, 0xff, 0x02, 0x06, 0x40, 0x02, 0x80, 0xff]; + round_trip!(oer, extension_addition_group::S1, SAMPLE_S1, encoded); +}