monero_bulletproofs/plus/
weighted_inner_product.rs

1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{Scalar, EdwardsPoint};
7
8use monero_ed25519::CompressedPoint;
9use crate::{
10  core::{multiexp, multiexp_vartime, challenge_products},
11  batch_verifier::BulletproofsPlusBatchVerifier,
12  plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
13};
14
15const INV_EIGHT: monero_ed25519::Scalar = monero_ed25519::Scalar::INV_EIGHT;
16
17// Figure 1 of the Bulletproofs+ paper
18#[derive(Clone, Debug)]
19pub(crate) struct WipStatement {
20  generators: BpPlusGenerators,
21  P: EdwardsPoint,
22  y: ScalarVector,
23}
24
25impl Zeroize for WipStatement {
26  fn zeroize(&mut self) {
27    self.P.zeroize();
28    self.y.zeroize();
29  }
30}
31
32#[derive(Clone, Zeroize, ZeroizeOnDrop)]
33pub(crate) struct WipWitness {
34  a: ScalarVector,
35  b: ScalarVector,
36  alpha: Scalar,
37}
38
39impl WipWitness {
40  pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
41    if a.0.is_empty() || (a.len() != b.len()) {
42      return None;
43    }
44
45    // Pad to the nearest power of 2
46    let missing = padded_pow_of_2(a.len()) - a.len();
47    a.0.reserve(missing);
48    b.0.reserve(missing);
49    for _ in 0 .. missing {
50      a.0.push(Scalar::ZERO);
51      b.0.push(Scalar::ZERO);
52    }
53
54    Some(Self { a, b, alpha })
55  }
56}
57
58#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
59pub(crate) struct WipProof {
60  pub(crate) L: Vec<CompressedPoint>,
61  pub(crate) R: Vec<CompressedPoint>,
62  pub(crate) A: CompressedPoint,
63  pub(crate) B: CompressedPoint,
64  pub(crate) r_answer: Scalar,
65  pub(crate) s_answer: Scalar,
66  pub(crate) delta_answer: Scalar,
67}
68
69impl WipStatement {
70  pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
71    debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
72
73    // y ** n
74    let mut y_vec = ScalarVector::new(generators.len());
75    y_vec[0] = y;
76    for i in 1 .. y_vec.len() {
77      y_vec[i] = y_vec[i - 1] * y;
78    }
79
80    Self { generators, P, y: y_vec }
81  }
82
83  fn transcript_L_R(transcript: &mut Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
84    let e =
85      monero_ed25519::Scalar::hash([transcript.to_bytes(), L.to_bytes(), R.to_bytes()].concat())
86        .into();
87    *transcript = e;
88    e
89  }
90
91  fn transcript_A_B(transcript: &mut Scalar, A: CompressedPoint, B: CompressedPoint) -> Scalar {
92    let e =
93      monero_ed25519::Scalar::hash([transcript.to_bytes(), A.to_bytes(), B.to_bytes()].concat())
94        .into();
95    *transcript = e;
96    e
97  }
98
99  // Prover's variant of the shared code block to calculate G/H/P when n > 1
100  // Returns each permutation of G/H since the prover needs to do operation on each permutation
101  // P is dropped as it's unused in the prover's path
102  #[allow(clippy::too_many_arguments)]
103  fn next_G_H(
104    transcript: &mut Scalar,
105    mut g_bold1: PointVector,
106    mut g_bold2: PointVector,
107    mut h_bold1: PointVector,
108    mut h_bold2: PointVector,
109    L: CompressedPoint,
110    R: CompressedPoint,
111    y_inv_n_hat: Scalar,
112  ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
113    debug_assert_eq!(g_bold1.len(), g_bold2.len());
114    debug_assert_eq!(h_bold1.len(), h_bold2.len());
115    debug_assert_eq!(g_bold1.len(), h_bold1.len());
116
117    let e = Self::transcript_L_R(transcript, L, R);
118    let inv_e = e.invert();
119
120    // This vartime is safe as all of these arguments are public
121    let mut new_g_bold = Vec::with_capacity(g_bold1.len());
122    let e_y_inv = e * y_inv_n_hat;
123    for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
124      new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
125    }
126
127    let mut new_h_bold = Vec::with_capacity(h_bold1.len());
128    for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
129      new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
130    }
131
132    let e_square = e * e;
133    let inv_e_square = inv_e * inv_e;
134
135    (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
136  }
137
138  pub(crate) fn prove<R: RngCore + CryptoRng>(
139    self,
140    rng: &mut R,
141    mut transcript: Scalar,
142    witness: &WipWitness,
143  ) -> Option<WipProof> {
144    let WipStatement { generators, P, mut y } = self;
145    #[cfg(not(debug_assertions))]
146    let _ = P;
147
148    if generators.len() != witness.a.len() {
149      return None;
150    }
151    let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
152    let mut g_bold = vec![];
153    let mut h_bold = vec![];
154    for i in 0 .. generators.len() {
155      g_bold.push(generators.generator(GeneratorsList::GBold, i));
156      h_bold.push(generators.generator(GeneratorsList::HBold, i));
157    }
158    let mut g_bold = PointVector(g_bold);
159    let mut h_bold = PointVector(h_bold);
160
161    let mut y_inv = {
162      let mut i = 1;
163      let mut to_invert = vec![];
164      while i < g_bold.len() {
165        to_invert.push(y[i - 1]);
166        i *= 2;
167      }
168      Scalar::batch_invert(&mut to_invert);
169      to_invert
170    };
171
172    // Check P has the expected relationship
173    #[cfg(debug_assertions)]
174    {
175      let mut P_terms = witness
176        .a
177        .0
178        .iter()
179        .copied()
180        .zip(g_bold.0.iter().copied())
181        .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
182        .collect::<Vec<_>>();
183      P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
184      P_terms.push((witness.alpha, h));
185      debug_assert_eq!(multiexp(&P_terms), P);
186      P_terms.zeroize();
187    }
188
189    let mut a = witness.a.clone();
190    let mut b = witness.b.clone();
191    let mut alpha = witness.alpha;
192
193    // From here on, g_bold.len() is used as n
194    debug_assert_eq!(g_bold.len(), a.len());
195
196    let mut L_vec = vec![];
197    let mut R_vec = vec![];
198
199    // else n > 1 case from figure 1
200    while g_bold.len() > 1 {
201      let (a1, a2) = a.clone().split();
202      let (b1, b2) = b.clone().split();
203      let (g_bold1, g_bold2) = g_bold.split();
204      let (h_bold1, h_bold2) = h_bold.split();
205
206      let n_hat = g_bold1.len();
207      debug_assert_eq!(a1.len(), n_hat);
208      debug_assert_eq!(a2.len(), n_hat);
209      debug_assert_eq!(b1.len(), n_hat);
210      debug_assert_eq!(b2.len(), n_hat);
211      debug_assert_eq!(g_bold1.len(), n_hat);
212      debug_assert_eq!(g_bold2.len(), n_hat);
213      debug_assert_eq!(h_bold1.len(), n_hat);
214      debug_assert_eq!(h_bold2.len(), n_hat);
215
216      let y_n_hat = y[n_hat - 1];
217      y.0.truncate(n_hat);
218
219      let d_l = monero_ed25519::Scalar::random(&mut *rng).into();
220      let d_r = monero_ed25519::Scalar::random(&mut *rng).into();
221
222      let c_l = a1.clone().weighted_inner_product(&b2, &y);
223      let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
224
225      let y_inv_n_hat = y_inv
226        .pop()
227        .expect("couldn't pop y_inv despite y_inv being of same length as times iterated");
228
229      let mut L_terms = (a1.clone() * y_inv_n_hat)
230        .0
231        .drain(..)
232        .zip(g_bold2.0.iter().copied())
233        .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
234        .collect::<Vec<_>>();
235      L_terms.push((c_l, g));
236      L_terms.push((d_l, h));
237      let L = CompressedPoint::from((multiexp(&L_terms) * INV_EIGHT.into()).compress().to_bytes());
238      L_vec.push(L);
239      L_terms.zeroize();
240
241      let mut R_terms = (a2.clone() * y_n_hat)
242        .0
243        .drain(..)
244        .zip(g_bold1.0.iter().copied())
245        .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
246        .collect::<Vec<_>>();
247      R_terms.push((c_r, g));
248      R_terms.push((d_r, h));
249      let R = CompressedPoint::from((multiexp(&R_terms) * INV_EIGHT.into()).compress().to_bytes());
250      R_vec.push(R);
251      R_terms.zeroize();
252
253      let (e, inv_e, e_square, inv_e_square);
254      (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
255        Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
256
257      a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
258      b = (b1 * inv_e) + &(b2 * e);
259      alpha += (d_l * e_square) + (d_r * inv_e_square);
260
261      debug_assert_eq!(g_bold.len(), a.len());
262      debug_assert_eq!(g_bold.len(), h_bold.len());
263      debug_assert_eq!(g_bold.len(), b.len());
264    }
265
266    // n == 1 case from figure 1
267    debug_assert_eq!(g_bold.len(), 1);
268    debug_assert_eq!(h_bold.len(), 1);
269
270    debug_assert_eq!(a.len(), 1);
271    debug_assert_eq!(b.len(), 1);
272
273    let r = monero_ed25519::Scalar::random(&mut *rng).into();
274    let s = monero_ed25519::Scalar::random(&mut *rng).into();
275    let delta = monero_ed25519::Scalar::random(&mut *rng).into();
276    let eta = monero_ed25519::Scalar::random(&mut *rng).into();
277
278    let ry = r * y[0];
279
280    let mut A_terms =
281      vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
282    let A = CompressedPoint::from((multiexp(&A_terms) * INV_EIGHT.into()).compress().to_bytes());
283    A_terms.zeroize();
284
285    let mut B_terms = vec![(ry * s, g), (eta, h)];
286    let B = CompressedPoint::from((multiexp(&B_terms) * INV_EIGHT.into()).compress().to_bytes());
287    B_terms.zeroize();
288
289    let e = Self::transcript_A_B(&mut transcript, A, B);
290
291    let r_answer = r + (a[0] * e);
292    let s_answer = s + (b[0] * e);
293    let delta_answer = eta + (delta * e) + (alpha * (e * e));
294
295    Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
296  }
297
298  pub(crate) fn verify<R: RngCore + CryptoRng>(
299    self,
300    rng: &mut R,
301    verifier: &mut BulletproofsPlusBatchVerifier,
302    mut transcript: Scalar,
303    WipProof { L, R, A, B, r_answer, s_answer, delta_answer }: WipProof,
304  ) -> bool {
305    let verifier_weight = monero_ed25519::Scalar::random(rng).into();
306
307    let WipStatement { generators, P, y } = self;
308
309    // Verify the L/R lengths
310    {
311      let mut lr_len = 0;
312      while (1 << lr_len) < generators.len() {
313        lr_len += 1;
314      }
315      if (L.len() != lr_len) || (R.len() != lr_len) || (generators.len() != (1 << lr_len)) {
316        return false;
317      }
318    }
319
320    let inv_y = {
321      let inv_y = y[0].invert();
322      let mut res = Vec::with_capacity(y.len());
323      res.push(inv_y);
324      while res.len() < y.len() {
325        res.push(
326          inv_y * res.last().expect("couldn't get last inv_y despite inv_y always being non-empty"),
327        );
328      }
329      res
330    };
331
332    let mut e_is = Vec::with_capacity(L.len());
333    let mut L_decomp = Vec::with_capacity(L.len());
334    let mut R_decomp = Vec::with_capacity(R.len());
335
336    let decomp_mul_cofactor =
337      |p| CompressedPoint::decompress(&p).map(|p| EdwardsPoint::mul_by_cofactor(&p.into()));
338
339    for (L_i, R_i) in L.into_iter().zip(R.into_iter()) {
340      e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
341
342      let (Some(L_i), Some(R_i)) = (decomp_mul_cofactor(L_i), decomp_mul_cofactor(R_i)) else {
343        return false;
344      };
345
346      L_decomp.push(L_i);
347      R_decomp.push(R_i);
348    }
349
350    let L = L_decomp;
351    let R = R_decomp;
352
353    let e = Self::transcript_A_B(&mut transcript, A, B);
354
355    let (Some(A), Some(B)) = (decomp_mul_cofactor(A), decomp_mul_cofactor(B)) else {
356      return false;
357    };
358
359    let neg_e_square = verifier_weight * -(e * e);
360
361    verifier.0.other.push((neg_e_square, P));
362
363    let mut challenges = Vec::with_capacity(L.len());
364    let product_cache = {
365      let mut inv_e_is = e_is.clone();
366      Scalar::batch_invert(&mut inv_e_is);
367
368      debug_assert_eq!(e_is.len(), inv_e_is.len());
369      debug_assert_eq!(e_is.len(), L.len());
370      debug_assert_eq!(e_is.len(), R.len());
371      for ((e_i, inv_e_i), (L, R)) in
372        e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
373      {
374        debug_assert_eq!(e_i.invert(), inv_e_i);
375
376        challenges.push((e_i, inv_e_i));
377
378        let e_i_square = e_i * e_i;
379        let inv_e_i_square = inv_e_i * inv_e_i;
380        verifier.0.other.push((neg_e_square * e_i_square, *L));
381        verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
382      }
383
384      challenge_products(&challenges)
385    };
386
387    while verifier.0.g_bold.len() < generators.len() {
388      verifier.0.g_bold.push(Scalar::ZERO);
389    }
390    while verifier.0.h_bold.len() < generators.len() {
391      verifier.0.h_bold.push(Scalar::ZERO);
392    }
393
394    let re = r_answer * e;
395    for i in 0 .. generators.len() {
396      let mut scalar = product_cache[i] * re;
397      if i > 0 {
398        scalar *= inv_y[i - 1];
399      }
400      verifier.0.g_bold[i] += verifier_weight * scalar;
401    }
402
403    let se = s_answer * e;
404    for i in 0 .. generators.len() {
405      verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
406    }
407
408    verifier.0.other.push((verifier_weight * -e, A));
409    verifier.0.g += verifier_weight * (r_answer * y[0] * s_answer);
410    verifier.0.h += verifier_weight * delta_answer;
411    verifier.0.other.push((-verifier_weight, B));
412
413    true
414  }
415}