diff --git a/src/database/de.rs b/src/database/de.rs index 2180e0fb..7ee42129 100644 --- a/src/database/de.rs +++ b/src/database/de.rs @@ -22,7 +22,7 @@ pub(crate) fn from_slice<'a, T>(buf: &'a [u8]) -> Result where T: Deserialize<'a>, { - let mut deserializer = Deserializer { buf, pos: 0, rec: 0, seq: false }; + let mut deserializer = Deserializer { buf, pos: 0, rec: 0, seq: 0 }; T::deserialize(&mut deserializer).debug_inspect(|_| { deserializer @@ -36,7 +36,7 @@ pub(crate) struct Deserializer<'de> { buf: &'de [u8], pos: usize, rec: usize, - seq: bool, + seq: usize, } /// Directive to ignore a record. This type can be used to skip deserialization @@ -70,9 +70,9 @@ impl<'de> Deserializer<'de> { /// Called at the start of arrays and tuples #[inline] - fn sequence_start(&mut self) { - debug_assert!(!self.seq, "Nested sequences are not handled at this time"); - self.seq = true; + fn sequence_start(&mut self, len: usize) { + debug_assert!(self.seq == 0, "Nested sequences are not handled at this time"); + self.seq = len; } /// Consume the current record to ignore it. Inside a sequence the next @@ -80,7 +80,7 @@ impl<'de> Deserializer<'de> { /// deserialization completes with self.finished() == Ok. #[inline] fn record_ignore(&mut self) { - if self.seq { + if self.seq > 0 { self.record_next(); } else { self.record_ignore_all(); @@ -123,12 +123,16 @@ impl<'de> Deserializer<'de> { #[inline] fn record_start(&mut self) { let started = self.pos != 0 || self.rec > 0; + let input_done = self.pos >= self.buf.len(); + let output_done = self.rec >= self.seq; + let incomplete = input_done && !output_done; debug_assert!( - !started || self.buf[self.pos] == Self::SEP, + !started || incomplete || self.buf.get(self.pos) == Some(&Self::SEP), "Missing expected record separator at current position" ); - self.inc_pos(started.into()); + let inc = started && !incomplete; + self.inc_pos(inc.into()); self.inc_rec(1); } @@ -179,7 +183,7 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - self.sequence_start(); + self.sequence_start(1); visitor.visit_seq(self) } @@ -187,11 +191,11 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { unabridged, tracing::instrument(level = "trace", skip(self, visitor)) )] - fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + fn deserialize_tuple(self, len: usize, visitor: V) -> Result where V: Visitor<'de>, { - self.sequence_start(); + self.sequence_start(len); visitor.visit_seq(self) } @@ -202,13 +206,13 @@ impl<'a, 'de: 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { fn deserialize_tuple_struct( self, _name: &'static str, - _len: usize, + len: usize, visitor: V, ) -> Result where V: Visitor<'de>, { - self.sequence_start(); + self.sequence_start(len); visitor.visit_seq(self) } @@ -462,7 +466,18 @@ impl<'a, 'de: 'a> de::SeqAccess<'de> for &'a mut Deserializer<'de> { where T: DeserializeSeed<'de>, { - if self.pos >= self.buf.len() { + // Finished parsing the input. + let finished = self.pos >= self.buf.len(); + + // Completely satisfied the output. + let complete = self.rec >= self.seq; + + // Leave after reaching the end of both the input and output. Leaving before + // reaching the end of the input trips the finished() assertion. Leaving before + // reaching the end of the output causes a length expectation panic on type T + // i.e. tuple of size X instead of Y and trailing defaulting elements will not + // be possible. + if finished && complete { return Ok(None); } diff --git a/src/database/tests.rs b/src/database/tests.rs index bab704c5..9ae3aa5d 100644 --- a/src/database/tests.rs +++ b/src/database/tests.rs @@ -240,6 +240,39 @@ fn de_tuple_incomplete() { assert_eq!(a, user_id, "deserialized user_id does not match"); } +#[test] +fn de_tuple_incomplete_default() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com"; + let (a, b): (&UserId, &str) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, "", "deserialized defaulted str does not match"); +} + +#[test] +#[should_panic(expected = "failed to deserialize")] +fn de_tuple_incomplete_nodefault() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com"; + let (a, _): (&UserId, u64) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); +} + +#[test] +fn de_tuple_incomplete_option() { + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let raw: &[u8] = b"@user:example.com"; + let (a, b): (&UserId, Option<&str>) = de::from_slice(raw).expect("failed to deserialize"); + + assert_eq!(a, user_id, "deserialized user_id does not match"); + assert_eq!(b, None, "deserialized defaulted Option does not match"); +} + #[test] #[should_panic(expected = "failed to deserialize")] fn de_tuple_incomplete_with_sep() { @@ -480,6 +513,28 @@ fn serde_tuple_option_some_value() { assert_eq!(cc.1, bb.1); } +#[test] +fn serde_tuple_option_value_incomplete() { + let room_id: &RoomId = "!room:example.com".try_into().unwrap(); + let user_id: &UserId = "@user:example.com".try_into().unwrap(); + + let mut aa = Vec::::new(); + aa.extend_from_slice(room_id.as_bytes()); + aa.push(0xFF); + aa.extend_from_slice(user_id.as_bytes()); + + let bb: (&RoomId, &UserId) = (room_id, user_id); + let bbs = serialize_to_vec(&bb).expect("failed to serialize tuple"); + assert_eq!(aa, bbs); + + let cc: (&RoomId, &UserId, Option) = + de::from_slice(&bbs).expect("failed to deserialize tuple"); + + assert_eq!(bb.0, cc.0); + assert_eq!(bb.1, cc.1); + assert_eq!(cc.2, None); +} + #[test] fn serde_tuple_option_some_some() { let room_id: &RoomId = "!room:example.com".try_into().unwrap();