1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![cfg_attr(not(feature = "std"), no_std)]
5
6use core::fmt::Debug;
7use std_shims::{
8 vec,
9 vec::Vec,
10 io::{self, Read, Write},
11};
12
13use curve25519_dalek::{
14 scalar::Scalar,
15 edwards::{EdwardsPoint, CompressedEdwardsY},
16};
17
18const VARINT_CONTINUATION_MASK: u8 = 0b1000_0000;
19
20mod sealed {
21 use core::fmt::Debug;
22
23 pub trait VarInt: TryInto<u64, Error: Debug> + TryFrom<u64, Error: Debug> + Copy {
27 const BITS: usize;
28 }
29
30 impl VarInt for u8 {
31 const BITS: usize = 8;
32 }
33 impl VarInt for u32 {
34 const BITS: usize = 32;
35 }
36 impl VarInt for u64 {
37 const BITS: usize = 64;
38 }
39 const _NO_128_BIT_PLATFORMS: [(); (u64::BITS - usize::BITS) as usize] =
42 [(); (u64::BITS - usize::BITS) as usize];
43 impl VarInt for usize {
44 const BITS: usize = core::mem::size_of::<usize>() * 8;
45 }
46}
47
48pub fn varint_len<V: sealed::VarInt>(varint: V) -> usize {
52 let varint_u64: u64 = varint.try_into().expect("varint exceeded u64");
53 ((usize::try_from(u64::BITS - varint_u64.leading_zeros())
54 .expect("64 > usize::MAX")
55 .saturating_sub(1)) /
56 7) +
57 1
58}
59
60pub fn write_byte<W: Write>(byte: &u8, w: &mut W) -> io::Result<()> {
64 w.write_all(&[*byte])
65}
66
67pub fn write_varint<W: Write, U: sealed::VarInt>(varint: &U, w: &mut W) -> io::Result<()> {
71 let mut varint: u64 = (*varint).try_into().expect("varint exceeded u64");
72 while {
73 let mut b = u8::try_from(varint & u64::from(!VARINT_CONTINUATION_MASK))
74 .expect("& eight_bit_mask left more than 8 bits set");
75 varint >>= 7;
76 if varint != 0 {
77 b |= VARINT_CONTINUATION_MASK;
78 }
79 write_byte(&b, w)?;
80 varint != 0
81 } {}
82 Ok(())
83}
84
85pub fn write_scalar<W: Write>(scalar: &Scalar, w: &mut W) -> io::Result<()> {
87 w.write_all(&scalar.to_bytes())
88}
89
90pub fn write_point<W: Write>(point: &EdwardsPoint, w: &mut W) -> io::Result<()> {
92 w.write_all(&point.compress().to_bytes())
93}
94
95pub fn write_raw_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
97 f: F,
98 values: &[T],
99 w: &mut W,
100) -> io::Result<()> {
101 for value in values {
102 f(value, w)?;
103 }
104 Ok(())
105}
106
107pub fn write_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
109 f: F,
110 values: &[T],
111 w: &mut W,
112) -> io::Result<()> {
113 write_varint(&values.len(), w)?;
114 write_raw_vec(f, values, w)
115}
116
117pub fn read_bytes<R: Read, const N: usize>(r: &mut R) -> io::Result<[u8; N]> {
119 let mut res = [0; N];
120 r.read_exact(&mut res)?;
121 Ok(res)
122}
123
124pub fn read_byte<R: Read>(r: &mut R) -> io::Result<u8> {
126 Ok(read_bytes::<_, 1>(r)?[0])
127}
128
129pub fn read_u16<R: Read>(r: &mut R) -> io::Result<u16> {
131 read_bytes(r).map(u16::from_le_bytes)
132}
133
134pub fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
136 read_bytes(r).map(u32::from_le_bytes)
137}
138
139pub fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
141 read_bytes(r).map(u64::from_le_bytes)
142}
143
144pub fn read_varint<R: Read, U: sealed::VarInt>(r: &mut R) -> io::Result<U> {
146 let mut bits = 0;
147 let mut res = 0;
148 while {
149 let b = read_byte(r)?;
150 if (bits != 0) && (b == 0) {
151 Err(io::Error::other("non-canonical varint"))?;
152 }
153 if ((bits + 7) >= U::BITS) && (b >= (1 << (U::BITS - bits))) {
154 Err(io::Error::other("varint overflow"))?;
155 }
156
157 res += u64::from(b & (!VARINT_CONTINUATION_MASK)) << bits;
158 bits += 7;
159 b & VARINT_CONTINUATION_MASK == VARINT_CONTINUATION_MASK
160 } {}
161 res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
162}
163
164pub fn read_scalar<R: Read>(r: &mut R) -> io::Result<Scalar> {
169 Option::from(Scalar::from_canonical_bytes(read_bytes(r)?))
170 .ok_or_else(|| io::Error::other("unreduced scalar"))
171}
172
173pub fn decompress_point(bytes: [u8; 32]) -> Option<EdwardsPoint> {
183 CompressedEdwardsY(bytes)
184 .decompress()
185 .filter(|point| point.compress().to_bytes() == bytes)
187}
188
189pub fn read_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
194 let bytes = read_bytes(r)?;
195 decompress_point(bytes).ok_or_else(|| io::Error::other("invalid point"))
196}
197
198pub fn read_torsion_free_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
200 read_point(r)
201 .ok()
202 .filter(EdwardsPoint::is_torsion_free)
203 .ok_or_else(|| io::Error::other("invalid point"))
204}
205
206pub fn read_raw_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
208 f: F,
209 len: usize,
210 r: &mut R,
211) -> io::Result<Vec<T>> {
212 let mut res = vec![];
213 for _ in 0 .. len {
214 res.push(f(r)?);
215 }
216 Ok(res)
217}
218
219pub fn read_array<R: Read, T: Debug, F: Fn(&mut R) -> io::Result<T>, const N: usize>(
221 f: F,
222 r: &mut R,
223) -> io::Result<[T; N]> {
224 read_raw_vec(f, N, r).map(|vec| {
225 vec.try_into().expect(
226 "read vector of specific length yet couldn't transform to an array of the same length",
227 )
228 })
229}
230
231pub fn read_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
237 f: F,
238 length_bound: Option<usize>,
239 r: &mut R,
240) -> io::Result<Vec<T>> {
241 let declared_length: usize = read_varint(r)?;
242 if let Some(length_bound) = length_bound {
243 if declared_length > length_bound {
244 Err(io::Error::other("vector exceeds bound on length"))?;
245 }
246 }
247 read_raw_vec(f, declared_length, r)
248}