monero_io/
varint.rs

1//! Monero's VarInt type, frequently used to encode integers expected to be of low norm.
2//!
3//! This corresponds to
4//! https://github.com/monero-project/monero/blob/8e9ab9677f90492bca3c7555a246f2a8677bd570
5//!   /src/common/varint.h.
6
7#[allow(unused_imports)]
8use std_shims::prelude::*;
9use std_shims::io::{self, Read, Write};
10
11use crate::{read_byte, write_byte};
12
13const VARINT_CONTINUATION_FLAG: u8 = 0b1000_0000;
14const VARINT_VALUE_MASK: u8 = !VARINT_CONTINUATION_FLAG;
15
16mod sealed {
17  /// A seal to prevent implementing `VarInt` on foreign types.
18  pub trait Sealed {
19    /// Lossless, guaranteed conversion into a `u64`.
20    ///
21    /// This is due to internally implementing encoding for `u64` alone and `usize` not implementing
22    /// `From<u64>`.
23    // This is placed here so it's not within our public API commitment.
24    fn into_u64(self) -> u64;
25  }
26}
27
28/// Compute the upper bound for the encoded length of a integer type as a VarInt.
29///
30/// This is a private function only called at compile-time, hence why it panics on unexpected
31/// input.
32#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
33const fn upper_bound(bits: u32) -> usize {
34  // This assert ensures the following cast is correct even on 8-bit platforms
35  assert!(bits <= 256, "defining a number exceeding u256 as a VarInt");
36  // Manually implement `div_ceil` as it was introduced with 1.73 and `std-shims` cannot provide
37  // a `const fn` shim due to using a trait to provide this as a method
38  ((bits + (7 - 1)) / 7) as usize
39}
40
41/// A trait for a number readable/writable as a VarInt.
42///
43/// This is sealed to prevent unintended implementations. It MUST only be implemented for primitive
44/// types (or sufficiently approximate types like `NonZero<_>`).
45pub trait VarInt: TryFrom<u64> + Copy + sealed::Sealed {
46  /// The lower bound on the amount of bytes this will take up when encoded.
47  const LOWER_BOUND: usize;
48
49  /// The upper bound on the amount of bytes this will take up when encoded.
50  const UPPER_BOUND: usize;
51
52  /// The amount of bytes this number will take when serialized as a VarInt.
53  fn varint_len(self) -> usize {
54    let varint_u64 = self.into_u64();
55    let bits = usize::try_from(u64::BITS - varint_u64.leading_zeros()).expect("64 > usize::MAX?");
56    let encoded_bytes = bits.div_ceil(7);
57    encoded_bytes.max(1)
58  }
59
60  /// Read a canonically-encoded VarInt.
61  fn read<R: Read>(r: &mut R) -> io::Result<Self> {
62    let mut bits = 0;
63    let mut res = 0;
64    while {
65      let b = read_byte(r)?;
66      // Reject trailing zero bytes
67      // https://github.com/monero-project/monero/blob/8e9ab9677f90492bca3c7555a246f2a8677bd570
68      //   /src/common/varint.h#L107
69      if (bits != 0) && (b == 0) {
70        Err(io::Error::other("non-canonical varint"))?;
71      }
72
73      // We use `size_of` here as we control what `VarInt` is implemented for, and it's only for
74      // types whose size correspond to their range
75      #[allow(non_snake_case)]
76      let U_BITS = core::mem::size_of::<Self>() * 8;
77      if ((bits + 7) >= U_BITS) && (b >= (1 << (U_BITS - bits))) {
78        Err(io::Error::other("varint overflow"))?;
79      }
80
81      res += u64::from(b & VARINT_VALUE_MASK) << bits;
82      bits += 7;
83      (b & VARINT_CONTINUATION_FLAG) == VARINT_CONTINUATION_FLAG
84    } {}
85    res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
86  }
87
88  /// Encode a number as a VarInt.
89  ///
90  /// This doesn't accept `self` to force writing it as `VarInt::write`, making it clear it's being
91  /// written with the VarInt encoding.
92  fn write<W: Write>(varint: &Self, w: &mut W) -> io::Result<()> {
93    let mut varint: u64 = varint.into_u64();
94
95    // A do-while loop as we always encode at least one byte
96    while {
97      // Take the next seven bits
98      let mut b = u8::try_from(varint & u64::from(VARINT_VALUE_MASK))
99        .expect("& 0b0111_1111 left more than 8 bits set");
100      varint >>= 7;
101
102      // If there's more, set the continuation flag
103      if varint != 0 {
104        b |= VARINT_CONTINUATION_FLAG;
105      }
106
107      // Write this byte
108      write_byte(&b, w)?;
109
110      // Continue until the number is fully encoded
111      varint != 0
112    } {}
113
114    Ok(())
115  }
116}
117
118impl sealed::Sealed for u8 {
119  fn into_u64(self) -> u64 {
120    self.into()
121  }
122}
123impl VarInt for u8 {
124  const LOWER_BOUND: usize = 1;
125  const UPPER_BOUND: usize = upper_bound(Self::BITS);
126}
127
128impl sealed::Sealed for u32 {
129  fn into_u64(self) -> u64 {
130    self.into()
131  }
132}
133impl VarInt for u32 {
134  const LOWER_BOUND: usize = 1;
135  const UPPER_BOUND: usize = upper_bound(Self::BITS);
136}
137
138impl sealed::Sealed for u64 {
139  fn into_u64(self) -> u64 {
140    self
141  }
142}
143impl VarInt for u64 {
144  const LOWER_BOUND: usize = 1;
145  const UPPER_BOUND: usize = upper_bound(Self::BITS);
146}
147
148impl sealed::Sealed for usize {
149  fn into_u64(self) -> u64 {
150    // Ensure the falling conversion is infallible
151    #[allow(clippy::as_conversions)]
152    const _NO_128_BIT_PLATFORMS: [(); (u64::BITS - usize::BITS) as usize] =
153      [(); (u64::BITS - usize::BITS) as usize];
154
155    self.try_into().expect("compiling on platform with <64-bit usize yet value didn't fit in u64")
156  }
157}
158impl VarInt for usize {
159  const LOWER_BOUND: usize = 1;
160  const UPPER_BOUND: usize = upper_bound(Self::BITS);
161}