monero_io/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![cfg_attr(not(feature = "std"), no_std)]
5
6use core::fmt::Debug;
7use std_shims::{
8  vec,
9  vec::Vec,
10  io::{self, Read, Write},
11};
12
13use curve25519_dalek::{
14  scalar::Scalar,
15  edwards::{EdwardsPoint, CompressedEdwardsY},
16};
17
18const VARINT_CONTINUATION_MASK: u8 = 0b1000_0000;
19
20mod sealed {
21  use core::fmt::Debug;
22
23  /// A trait for a number readable/writable as a VarInt.
24  ///
25  /// This is sealed to prevent unintended implementations.
26  pub trait VarInt: TryInto<u64, Error: Debug> + TryFrom<u64, Error: Debug> + Copy {
27    const BITS: usize;
28  }
29
30  impl VarInt for u8 {
31    const BITS: usize = 8;
32  }
33  impl VarInt for u32 {
34    const BITS: usize = 32;
35  }
36  impl VarInt for u64 {
37    const BITS: usize = 64;
38  }
39  // Don't compile for platforms where `usize` exceeds `u64`, preventing various possible runtime
40  // exceptions
41  const _NO_128_BIT_PLATFORMS: [(); (u64::BITS - usize::BITS) as usize] =
42    [(); (u64::BITS - usize::BITS) as usize];
43  impl VarInt for usize {
44    const BITS: usize = core::mem::size_of::<usize>() * 8;
45  }
46}
47
48/// The amount of bytes this number will take when serialized as a VarInt.
49///
50/// This function will panic if the VarInt exceeds u64::MAX.
51pub fn varint_len<V: sealed::VarInt>(varint: V) -> usize {
52  let varint_u64: u64 = varint.try_into().expect("varint exceeded u64");
53  ((usize::try_from(u64::BITS - varint_u64.leading_zeros())
54    .expect("64 > usize::MAX")
55    .saturating_sub(1)) /
56    7) +
57    1
58}
59
60/// Write a byte.
61///
62/// This is used as a building block within generic functions.
63pub fn write_byte<W: Write>(byte: &u8, w: &mut W) -> io::Result<()> {
64  w.write_all(&[*byte])
65}
66
67/// Write a number, VarInt-encoded.
68///
69/// This will panic if the VarInt exceeds u64::MAX.
70pub fn write_varint<W: Write, U: sealed::VarInt>(varint: &U, w: &mut W) -> io::Result<()> {
71  let mut varint: u64 = (*varint).try_into().expect("varint exceeded u64");
72  while {
73    let mut b = u8::try_from(varint & u64::from(!VARINT_CONTINUATION_MASK))
74      .expect("& eight_bit_mask left more than 8 bits set");
75    varint >>= 7;
76    if varint != 0 {
77      b |= VARINT_CONTINUATION_MASK;
78    }
79    write_byte(&b, w)?;
80    varint != 0
81  } {}
82  Ok(())
83}
84
85/// Write a scalar.
86pub fn write_scalar<W: Write>(scalar: &Scalar, w: &mut W) -> io::Result<()> {
87  w.write_all(&scalar.to_bytes())
88}
89
90/// Write a point.
91pub fn write_point<W: Write>(point: &EdwardsPoint, w: &mut W) -> io::Result<()> {
92  w.write_all(&point.compress().to_bytes())
93}
94
95/// Write a list of elements, without length-prefixing.
96pub fn write_raw_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
97  f: F,
98  values: &[T],
99  w: &mut W,
100) -> io::Result<()> {
101  for value in values {
102    f(value, w)?;
103  }
104  Ok(())
105}
106
107/// Write a list of elements, with length-prefixing.
108pub fn write_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
109  f: F,
110  values: &[T],
111  w: &mut W,
112) -> io::Result<()> {
113  write_varint(&values.len(), w)?;
114  write_raw_vec(f, values, w)
115}
116
117/// Read a constant amount of bytes.
118pub fn read_bytes<R: Read, const N: usize>(r: &mut R) -> io::Result<[u8; N]> {
119  let mut res = [0; N];
120  r.read_exact(&mut res)?;
121  Ok(res)
122}
123
124/// Read a single byte.
125pub fn read_byte<R: Read>(r: &mut R) -> io::Result<u8> {
126  Ok(read_bytes::<_, 1>(r)?[0])
127}
128
129/// Read a u16, little-endian encoded.
130pub fn read_u16<R: Read>(r: &mut R) -> io::Result<u16> {
131  read_bytes(r).map(u16::from_le_bytes)
132}
133
134/// Read a u32, little-endian encoded.
135pub fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
136  read_bytes(r).map(u32::from_le_bytes)
137}
138
139/// Read a u64, little-endian encoded.
140pub fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
141  read_bytes(r).map(u64::from_le_bytes)
142}
143
144/// Read a canonically-encoded VarInt.
145pub fn read_varint<R: Read, U: sealed::VarInt>(r: &mut R) -> io::Result<U> {
146  let mut bits = 0;
147  let mut res = 0;
148  while {
149    let b = read_byte(r)?;
150    if (bits != 0) && (b == 0) {
151      Err(io::Error::other("non-canonical varint"))?;
152    }
153    if ((bits + 7) >= U::BITS) && (b >= (1 << (U::BITS - bits))) {
154      Err(io::Error::other("varint overflow"))?;
155    }
156
157    res += u64::from(b & (!VARINT_CONTINUATION_MASK)) << bits;
158    bits += 7;
159    b & VARINT_CONTINUATION_MASK == VARINT_CONTINUATION_MASK
160  } {}
161  res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
162}
163
164/// Read a canonically-encoded scalar.
165///
166/// Some scalars within the Monero protocol are not enforced to be canonically encoded. For such
167/// scalars, they should be represented as `[u8; 32]` and later converted to scalars as relevant.
168pub fn read_scalar<R: Read>(r: &mut R) -> io::Result<Scalar> {
169  Option::from(Scalar::from_canonical_bytes(read_bytes(r)?))
170    .ok_or_else(|| io::Error::other("unreduced scalar"))
171}
172
173/// Decompress a canonically-encoded Ed25519 point.
174///
175/// Ed25519 is of order `8 * l`. This function ensures each of those `8 * l` points have a singular
176/// encoding by checking points aren't encoded with an unreduced field element, and aren't negative
177/// when the negative is equivalent (0 == -0).
178///
179/// Since this decodes an Ed25519 point, it does not check the point is in the prime-order
180/// subgroup. Torsioned points do have a canonical encoding, and only aren't canonical when
181/// considered in relation to the prime-order subgroup.
182pub fn decompress_point(bytes: [u8; 32]) -> Option<EdwardsPoint> {
183  CompressedEdwardsY(bytes)
184    .decompress()
185    // Ban points which are either unreduced or -0
186    .filter(|point| point.compress().to_bytes() == bytes)
187}
188
189/// Read a canonically-encoded Ed25519 point.
190///
191/// This internally calls `decompress_point` and has the same definition of canonicity. This
192/// function does not check the resulting point is within the prime-order subgroup.
193pub fn read_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
194  let bytes = read_bytes(r)?;
195  decompress_point(bytes).ok_or_else(|| io::Error::other("invalid point"))
196}
197
198/// Read a canonically-encoded Ed25519 point, within the prime-order subgroup.
199pub fn read_torsion_free_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
200  read_point(r)
201    .ok()
202    .filter(EdwardsPoint::is_torsion_free)
203    .ok_or_else(|| io::Error::other("invalid point"))
204}
205
206/// Read a variable-length list of elements, without length-prefixing.
207pub fn read_raw_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
208  f: F,
209  len: usize,
210  r: &mut R,
211) -> io::Result<Vec<T>> {
212  let mut res = vec![];
213  for _ in 0 .. len {
214    res.push(f(r)?);
215  }
216  Ok(res)
217}
218
219/// Read a constant-length list of elements.
220pub fn read_array<R: Read, T: Debug, F: Fn(&mut R) -> io::Result<T>, const N: usize>(
221  f: F,
222  r: &mut R,
223) -> io::Result<[T; N]> {
224  read_raw_vec(f, N, r).map(|vec| {
225    vec.try_into().expect(
226      "read vector of specific length yet couldn't transform to an array of the same length",
227    )
228  })
229}
230
231/// Read a length-prefixed variable-length list of elements.
232///
233/// An optional bound on the length of the result may be provided. If `None`, the returned `Vec`
234/// will be of the length read off the reader, if successfully read. If `Some(_)`, an error will be
235/// raised if the length read off the read is greater than the bound.
236pub fn read_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
237  f: F,
238  length_bound: Option<usize>,
239  r: &mut R,
240) -> io::Result<Vec<T>> {
241  let declared_length: usize = read_varint(r)?;
242  if let Some(length_bound) = length_bound {
243    if declared_length > length_bound {
244      Err(io::Error::other("vector exceeds bound on length"))?;
245    }
246  }
247  read_raw_vec(f, declared_length, r)
248}