1use std_shims::{vec, vec::Vec};
2
3use rand_core::{RngCore, CryptoRng};
4use zeroize::{Zeroize, ZeroizeOnDrop};
5
6use curve25519_dalek::{Scalar, EdwardsPoint};
7
8use monero_ed25519::CompressedPoint;
9use crate::{
10 core::{multiexp, multiexp_vartime, challenge_products},
11 batch_verifier::BulletproofsPlusBatchVerifier,
12 plus::{ScalarVector, PointVector, GeneratorsList, BpPlusGenerators, padded_pow_of_2},
13};
14
15const INV_EIGHT: monero_ed25519::Scalar = monero_ed25519::Scalar::INV_EIGHT;
16
17#[derive(Clone, Debug)]
19pub(crate) struct WipStatement {
20 generators: BpPlusGenerators,
21 P: EdwardsPoint,
22 y: ScalarVector,
23}
24
25impl Zeroize for WipStatement {
26 fn zeroize(&mut self) {
27 self.P.zeroize();
28 self.y.zeroize();
29 }
30}
31
32#[derive(Clone, Zeroize, ZeroizeOnDrop)]
33pub(crate) struct WipWitness {
34 a: ScalarVector,
35 b: ScalarVector,
36 alpha: Scalar,
37}
38
39impl WipWitness {
40 pub(crate) fn new(mut a: ScalarVector, mut b: ScalarVector, alpha: Scalar) -> Option<Self> {
41 if a.0.is_empty() || (a.len() != b.len()) {
42 return None;
43 }
44
45 let missing = padded_pow_of_2(a.len()) - a.len();
47 a.0.reserve(missing);
48 b.0.reserve(missing);
49 for _ in 0 .. missing {
50 a.0.push(Scalar::ZERO);
51 b.0.push(Scalar::ZERO);
52 }
53
54 Some(Self { a, b, alpha })
55 }
56}
57
58#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
59pub(crate) struct WipProof {
60 pub(crate) L: Vec<CompressedPoint>,
61 pub(crate) R: Vec<CompressedPoint>,
62 pub(crate) A: CompressedPoint,
63 pub(crate) B: CompressedPoint,
64 pub(crate) r_answer: Scalar,
65 pub(crate) s_answer: Scalar,
66 pub(crate) delta_answer: Scalar,
67}
68
69impl WipStatement {
70 pub(crate) fn new(generators: BpPlusGenerators, P: EdwardsPoint, y: Scalar) -> Self {
71 debug_assert_eq!(generators.len(), padded_pow_of_2(generators.len()));
72
73 let mut y_vec = ScalarVector::new(generators.len());
75 y_vec[0] = y;
76 for i in 1 .. y_vec.len() {
77 y_vec[i] = y_vec[i - 1] * y;
78 }
79
80 Self { generators, P, y: y_vec }
81 }
82
83 fn transcript_L_R(transcript: &mut Scalar, L: CompressedPoint, R: CompressedPoint) -> Scalar {
84 let e =
85 monero_ed25519::Scalar::hash([transcript.to_bytes(), L.to_bytes(), R.to_bytes()].concat())
86 .into();
87 *transcript = e;
88 e
89 }
90
91 fn transcript_A_B(transcript: &mut Scalar, A: CompressedPoint, B: CompressedPoint) -> Scalar {
92 let e =
93 monero_ed25519::Scalar::hash([transcript.to_bytes(), A.to_bytes(), B.to_bytes()].concat())
94 .into();
95 *transcript = e;
96 e
97 }
98
99 #[allow(clippy::too_many_arguments)]
103 fn next_G_H(
104 transcript: &mut Scalar,
105 mut g_bold1: PointVector,
106 mut g_bold2: PointVector,
107 mut h_bold1: PointVector,
108 mut h_bold2: PointVector,
109 L: CompressedPoint,
110 R: CompressedPoint,
111 y_inv_n_hat: Scalar,
112 ) -> (Scalar, Scalar, Scalar, Scalar, PointVector, PointVector) {
113 debug_assert_eq!(g_bold1.len(), g_bold2.len());
114 debug_assert_eq!(h_bold1.len(), h_bold2.len());
115 debug_assert_eq!(g_bold1.len(), h_bold1.len());
116
117 let e = Self::transcript_L_R(transcript, L, R);
118 let inv_e = e.invert();
119
120 let mut new_g_bold = Vec::with_capacity(g_bold1.len());
122 let e_y_inv = e * y_inv_n_hat;
123 for g_bold in g_bold1.0.drain(..).zip(g_bold2.0.drain(..)) {
124 new_g_bold.push(multiexp_vartime(&[(inv_e, g_bold.0), (e_y_inv, g_bold.1)]));
125 }
126
127 let mut new_h_bold = Vec::with_capacity(h_bold1.len());
128 for h_bold in h_bold1.0.drain(..).zip(h_bold2.0.drain(..)) {
129 new_h_bold.push(multiexp_vartime(&[(e, h_bold.0), (inv_e, h_bold.1)]));
130 }
131
132 let e_square = e * e;
133 let inv_e_square = inv_e * inv_e;
134
135 (e, inv_e, e_square, inv_e_square, PointVector(new_g_bold), PointVector(new_h_bold))
136 }
137
138 pub(crate) fn prove<R: RngCore + CryptoRng>(
139 self,
140 rng: &mut R,
141 mut transcript: Scalar,
142 witness: &WipWitness,
143 ) -> Option<WipProof> {
144 let WipStatement { generators, P, mut y } = self;
145 #[cfg(not(debug_assertions))]
146 let _ = P;
147
148 if generators.len() != witness.a.len() {
149 return None;
150 }
151 let (g, h) = (BpPlusGenerators::g(), BpPlusGenerators::h());
152 let mut g_bold = vec![];
153 let mut h_bold = vec![];
154 for i in 0 .. generators.len() {
155 g_bold.push(generators.generator(GeneratorsList::GBold, i));
156 h_bold.push(generators.generator(GeneratorsList::HBold, i));
157 }
158 let mut g_bold = PointVector(g_bold);
159 let mut h_bold = PointVector(h_bold);
160
161 let mut y_inv = {
162 let mut i = 1;
163 let mut to_invert = vec![];
164 while i < g_bold.len() {
165 to_invert.push(y[i - 1]);
166 i *= 2;
167 }
168 Scalar::batch_invert(&mut to_invert);
169 to_invert
170 };
171
172 #[cfg(debug_assertions)]
174 {
175 let mut P_terms = witness
176 .a
177 .0
178 .iter()
179 .copied()
180 .zip(g_bold.0.iter().copied())
181 .chain(witness.b.0.iter().copied().zip(h_bold.0.iter().copied()))
182 .collect::<Vec<_>>();
183 P_terms.push((witness.a.clone().weighted_inner_product(&witness.b, &y), g));
184 P_terms.push((witness.alpha, h));
185 debug_assert_eq!(multiexp(&P_terms), P);
186 P_terms.zeroize();
187 }
188
189 let mut a = witness.a.clone();
190 let mut b = witness.b.clone();
191 let mut alpha = witness.alpha;
192
193 debug_assert_eq!(g_bold.len(), a.len());
195
196 let mut L_vec = vec![];
197 let mut R_vec = vec![];
198
199 while g_bold.len() > 1 {
201 let (a1, a2) = a.clone().split();
202 let (b1, b2) = b.clone().split();
203 let (g_bold1, g_bold2) = g_bold.split();
204 let (h_bold1, h_bold2) = h_bold.split();
205
206 let n_hat = g_bold1.len();
207 debug_assert_eq!(a1.len(), n_hat);
208 debug_assert_eq!(a2.len(), n_hat);
209 debug_assert_eq!(b1.len(), n_hat);
210 debug_assert_eq!(b2.len(), n_hat);
211 debug_assert_eq!(g_bold1.len(), n_hat);
212 debug_assert_eq!(g_bold2.len(), n_hat);
213 debug_assert_eq!(h_bold1.len(), n_hat);
214 debug_assert_eq!(h_bold2.len(), n_hat);
215
216 let y_n_hat = y[n_hat - 1];
217 y.0.truncate(n_hat);
218
219 let d_l = monero_ed25519::Scalar::random(&mut *rng).into();
220 let d_r = monero_ed25519::Scalar::random(&mut *rng).into();
221
222 let c_l = a1.clone().weighted_inner_product(&b2, &y);
223 let c_r = (a2.clone() * y_n_hat).weighted_inner_product(&b1, &y);
224
225 let y_inv_n_hat = y_inv
226 .pop()
227 .expect("couldn't pop y_inv despite y_inv being of same length as times iterated");
228
229 let mut L_terms = (a1.clone() * y_inv_n_hat)
230 .0
231 .drain(..)
232 .zip(g_bold2.0.iter().copied())
233 .chain(b2.0.iter().copied().zip(h_bold1.0.iter().copied()))
234 .collect::<Vec<_>>();
235 L_terms.push((c_l, g));
236 L_terms.push((d_l, h));
237 let L = CompressedPoint::from((multiexp(&L_terms) * INV_EIGHT.into()).compress().to_bytes());
238 L_vec.push(L);
239 L_terms.zeroize();
240
241 let mut R_terms = (a2.clone() * y_n_hat)
242 .0
243 .drain(..)
244 .zip(g_bold1.0.iter().copied())
245 .chain(b1.0.iter().copied().zip(h_bold2.0.iter().copied()))
246 .collect::<Vec<_>>();
247 R_terms.push((c_r, g));
248 R_terms.push((d_r, h));
249 let R = CompressedPoint::from((multiexp(&R_terms) * INV_EIGHT.into()).compress().to_bytes());
250 R_vec.push(R);
251 R_terms.zeroize();
252
253 let (e, inv_e, e_square, inv_e_square);
254 (e, inv_e, e_square, inv_e_square, g_bold, h_bold) =
255 Self::next_G_H(&mut transcript, g_bold1, g_bold2, h_bold1, h_bold2, L, R, y_inv_n_hat);
256
257 a = (a1 * e) + &(a2 * (y_n_hat * inv_e));
258 b = (b1 * inv_e) + &(b2 * e);
259 alpha += (d_l * e_square) + (d_r * inv_e_square);
260
261 debug_assert_eq!(g_bold.len(), a.len());
262 debug_assert_eq!(g_bold.len(), h_bold.len());
263 debug_assert_eq!(g_bold.len(), b.len());
264 }
265
266 debug_assert_eq!(g_bold.len(), 1);
268 debug_assert_eq!(h_bold.len(), 1);
269
270 debug_assert_eq!(a.len(), 1);
271 debug_assert_eq!(b.len(), 1);
272
273 let r = monero_ed25519::Scalar::random(&mut *rng).into();
274 let s = monero_ed25519::Scalar::random(&mut *rng).into();
275 let delta = monero_ed25519::Scalar::random(&mut *rng).into();
276 let eta = monero_ed25519::Scalar::random(&mut *rng).into();
277
278 let ry = r * y[0];
279
280 let mut A_terms =
281 vec![(r, g_bold[0]), (s, h_bold[0]), ((ry * b[0]) + (s * y[0] * a[0]), g), (delta, h)];
282 let A = CompressedPoint::from((multiexp(&A_terms) * INV_EIGHT.into()).compress().to_bytes());
283 A_terms.zeroize();
284
285 let mut B_terms = vec![(ry * s, g), (eta, h)];
286 let B = CompressedPoint::from((multiexp(&B_terms) * INV_EIGHT.into()).compress().to_bytes());
287 B_terms.zeroize();
288
289 let e = Self::transcript_A_B(&mut transcript, A, B);
290
291 let r_answer = r + (a[0] * e);
292 let s_answer = s + (b[0] * e);
293 let delta_answer = eta + (delta * e) + (alpha * (e * e));
294
295 Some(WipProof { L: L_vec, R: R_vec, A, B, r_answer, s_answer, delta_answer })
296 }
297
298 pub(crate) fn verify<R: RngCore + CryptoRng>(
299 self,
300 rng: &mut R,
301 verifier: &mut BulletproofsPlusBatchVerifier,
302 mut transcript: Scalar,
303 WipProof { L, R, A, B, r_answer, s_answer, delta_answer }: WipProof,
304 ) -> bool {
305 let verifier_weight = monero_ed25519::Scalar::random(rng).into();
306
307 let WipStatement { generators, P, y } = self;
308
309 {
311 let mut lr_len = 0;
312 while (1 << lr_len) < generators.len() {
313 lr_len += 1;
314 }
315 if (L.len() != lr_len) || (R.len() != lr_len) || (generators.len() != (1 << lr_len)) {
316 return false;
317 }
318 }
319
320 let inv_y = {
321 let inv_y = y[0].invert();
322 let mut res = Vec::with_capacity(y.len());
323 res.push(inv_y);
324 while res.len() < y.len() {
325 res.push(
326 inv_y * res.last().expect("couldn't get last inv_y despite inv_y always being non-empty"),
327 );
328 }
329 res
330 };
331
332 let mut e_is = Vec::with_capacity(L.len());
333 let mut L_decomp = Vec::with_capacity(L.len());
334 let mut R_decomp = Vec::with_capacity(R.len());
335
336 let decomp_mul_cofactor =
337 |p| CompressedPoint::decompress(&p).map(|p| EdwardsPoint::mul_by_cofactor(&p.into()));
338
339 for (L_i, R_i) in L.into_iter().zip(R.into_iter()) {
340 e_is.push(Self::transcript_L_R(&mut transcript, L_i, R_i));
341
342 let (Some(L_i), Some(R_i)) = (decomp_mul_cofactor(L_i), decomp_mul_cofactor(R_i)) else {
343 return false;
344 };
345
346 L_decomp.push(L_i);
347 R_decomp.push(R_i);
348 }
349
350 let L = L_decomp;
351 let R = R_decomp;
352
353 let e = Self::transcript_A_B(&mut transcript, A, B);
354
355 let (Some(A), Some(B)) = (decomp_mul_cofactor(A), decomp_mul_cofactor(B)) else {
356 return false;
357 };
358
359 let neg_e_square = verifier_weight * -(e * e);
360
361 verifier.0.other.push((neg_e_square, P));
362
363 let mut challenges = Vec::with_capacity(L.len());
364 let product_cache = {
365 let mut inv_e_is = e_is.clone();
366 Scalar::batch_invert(&mut inv_e_is);
367
368 debug_assert_eq!(e_is.len(), inv_e_is.len());
369 debug_assert_eq!(e_is.len(), L.len());
370 debug_assert_eq!(e_is.len(), R.len());
371 for ((e_i, inv_e_i), (L, R)) in
372 e_is.drain(..).zip(inv_e_is.drain(..)).zip(L.iter().zip(R.iter()))
373 {
374 debug_assert_eq!(e_i.invert(), inv_e_i);
375
376 challenges.push((e_i, inv_e_i));
377
378 let e_i_square = e_i * e_i;
379 let inv_e_i_square = inv_e_i * inv_e_i;
380 verifier.0.other.push((neg_e_square * e_i_square, *L));
381 verifier.0.other.push((neg_e_square * inv_e_i_square, *R));
382 }
383
384 challenge_products(&challenges)
385 };
386
387 while verifier.0.g_bold.len() < generators.len() {
388 verifier.0.g_bold.push(Scalar::ZERO);
389 }
390 while verifier.0.h_bold.len() < generators.len() {
391 verifier.0.h_bold.push(Scalar::ZERO);
392 }
393
394 let re = r_answer * e;
395 for i in 0 .. generators.len() {
396 let mut scalar = product_cache[i] * re;
397 if i > 0 {
398 scalar *= inv_y[i - 1];
399 }
400 verifier.0.g_bold[i] += verifier_weight * scalar;
401 }
402
403 let se = s_answer * e;
404 for i in 0 .. generators.len() {
405 verifier.0.h_bold[i] += verifier_weight * (se * product_cache[product_cache.len() - 1 - i]);
406 }
407
408 verifier.0.other.push((verifier_weight * -e, A));
409 verifier.0.g += verifier_weight * (r_answer * y[0] * s_answer);
410 verifier.0.h += verifier_weight * delta_answer;
411 verifier.0.other.push((-verifier_weight, B));
412
413 true
414 }
415}