dalek_ff_group/
lib.rs

1#![allow(deprecated)]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3#![no_std] // Prevents writing new code, in what should be a simple wrapper, which requires std
4#![doc = include_str!("../README.md")]
5#![allow(clippy::redundant_closure_call)]
6
7use core::{
8  borrow::Borrow,
9  ops::{Deref, Add, AddAssign, Sub, SubAssign, Neg, Mul, MulAssign},
10  iter::{Iterator, Sum, Product},
11  hash::{Hash, Hasher},
12};
13
14use zeroize::Zeroize;
15use subtle::{ConstantTimeEq, ConditionallySelectable};
16
17use rand_core::RngCore;
18use digest::{consts::U64, Digest, HashMarker};
19
20use subtle::{Choice, CtOption};
21
22pub use curve25519_dalek as dalek;
23
24use dalek::{
25  constants::{self, BASEPOINT_ORDER},
26  scalar::Scalar as DScalar,
27  edwards::{EdwardsPoint as DEdwardsPoint, EdwardsBasepointTable, CompressedEdwardsY},
28  ristretto::{RistrettoPoint as DRistrettoPoint, RistrettoBasepointTable, CompressedRistretto},
29};
30pub use constants::{ED25519_BASEPOINT_TABLE, RISTRETTO_BASEPOINT_TABLE};
31
32use group::{
33  ff::{Field, PrimeField, FieldBits, PrimeFieldBits},
34  Group, GroupEncoding,
35  prime::PrimeGroup,
36};
37
38mod field;
39pub use field::FieldElement;
40
41// Use black_box when possible
42#[rustversion::since(1.66)]
43mod black_box {
44  pub(crate) fn black_box<T>(val: T) -> T {
45    #[allow(clippy::incompatible_msrv)]
46    core::hint::black_box(val)
47  }
48}
49#[rustversion::before(1.66)]
50mod black_box {
51  pub(crate) fn black_box<T>(val: T) -> T {
52    val
53  }
54}
55use black_box::black_box;
56
57fn u8_from_bool(bit_ref: &mut bool) -> u8 {
58  let bit_ref = black_box(bit_ref);
59
60  let mut bit = black_box(*bit_ref);
61  #[allow(clippy::cast_lossless)]
62  let res = black_box(bit as u8);
63  bit.zeroize();
64  debug_assert!((res | 1) == 1);
65
66  bit_ref.zeroize();
67  res
68}
69
70// Convert a boolean to a Choice in a *presumably* constant time manner
71fn choice(mut value: bool) -> Choice {
72  Choice::from(u8_from_bool(&mut value))
73}
74
75macro_rules! deref_borrow {
76  ($Source: ident, $Target: ident) => {
77    impl Deref for $Source {
78      type Target = $Target;
79
80      fn deref(&self) -> &Self::Target {
81        &self.0
82      }
83    }
84
85    impl Borrow<$Target> for $Source {
86      fn borrow(&self) -> &$Target {
87        &self.0
88      }
89    }
90
91    impl Borrow<$Target> for &$Source {
92      fn borrow(&self) -> &$Target {
93        &self.0
94      }
95    }
96  };
97}
98
99macro_rules! constant_time {
100  ($Value: ident, $Inner: ident) => {
101    impl ConstantTimeEq for $Value {
102      fn ct_eq(&self, other: &Self) -> Choice {
103        self.0.ct_eq(&other.0)
104      }
105    }
106
107    impl ConditionallySelectable for $Value {
108      fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
109        $Value($Inner::conditional_select(&a.0, &b.0, choice))
110      }
111    }
112  };
113}
114pub(crate) use constant_time;
115
116macro_rules! math_op {
117  (
118    $Value: ident,
119    $Other: ident,
120    $Op: ident,
121    $op_fn: ident,
122    $Assign: ident,
123    $assign_fn: ident,
124    $function: expr
125  ) => {
126    impl $Op<$Other> for $Value {
127      type Output = $Value;
128      fn $op_fn(self, other: $Other) -> Self::Output {
129        Self($function(self.0, other.0))
130      }
131    }
132    impl $Assign<$Other> for $Value {
133      fn $assign_fn(&mut self, other: $Other) {
134        self.0 = $function(self.0, other.0);
135      }
136    }
137    impl<'a> $Op<&'a $Other> for $Value {
138      type Output = $Value;
139      fn $op_fn(self, other: &'a $Other) -> Self::Output {
140        Self($function(self.0, other.0))
141      }
142    }
143    impl<'a> $Assign<&'a $Other> for $Value {
144      fn $assign_fn(&mut self, other: &'a $Other) {
145        self.0 = $function(self.0, other.0);
146      }
147    }
148  };
149}
150pub(crate) use math_op;
151
152macro_rules! math {
153  ($Value: ident, $Factor: ident, $add: expr, $sub: expr, $mul: expr) => {
154    math_op!($Value, $Value, Add, add, AddAssign, add_assign, $add);
155    math_op!($Value, $Value, Sub, sub, SubAssign, sub_assign, $sub);
156    math_op!($Value, $Factor, Mul, mul, MulAssign, mul_assign, $mul);
157  };
158}
159pub(crate) use math;
160
161macro_rules! math_neg {
162  ($Value: ident, $Factor: ident, $add: expr, $sub: expr, $mul: expr) => {
163    math!($Value, $Factor, $add, $sub, $mul);
164
165    impl Neg for $Value {
166      type Output = Self;
167      fn neg(self) -> Self::Output {
168        Self(-self.0)
169      }
170    }
171  };
172}
173
174/// Wrapper around the dalek Scalar type.
175#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Zeroize)]
176pub struct Scalar(pub DScalar);
177deref_borrow!(Scalar, DScalar);
178constant_time!(Scalar, DScalar);
179math_neg!(Scalar, Scalar, DScalar::add, DScalar::sub, DScalar::mul);
180
181macro_rules! from_wrapper {
182  ($uint: ident) => {
183    impl From<$uint> for Scalar {
184      fn from(a: $uint) -> Scalar {
185        Scalar(DScalar::from(a))
186      }
187    }
188  };
189}
190
191from_wrapper!(u8);
192from_wrapper!(u16);
193from_wrapper!(u32);
194from_wrapper!(u64);
195from_wrapper!(u128);
196
197impl Scalar {
198  pub fn pow(&self, other: Scalar) -> Scalar {
199    let mut table = [Scalar::ONE; 16];
200    table[1] = *self;
201    for i in 2 .. 16 {
202      table[i] = table[i - 1] * self;
203    }
204
205    let mut res = Scalar::ONE;
206    let mut bits = 0;
207    for (i, mut bit) in other.to_le_bits().iter_mut().rev().enumerate() {
208      bits <<= 1;
209      let mut bit = u8_from_bool(&mut bit);
210      bits |= bit;
211      bit.zeroize();
212
213      if ((i + 1) % 4) == 0 {
214        if i != 3 {
215          for _ in 0 .. 4 {
216            res *= res;
217          }
218        }
219
220        let mut scale_by = Scalar::ONE;
221        #[allow(clippy::needless_range_loop)]
222        for i in 0 .. 16 {
223          #[allow(clippy::cast_possible_truncation)] // Safe since 0 .. 16
224          {
225            scale_by = <_>::conditional_select(&scale_by, &table[i], bits.ct_eq(&(i as u8)));
226          }
227        }
228        res *= scale_by;
229        bits = 0;
230      }
231    }
232    res
233  }
234
235  /// Perform wide reduction on a 64-byte array to create a Scalar without bias.
236  pub fn from_bytes_mod_order_wide(bytes: &[u8; 64]) -> Scalar {
237    Self(DScalar::from_bytes_mod_order_wide(bytes))
238  }
239
240  /// Derive a Scalar without bias from a digest via wide reduction.
241  pub fn from_hash<D: Digest<OutputSize = U64> + HashMarker>(hash: D) -> Scalar {
242    let mut output = [0u8; 64];
243    output.copy_from_slice(&hash.finalize());
244    let res = Scalar(DScalar::from_bytes_mod_order_wide(&output));
245    output.zeroize();
246    res
247  }
248}
249
250impl Field for Scalar {
251  const ZERO: Scalar = Scalar(DScalar::ZERO);
252  const ONE: Scalar = Scalar(DScalar::ONE);
253
254  fn random(rng: impl RngCore) -> Self {
255    Self(<DScalar as Field>::random(rng))
256  }
257
258  fn square(&self) -> Self {
259    Self(self.0.square())
260  }
261  fn double(&self) -> Self {
262    Self(self.0.double())
263  }
264  fn invert(&self) -> CtOption<Self> {
265    <DScalar as Field>::invert(&self.0).map(Self)
266  }
267
268  fn sqrt(&self) -> CtOption<Self> {
269    self.0.sqrt().map(Self)
270  }
271
272  fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
273    let (choice, res) = DScalar::sqrt_ratio(num, div);
274    (choice, Self(res))
275  }
276}
277
278impl PrimeField for Scalar {
279  type Repr = [u8; 32];
280
281  const MODULUS: &'static str = <DScalar as PrimeField>::MODULUS;
282
283  const NUM_BITS: u32 = <DScalar as PrimeField>::NUM_BITS;
284  const CAPACITY: u32 = <DScalar as PrimeField>::CAPACITY;
285
286  const TWO_INV: Scalar = Scalar(<DScalar as PrimeField>::TWO_INV);
287
288  const MULTIPLICATIVE_GENERATOR: Scalar =
289    Scalar(<DScalar as PrimeField>::MULTIPLICATIVE_GENERATOR);
290  const S: u32 = <DScalar as PrimeField>::S;
291
292  const ROOT_OF_UNITY: Scalar = Scalar(<DScalar as PrimeField>::ROOT_OF_UNITY);
293  const ROOT_OF_UNITY_INV: Scalar = Scalar(<DScalar as PrimeField>::ROOT_OF_UNITY_INV);
294
295  const DELTA: Scalar = Scalar(<DScalar as PrimeField>::DELTA);
296
297  fn from_repr(bytes: [u8; 32]) -> CtOption<Self> {
298    <DScalar as PrimeField>::from_repr(bytes).map(Scalar)
299  }
300  fn to_repr(&self) -> [u8; 32] {
301    self.0.to_repr()
302  }
303
304  fn is_odd(&self) -> Choice {
305    self.0.is_odd()
306  }
307
308  fn from_u128(num: u128) -> Self {
309    Scalar(DScalar::from_u128(num))
310  }
311}
312
313impl PrimeFieldBits for Scalar {
314  type ReprBits = [u8; 32];
315
316  fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
317    self.to_repr().into()
318  }
319
320  fn char_le_bits() -> FieldBits<Self::ReprBits> {
321    BASEPOINT_ORDER.to_bytes().into()
322  }
323}
324
325impl Sum<Scalar> for Scalar {
326  fn sum<I: Iterator<Item = Scalar>>(iter: I) -> Scalar {
327    Self(DScalar::sum(iter))
328  }
329}
330
331impl<'a> Sum<&'a Scalar> for Scalar {
332  fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Scalar {
333    Self(DScalar::sum(iter))
334  }
335}
336
337impl Product<Scalar> for Scalar {
338  fn product<I: Iterator<Item = Scalar>>(iter: I) -> Scalar {
339    Self(DScalar::product(iter))
340  }
341}
342
343impl<'a> Product<&'a Scalar> for Scalar {
344  fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Scalar {
345    Self(DScalar::product(iter))
346  }
347}
348
349macro_rules! dalek_group {
350  (
351    $Point: ident,
352    $DPoint: ident,
353    $torsion_free: expr,
354
355    $Table: ident,
356
357    $DCompressed: ident,
358
359    $BASEPOINT_POINT: ident,
360    $BASEPOINT_TABLE: ident
361  ) => {
362    /// Wrapper around the dalek Point type. For Ed25519, this is restricted to the prime subgroup.
363    #[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroize)]
364    pub struct $Point(pub $DPoint);
365    deref_borrow!($Point, $DPoint);
366    constant_time!($Point, $DPoint);
367    math_neg!($Point, Scalar, $DPoint::add, $DPoint::sub, $DPoint::mul);
368
369    /// The basepoint for this curve.
370    pub const $BASEPOINT_POINT: $Point = $Point(constants::$BASEPOINT_POINT);
371
372    impl Sum<$Point> for $Point {
373      fn sum<I: Iterator<Item = $Point>>(iter: I) -> $Point {
374        Self($DPoint::sum(iter))
375      }
376    }
377    impl<'a> Sum<&'a $Point> for $Point {
378      fn sum<I: Iterator<Item = &'a $Point>>(iter: I) -> $Point {
379        Self($DPoint::sum(iter))
380      }
381    }
382
383    impl Group for $Point {
384      type Scalar = Scalar;
385      fn random(mut rng: impl RngCore) -> Self {
386        loop {
387          let mut bytes = [0; 32];
388          rng.fill_bytes(&mut bytes);
389          let Some(point) = Option::<$Point>::from($Point::from_bytes(&bytes)) else {
390            continue;
391          };
392          // Ban identity, per the trait specification
393          if !bool::from(point.is_identity()) {
394            return point;
395          }
396        }
397      }
398      fn identity() -> Self {
399        Self($DPoint::identity())
400      }
401      fn generator() -> Self {
402        $BASEPOINT_POINT
403      }
404      fn is_identity(&self) -> Choice {
405        self.0.ct_eq(&$DPoint::identity())
406      }
407      fn double(&self) -> Self {
408        Self(self.0.double())
409      }
410    }
411
412    impl GroupEncoding for $Point {
413      type Repr = [u8; 32];
414
415      fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
416        let decompressed = $DCompressed(*bytes).decompress();
417        // TODO: Same note on unwrap_or as above
418        let point = decompressed.unwrap_or($DPoint::identity());
419        CtOption::new(
420          $Point(point),
421          choice(black_box(decompressed).is_some()) & choice($torsion_free(point)),
422        )
423      }
424
425      fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
426        $Point::from_bytes(bytes)
427      }
428
429      fn to_bytes(&self) -> Self::Repr {
430        self.0.to_bytes()
431      }
432    }
433
434    impl PrimeGroup for $Point {}
435
436    impl Mul<Scalar> for &$Table {
437      type Output = $Point;
438      fn mul(self, b: Scalar) -> $Point {
439        $Point(&b.0 * self)
440      }
441    }
442
443    // Support being used as a key in a table
444    // While it is expensive as a key, due to the field operations required, there's frequently
445    // use cases for public key -> value lookups
446    #[allow(unknown_lints, renamed_and_removed_lints)]
447    #[allow(clippy::derived_hash_with_manual_eq, clippy::derive_hash_xor_eq)]
448    impl Hash for $Point {
449      fn hash<H: Hasher>(&self, state: &mut H) {
450        self.to_bytes().hash(state);
451      }
452    }
453  };
454}
455
456dalek_group!(
457  EdwardsPoint,
458  DEdwardsPoint,
459  |point: DEdwardsPoint| point.is_torsion_free(),
460  EdwardsBasepointTable,
461  CompressedEdwardsY,
462  ED25519_BASEPOINT_POINT,
463  ED25519_BASEPOINT_TABLE
464);
465
466impl EdwardsPoint {
467  pub fn mul_by_cofactor(&self) -> EdwardsPoint {
468    EdwardsPoint(self.0.mul_by_cofactor())
469  }
470}
471
472dalek_group!(
473  RistrettoPoint,
474  DRistrettoPoint,
475  |_| true,
476  RistrettoBasepointTable,
477  CompressedRistretto,
478  RISTRETTO_BASEPOINT_POINT,
479  RISTRETTO_BASEPOINT_TABLE
480);
481
482#[test]
483fn test_ed25519_group() {
484  ff_group_tests::group::test_prime_group_bits::<_, EdwardsPoint>(&mut rand_core::OsRng);
485}
486
487#[test]
488fn test_ristretto_group() {
489  ff_group_tests::group::test_prime_group_bits::<_, RistrettoPoint>(&mut rand_core::OsRng);
490}