• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2013, Kenneth MacKay
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are
7  * met:
8  *  * Redistributions of source code must retain the above copyright
9  *   notice, this list of conditions and the following disclaimer.
10  *  * Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
18  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
19  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
20  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
22  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  */
26 
27 #include <linux/random.h>
28 
29 #include "ecc.h"
30 
31 /* 256-bit curve */
32 #define ECC_BYTES 32
33 
34 #define MAX_TRIES 16
35 
36 /* Number of u64's needed */
37 #define NUM_ECC_DIGITS (ECC_BYTES / 8)
38 
39 struct ecc_point {
40 	u64 x[NUM_ECC_DIGITS];
41 	u64 y[NUM_ECC_DIGITS];
42 };
43 
44 typedef struct {
45 	u64 m_low;
46 	u64 m_high;
47 } uint128_t;
48 
49 #define CURVE_P_32 {	0xFFFFFFFFFFFFFFFFull, 0x00000000FFFFFFFFull, \
50 			0x0000000000000000ull, 0xFFFFFFFF00000001ull }
51 
52 #define CURVE_G_32 { \
53 		{	0xF4A13945D898C296ull, 0x77037D812DEB33A0ull,	\
54 			0xF8BCE6E563A440F2ull, 0x6B17D1F2E12C4247ull }, \
55 		{	0xCBB6406837BF51F5ull, 0x2BCE33576B315ECEull,	\
56 			0x8EE7EB4A7C0F9E16ull, 0x4FE342E2FE1A7F9Bull }	\
57 }
58 
59 #define CURVE_N_32 {	0xF3B9CAC2FC632551ull, 0xBCE6FAADA7179E84ull,	\
60 			0xFFFFFFFFFFFFFFFFull, 0xFFFFFFFF00000000ull }
61 
62 static u64 curve_p[NUM_ECC_DIGITS] = CURVE_P_32;
63 static struct ecc_point curve_g = CURVE_G_32;
64 static u64 curve_n[NUM_ECC_DIGITS] = CURVE_N_32;
65 
vli_clear(u64 * vli)66 static void vli_clear(u64 *vli)
67 {
68 	int i;
69 
70 	for (i = 0; i < NUM_ECC_DIGITS; i++)
71 		vli[i] = 0;
72 }
73 
74 /* Returns true if vli == 0, false otherwise. */
vli_is_zero(const u64 * vli)75 static bool vli_is_zero(const u64 *vli)
76 {
77 	int i;
78 
79 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
80 		if (vli[i])
81 			return false;
82 	}
83 
84 	return true;
85 }
86 
87 /* Returns nonzero if bit bit of vli is set. */
vli_test_bit(const u64 * vli,unsigned int bit)88 static u64 vli_test_bit(const u64 *vli, unsigned int bit)
89 {
90 	return (vli[bit / 64] & ((u64) 1 << (bit % 64)));
91 }
92 
93 /* Counts the number of 64-bit "digits" in vli. */
vli_num_digits(const u64 * vli)94 static unsigned int vli_num_digits(const u64 *vli)
95 {
96 	int i;
97 
98 	/* Search from the end until we find a non-zero digit.
99 	 * We do it in reverse because we expect that most digits will
100 	 * be nonzero.
101 	 */
102 	for (i = NUM_ECC_DIGITS - 1; i >= 0 && vli[i] == 0; i--);
103 
104 	return (i + 1);
105 }
106 
107 /* Counts the number of bits required for vli. */
vli_num_bits(const u64 * vli)108 static unsigned int vli_num_bits(const u64 *vli)
109 {
110 	unsigned int i, num_digits;
111 	u64 digit;
112 
113 	num_digits = vli_num_digits(vli);
114 	if (num_digits == 0)
115 		return 0;
116 
117 	digit = vli[num_digits - 1];
118 	for (i = 0; digit; i++)
119 		digit >>= 1;
120 
121 	return ((num_digits - 1) * 64 + i);
122 }
123 
124 /* Sets dest = src. */
vli_set(u64 * dest,const u64 * src)125 static void vli_set(u64 *dest, const u64 *src)
126 {
127 	int i;
128 
129 	for (i = 0; i < NUM_ECC_DIGITS; i++)
130 		dest[i] = src[i];
131 }
132 
133 /* Returns sign of left - right. */
vli_cmp(const u64 * left,const u64 * right)134 static int vli_cmp(const u64 *left, const u64 *right)
135 {
136     int i;
137 
138     for (i = NUM_ECC_DIGITS - 1; i >= 0; i--) {
139 	    if (left[i] > right[i])
140 		    return 1;
141 	    else if (left[i] < right[i])
142 		    return -1;
143     }
144 
145     return 0;
146 }
147 
148 /* Computes result = in << c, returning carry. Can modify in place
149  * (if result == in). 0 < shift < 64.
150  */
vli_lshift(u64 * result,const u64 * in,unsigned int shift)151 static u64 vli_lshift(u64 *result, const u64 *in,
152 			   unsigned int shift)
153 {
154 	u64 carry = 0;
155 	int i;
156 
157 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
158 		u64 temp = in[i];
159 
160 		result[i] = (temp << shift) | carry;
161 		carry = temp >> (64 - shift);
162 	}
163 
164 	return carry;
165 }
166 
167 /* Computes vli = vli >> 1. */
vli_rshift1(u64 * vli)168 static void vli_rshift1(u64 *vli)
169 {
170 	u64 *end = vli;
171 	u64 carry = 0;
172 
173 	vli += NUM_ECC_DIGITS;
174 
175 	while (vli-- > end) {
176 		u64 temp = *vli;
177 		*vli = (temp >> 1) | carry;
178 		carry = temp << 63;
179 	}
180 }
181 
182 /* Computes result = left + right, returning carry. Can modify in place. */
vli_add(u64 * result,const u64 * left,const u64 * right)183 static u64 vli_add(u64 *result, const u64 *left,
184 			const u64 *right)
185 {
186 	u64 carry = 0;
187 	int i;
188 
189 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
190 		u64 sum;
191 
192 		sum = left[i] + right[i] + carry;
193 		if (sum != left[i])
194 			carry = (sum < left[i]);
195 
196 		result[i] = sum;
197 	}
198 
199 	return carry;
200 }
201 
202 /* Computes result = left - right, returning borrow. Can modify in place. */
vli_sub(u64 * result,const u64 * left,const u64 * right)203 static u64 vli_sub(u64 *result, const u64 *left, const u64 *right)
204 {
205 	u64 borrow = 0;
206 	int i;
207 
208 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
209 		u64 diff;
210 
211 		diff = left[i] - right[i] - borrow;
212 		if (diff != left[i])
213 			borrow = (diff > left[i]);
214 
215 		result[i] = diff;
216 	}
217 
218 	return borrow;
219 }
220 
mul_64_64(u64 left,u64 right)221 static uint128_t mul_64_64(u64 left, u64 right)
222 {
223 	u64 a0 = left & 0xffffffffull;
224 	u64 a1 = left >> 32;
225 	u64 b0 = right & 0xffffffffull;
226 	u64 b1 = right >> 32;
227 	u64 m0 = a0 * b0;
228 	u64 m1 = a0 * b1;
229 	u64 m2 = a1 * b0;
230 	u64 m3 = a1 * b1;
231 	uint128_t result;
232 
233 	m2 += (m0 >> 32);
234 	m2 += m1;
235 
236 	/* Overflow */
237 	if (m2 < m1)
238 		m3 += 0x100000000ull;
239 
240 	result.m_low = (m0 & 0xffffffffull) | (m2 << 32);
241 	result.m_high = m3 + (m2 >> 32);
242 
243 	return result;
244 }
245 
add_128_128(uint128_t a,uint128_t b)246 static uint128_t add_128_128(uint128_t a, uint128_t b)
247 {
248 	uint128_t result;
249 
250 	result.m_low = a.m_low + b.m_low;
251 	result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low);
252 
253 	return result;
254 }
255 
vli_mult(u64 * result,const u64 * left,const u64 * right)256 static void vli_mult(u64 *result, const u64 *left, const u64 *right)
257 {
258 	uint128_t r01 = { 0, 0 };
259 	u64 r2 = 0;
260 	unsigned int i, k;
261 
262 	/* Compute each digit of result in sequence, maintaining the
263 	 * carries.
264 	 */
265 	for (k = 0; k < NUM_ECC_DIGITS * 2 - 1; k++) {
266 		unsigned int min;
267 
268 		if (k < NUM_ECC_DIGITS)
269 			min = 0;
270 		else
271 			min = (k + 1) - NUM_ECC_DIGITS;
272 
273 		for (i = min; i <= k && i < NUM_ECC_DIGITS; i++) {
274 			uint128_t product;
275 
276 			product = mul_64_64(left[i], right[k - i]);
277 
278 			r01 = add_128_128(r01, product);
279 			r2 += (r01.m_high < product.m_high);
280 		}
281 
282 		result[k] = r01.m_low;
283 		r01.m_low = r01.m_high;
284 		r01.m_high = r2;
285 		r2 = 0;
286 	}
287 
288 	result[NUM_ECC_DIGITS * 2 - 1] = r01.m_low;
289 }
290 
vli_square(u64 * result,const u64 * left)291 static void vli_square(u64 *result, const u64 *left)
292 {
293 	uint128_t r01 = { 0, 0 };
294 	u64 r2 = 0;
295 	int i, k;
296 
297 	for (k = 0; k < NUM_ECC_DIGITS * 2 - 1; k++) {
298 		unsigned int min;
299 
300 		if (k < NUM_ECC_DIGITS)
301 			min = 0;
302 		else
303 			min = (k + 1) - NUM_ECC_DIGITS;
304 
305 		for (i = min; i <= k && i <= k - i; i++) {
306 			uint128_t product;
307 
308 			product = mul_64_64(left[i], left[k - i]);
309 
310 			if (i < k - i) {
311 				r2 += product.m_high >> 63;
312 				product.m_high = (product.m_high << 1) |
313 						 (product.m_low >> 63);
314 				product.m_low <<= 1;
315 			}
316 
317 			r01 = add_128_128(r01, product);
318 			r2 += (r01.m_high < product.m_high);
319 		}
320 
321 		result[k] = r01.m_low;
322 		r01.m_low = r01.m_high;
323 		r01.m_high = r2;
324 		r2 = 0;
325 	}
326 
327 	result[NUM_ECC_DIGITS * 2 - 1] = r01.m_low;
328 }
329 
330 /* Computes result = (left + right) % mod.
331  * Assumes that left < mod and right < mod, result != mod.
332  */
vli_mod_add(u64 * result,const u64 * left,const u64 * right,const u64 * mod)333 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right,
334 			const u64 *mod)
335 {
336 	u64 carry;
337 
338 	carry = vli_add(result, left, right);
339 
340 	/* result > mod (result = mod + remainder), so subtract mod to
341 	 * get remainder.
342 	 */
343 	if (carry || vli_cmp(result, mod) >= 0)
344 		vli_sub(result, result, mod);
345 }
346 
347 /* Computes result = (left - right) % mod.
348  * Assumes that left < mod and right < mod, result != mod.
349  */
vli_mod_sub(u64 * result,const u64 * left,const u64 * right,const u64 * mod)350 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right,
351 			const u64 *mod)
352 {
353 	u64 borrow = vli_sub(result, left, right);
354 
355 	/* In this case, p_result == -diff == (max int) - diff.
356 	 * Since -x % d == d - x, we can get the correct result from
357 	 * result + mod (with overflow).
358 	 */
359 	if (borrow)
360 		vli_add(result, result, mod);
361 }
362 
363 /* Computes result = product % curve_p
364    from http://www.nsa.gov/ia/_files/nist-routines.pdf */
vli_mmod_fast(u64 * result,const u64 * product)365 static void vli_mmod_fast(u64 *result, const u64 *product)
366 {
367 	u64 tmp[NUM_ECC_DIGITS];
368 	int carry;
369 
370 	/* t */
371 	vli_set(result, product);
372 
373 	/* s1 */
374 	tmp[0] = 0;
375 	tmp[1] = product[5] & 0xffffffff00000000ull;
376 	tmp[2] = product[6];
377 	tmp[3] = product[7];
378 	carry = vli_lshift(tmp, tmp, 1);
379 	carry += vli_add(result, result, tmp);
380 
381 	/* s2 */
382 	tmp[1] = product[6] << 32;
383 	tmp[2] = (product[6] >> 32) | (product[7] << 32);
384 	tmp[3] = product[7] >> 32;
385 	carry += vli_lshift(tmp, tmp, 1);
386 	carry += vli_add(result, result, tmp);
387 
388 	/* s3 */
389 	tmp[0] = product[4];
390 	tmp[1] = product[5] & 0xffffffff;
391 	tmp[2] = 0;
392 	tmp[3] = product[7];
393 	carry += vli_add(result, result, tmp);
394 
395 	/* s4 */
396 	tmp[0] = (product[4] >> 32) | (product[5] << 32);
397 	tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull);
398 	tmp[2] = product[7];
399 	tmp[3] = (product[6] >> 32) | (product[4] << 32);
400 	carry += vli_add(result, result, tmp);
401 
402 	/* d1 */
403 	tmp[0] = (product[5] >> 32) | (product[6] << 32);
404 	tmp[1] = (product[6] >> 32);
405 	tmp[2] = 0;
406 	tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32);
407 	carry -= vli_sub(result, result, tmp);
408 
409 	/* d2 */
410 	tmp[0] = product[6];
411 	tmp[1] = product[7];
412 	tmp[2] = 0;
413 	tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull);
414 	carry -= vli_sub(result, result, tmp);
415 
416 	/* d3 */
417 	tmp[0] = (product[6] >> 32) | (product[7] << 32);
418 	tmp[1] = (product[7] >> 32) | (product[4] << 32);
419 	tmp[2] = (product[4] >> 32) | (product[5] << 32);
420 	tmp[3] = (product[6] << 32);
421 	carry -= vli_sub(result, result, tmp);
422 
423 	/* d4 */
424 	tmp[0] = product[7];
425 	tmp[1] = product[4] & 0xffffffff00000000ull;
426 	tmp[2] = product[5];
427 	tmp[3] = product[6] & 0xffffffff00000000ull;
428 	carry -= vli_sub(result, result, tmp);
429 
430 	if (carry < 0) {
431 		do {
432 			carry += vli_add(result, result, curve_p);
433 		} while (carry < 0);
434 	} else {
435 		while (carry || vli_cmp(curve_p, result) != 1)
436 			carry -= vli_sub(result, result, curve_p);
437 	}
438 }
439 
440 /* Computes result = (left * right) % curve_p. */
vli_mod_mult_fast(u64 * result,const u64 * left,const u64 * right)441 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right)
442 {
443 	u64 product[2 * NUM_ECC_DIGITS];
444 
445 	vli_mult(product, left, right);
446 	vli_mmod_fast(result, product);
447 }
448 
449 /* Computes result = left^2 % curve_p. */
vli_mod_square_fast(u64 * result,const u64 * left)450 static void vli_mod_square_fast(u64 *result, const u64 *left)
451 {
452 	u64 product[2 * NUM_ECC_DIGITS];
453 
454 	vli_square(product, left);
455 	vli_mmod_fast(result, product);
456 }
457 
458 #define EVEN(vli) (!(vli[0] & 1))
459 /* Computes result = (1 / p_input) % mod. All VLIs are the same size.
460  * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide"
461  * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf
462  */
vli_mod_inv(u64 * result,const u64 * input,const u64 * mod)463 static void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod)
464 {
465 	u64 a[NUM_ECC_DIGITS], b[NUM_ECC_DIGITS];
466 	u64 u[NUM_ECC_DIGITS], v[NUM_ECC_DIGITS];
467 	u64 carry;
468 	int cmp_result;
469 
470 	if (vli_is_zero(input)) {
471 		vli_clear(result);
472 		return;
473 	}
474 
475 	vli_set(a, input);
476 	vli_set(b, mod);
477 	vli_clear(u);
478 	u[0] = 1;
479 	vli_clear(v);
480 
481 	while ((cmp_result = vli_cmp(a, b)) != 0) {
482 		carry = 0;
483 
484 		if (EVEN(a)) {
485 			vli_rshift1(a);
486 
487 			if (!EVEN(u))
488 				carry = vli_add(u, u, mod);
489 
490 			vli_rshift1(u);
491 			if (carry)
492 				u[NUM_ECC_DIGITS - 1] |= 0x8000000000000000ull;
493 		} else if (EVEN(b)) {
494 			vli_rshift1(b);
495 
496 			if (!EVEN(v))
497 				carry = vli_add(v, v, mod);
498 
499 			vli_rshift1(v);
500 			if (carry)
501 				v[NUM_ECC_DIGITS - 1] |= 0x8000000000000000ull;
502 		} else if (cmp_result > 0) {
503 			vli_sub(a, a, b);
504 			vli_rshift1(a);
505 
506 			if (vli_cmp(u, v) < 0)
507 				vli_add(u, u, mod);
508 
509 			vli_sub(u, u, v);
510 			if (!EVEN(u))
511 				carry = vli_add(u, u, mod);
512 
513 			vli_rshift1(u);
514 			if (carry)
515 				u[NUM_ECC_DIGITS - 1] |= 0x8000000000000000ull;
516 		} else {
517 			vli_sub(b, b, a);
518 			vli_rshift1(b);
519 
520 			if (vli_cmp(v, u) < 0)
521 				vli_add(v, v, mod);
522 
523 			vli_sub(v, v, u);
524 			if (!EVEN(v))
525 				carry = vli_add(v, v, mod);
526 
527 			vli_rshift1(v);
528 			if (carry)
529 				v[NUM_ECC_DIGITS - 1] |= 0x8000000000000000ull;
530 		}
531 	}
532 
533 	vli_set(result, u);
534 }
535 
536 /* ------ Point operations ------ */
537 
538 /* Returns true if p_point is the point at infinity, false otherwise. */
ecc_point_is_zero(const struct ecc_point * point)539 static bool ecc_point_is_zero(const struct ecc_point *point)
540 {
541 	return (vli_is_zero(point->x) && vli_is_zero(point->y));
542 }
543 
544 /* Point multiplication algorithm using Montgomery's ladder with co-Z
545  * coordinates. From http://eprint.iacr.org/2011/338.pdf
546  */
547 
548 /* Double in place */
ecc_point_double_jacobian(u64 * x1,u64 * y1,u64 * z1)549 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1)
550 {
551 	/* t1 = x, t2 = y, t3 = z */
552 	u64 t4[NUM_ECC_DIGITS];
553 	u64 t5[NUM_ECC_DIGITS];
554 
555 	if (vli_is_zero(z1))
556 		return;
557 
558 	vli_mod_square_fast(t4, y1);   /* t4 = y1^2 */
559 	vli_mod_mult_fast(t5, x1, t4); /* t5 = x1*y1^2 = A */
560 	vli_mod_square_fast(t4, t4);   /* t4 = y1^4 */
561 	vli_mod_mult_fast(y1, y1, z1); /* t2 = y1*z1 = z3 */
562 	vli_mod_square_fast(z1, z1);   /* t3 = z1^2 */
563 
564 	vli_mod_add(x1, x1, z1, curve_p); /* t1 = x1 + z1^2 */
565 	vli_mod_add(z1, z1, z1, curve_p); /* t3 = 2*z1^2 */
566 	vli_mod_sub(z1, x1, z1, curve_p); /* t3 = x1 - z1^2 */
567 	vli_mod_mult_fast(x1, x1, z1);    /* t1 = x1^2 - z1^4 */
568 
569 	vli_mod_add(z1, x1, x1, curve_p); /* t3 = 2*(x1^2 - z1^4) */
570 	vli_mod_add(x1, x1, z1, curve_p); /* t1 = 3*(x1^2 - z1^4) */
571 	if (vli_test_bit(x1, 0)) {
572 		u64 carry = vli_add(x1, x1, curve_p);
573 		vli_rshift1(x1);
574 		x1[NUM_ECC_DIGITS - 1] |= carry << 63;
575 	} else {
576 		vli_rshift1(x1);
577 	}
578 	/* t1 = 3/2*(x1^2 - z1^4) = B */
579 
580 	vli_mod_square_fast(z1, x1);      /* t3 = B^2 */
581 	vli_mod_sub(z1, z1, t5, curve_p); /* t3 = B^2 - A */
582 	vli_mod_sub(z1, z1, t5, curve_p); /* t3 = B^2 - 2A = x3 */
583 	vli_mod_sub(t5, t5, z1, curve_p); /* t5 = A - x3 */
584 	vli_mod_mult_fast(x1, x1, t5);    /* t1 = B * (A - x3) */
585 	vli_mod_sub(t4, x1, t4, curve_p); /* t4 = B * (A - x3) - y1^4 = y3 */
586 
587 	vli_set(x1, z1);
588 	vli_set(z1, y1);
589 	vli_set(y1, t4);
590 }
591 
592 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */
apply_z(u64 * x1,u64 * y1,u64 * z)593 static void apply_z(u64 *x1, u64 *y1, u64 *z)
594 {
595 	u64 t1[NUM_ECC_DIGITS];
596 
597 	vli_mod_square_fast(t1, z);    /* z^2 */
598 	vli_mod_mult_fast(x1, x1, t1); /* x1 * z^2 */
599 	vli_mod_mult_fast(t1, t1, z);  /* z^3 */
600 	vli_mod_mult_fast(y1, y1, t1); /* y1 * z^3 */
601 }
602 
603 /* P = (x1, y1) => 2P, (x2, y2) => P' */
xycz_initial_double(u64 * x1,u64 * y1,u64 * x2,u64 * y2,u64 * p_initial_z)604 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2,
605 				u64 *p_initial_z)
606 {
607 	u64 z[NUM_ECC_DIGITS];
608 
609 	vli_set(x2, x1);
610 	vli_set(y2, y1);
611 
612 	vli_clear(z);
613 	z[0] = 1;
614 
615 	if (p_initial_z)
616 		vli_set(z, p_initial_z);
617 
618 	apply_z(x1, y1, z);
619 
620 	ecc_point_double_jacobian(x1, y1, z);
621 
622 	apply_z(x2, y2, z);
623 }
624 
625 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
626  * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3)
627  * or P => P', Q => P + Q
628  */
xycz_add(u64 * x1,u64 * y1,u64 * x2,u64 * y2)629 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2)
630 {
631 	/* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
632 	u64 t5[NUM_ECC_DIGITS];
633 
634 	vli_mod_sub(t5, x2, x1, curve_p); /* t5 = x2 - x1 */
635 	vli_mod_square_fast(t5, t5);      /* t5 = (x2 - x1)^2 = A */
636 	vli_mod_mult_fast(x1, x1, t5);    /* t1 = x1*A = B */
637 	vli_mod_mult_fast(x2, x2, t5);    /* t3 = x2*A = C */
638 	vli_mod_sub(y2, y2, y1, curve_p); /* t4 = y2 - y1 */
639 	vli_mod_square_fast(t5, y2);      /* t5 = (y2 - y1)^2 = D */
640 
641 	vli_mod_sub(t5, t5, x1, curve_p); /* t5 = D - B */
642 	vli_mod_sub(t5, t5, x2, curve_p); /* t5 = D - B - C = x3 */
643 	vli_mod_sub(x2, x2, x1, curve_p); /* t3 = C - B */
644 	vli_mod_mult_fast(y1, y1, x2);    /* t2 = y1*(C - B) */
645 	vli_mod_sub(x2, x1, t5, curve_p); /* t3 = B - x3 */
646 	vli_mod_mult_fast(y2, y2, x2);    /* t4 = (y2 - y1)*(B - x3) */
647 	vli_mod_sub(y2, y2, y1, curve_p); /* t4 = y3 */
648 
649 	vli_set(x2, t5);
650 }
651 
652 /* Input P = (x1, y1, Z), Q = (x2, y2, Z)
653  * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3)
654  * or P => P - Q, Q => P + Q
655  */
xycz_add_c(u64 * x1,u64 * y1,u64 * x2,u64 * y2)656 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2)
657 {
658 	/* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */
659 	u64 t5[NUM_ECC_DIGITS];
660 	u64 t6[NUM_ECC_DIGITS];
661 	u64 t7[NUM_ECC_DIGITS];
662 
663 	vli_mod_sub(t5, x2, x1, curve_p); /* t5 = x2 - x1 */
664 	vli_mod_square_fast(t5, t5);      /* t5 = (x2 - x1)^2 = A */
665 	vli_mod_mult_fast(x1, x1, t5);    /* t1 = x1*A = B */
666 	vli_mod_mult_fast(x2, x2, t5);    /* t3 = x2*A = C */
667 	vli_mod_add(t5, y2, y1, curve_p); /* t4 = y2 + y1 */
668 	vli_mod_sub(y2, y2, y1, curve_p); /* t4 = y2 - y1 */
669 
670 	vli_mod_sub(t6, x2, x1, curve_p); /* t6 = C - B */
671 	vli_mod_mult_fast(y1, y1, t6);    /* t2 = y1 * (C - B) */
672 	vli_mod_add(t6, x1, x2, curve_p); /* t6 = B + C */
673 	vli_mod_square_fast(x2, y2);      /* t3 = (y2 - y1)^2 */
674 	vli_mod_sub(x2, x2, t6, curve_p); /* t3 = x3 */
675 
676 	vli_mod_sub(t7, x1, x2, curve_p); /* t7 = B - x3 */
677 	vli_mod_mult_fast(y2, y2, t7);    /* t4 = (y2 - y1)*(B - x3) */
678 	vli_mod_sub(y2, y2, y1, curve_p); /* t4 = y3 */
679 
680 	vli_mod_square_fast(t7, t5);      /* t7 = (y2 + y1)^2 = F */
681 	vli_mod_sub(t7, t7, t6, curve_p); /* t7 = x3' */
682 	vli_mod_sub(t6, t7, x1, curve_p); /* t6 = x3' - B */
683 	vli_mod_mult_fast(t6, t6, t5);    /* t6 = (y2 + y1)*(x3' - B) */
684 	vli_mod_sub(y1, t6, y1, curve_p); /* t2 = y3' */
685 
686 	vli_set(x1, t7);
687 }
688 
ecc_point_mult(struct ecc_point * result,const struct ecc_point * point,u64 * scalar,u64 * initial_z,int num_bits)689 static void ecc_point_mult(struct ecc_point *result,
690 			   const struct ecc_point *point, u64 *scalar,
691 			   u64 *initial_z, int num_bits)
692 {
693 	/* R0 and R1 */
694 	u64 rx[2][NUM_ECC_DIGITS];
695 	u64 ry[2][NUM_ECC_DIGITS];
696 	u64 z[NUM_ECC_DIGITS];
697 	int i, nb;
698 
699 	vli_set(rx[1], point->x);
700 	vli_set(ry[1], point->y);
701 
702 	xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z);
703 
704 	for (i = num_bits - 2; i > 0; i--) {
705 		nb = !vli_test_bit(scalar, i);
706 		xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb]);
707 		xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb]);
708 	}
709 
710 	nb = !vli_test_bit(scalar, 0);
711 	xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb]);
712 
713 	/* Find final 1/Z value. */
714 	vli_mod_sub(z, rx[1], rx[0], curve_p); /* X1 - X0 */
715 	vli_mod_mult_fast(z, z, ry[1 - nb]); /* Yb * (X1 - X0) */
716 	vli_mod_mult_fast(z, z, point->x);   /* xP * Yb * (X1 - X0) */
717 	vli_mod_inv(z, z, curve_p);          /* 1 / (xP * Yb * (X1 - X0)) */
718 	vli_mod_mult_fast(z, z, point->y);   /* yP / (xP * Yb * (X1 - X0)) */
719 	vli_mod_mult_fast(z, z, rx[1 - nb]); /* Xb * yP / (xP * Yb * (X1 - X0)) */
720 	/* End 1/Z calculation */
721 
722 	xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb]);
723 
724 	apply_z(rx[0], ry[0], z);
725 
726 	vli_set(result->x, rx[0]);
727 	vli_set(result->y, ry[0]);
728 }
729 
ecc_bytes2native(const u8 bytes[ECC_BYTES],u64 native[NUM_ECC_DIGITS])730 static void ecc_bytes2native(const u8 bytes[ECC_BYTES],
731 			     u64 native[NUM_ECC_DIGITS])
732 {
733 	int i;
734 
735 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
736 		const u8 *digit = bytes + 8 * (NUM_ECC_DIGITS - 1 - i);
737 
738 		native[NUM_ECC_DIGITS - 1 - i] =
739 				((u64) digit[0] << 0) |
740 				((u64) digit[1] << 8) |
741 				((u64) digit[2] << 16) |
742 				((u64) digit[3] << 24) |
743 				((u64) digit[4] << 32) |
744 				((u64) digit[5] << 40) |
745 				((u64) digit[6] << 48) |
746 				((u64) digit[7] << 56);
747 	}
748 }
749 
ecc_native2bytes(const u64 native[NUM_ECC_DIGITS],u8 bytes[ECC_BYTES])750 static void ecc_native2bytes(const u64 native[NUM_ECC_DIGITS],
751 			     u8 bytes[ECC_BYTES])
752 {
753 	int i;
754 
755 	for (i = 0; i < NUM_ECC_DIGITS; i++) {
756 		u8 *digit = bytes + 8 * (NUM_ECC_DIGITS - 1 - i);
757 
758 		digit[0] = native[NUM_ECC_DIGITS - 1 - i] >> 0;
759 		digit[1] = native[NUM_ECC_DIGITS - 1 - i] >> 8;
760 		digit[2] = native[NUM_ECC_DIGITS - 1 - i] >> 16;
761 		digit[3] = native[NUM_ECC_DIGITS - 1 - i] >> 24;
762 		digit[4] = native[NUM_ECC_DIGITS - 1 - i] >> 32;
763 		digit[5] = native[NUM_ECC_DIGITS - 1 - i] >> 40;
764 		digit[6] = native[NUM_ECC_DIGITS - 1 - i] >> 48;
765 		digit[7] = native[NUM_ECC_DIGITS - 1 - i] >> 56;
766 	}
767 }
768 
ecc_make_key(u8 public_key[64],u8 private_key[32])769 bool ecc_make_key(u8 public_key[64], u8 private_key[32])
770 {
771 	struct ecc_point pk;
772 	u64 priv[NUM_ECC_DIGITS];
773 	unsigned int tries = 0;
774 
775 	do {
776 		if (tries++ >= MAX_TRIES)
777 			return false;
778 
779 		get_random_bytes(priv, ECC_BYTES);
780 
781 		if (vli_is_zero(priv))
782 			continue;
783 
784 		/* Make sure the private key is in the range [1, n-1]. */
785 		if (vli_cmp(curve_n, priv) != 1)
786 			continue;
787 
788 		ecc_point_mult(&pk, &curve_g, priv, NULL, vli_num_bits(priv));
789 	} while (ecc_point_is_zero(&pk));
790 
791 	ecc_native2bytes(priv, private_key);
792 	ecc_native2bytes(pk.x, public_key);
793 	ecc_native2bytes(pk.y, &public_key[32]);
794 
795 	return true;
796 }
797 
ecdh_shared_secret(const u8 public_key[64],const u8 private_key[32],u8 secret[32])798 bool ecdh_shared_secret(const u8 public_key[64], const u8 private_key[32],
799 		        u8 secret[32])
800 {
801 	u64 priv[NUM_ECC_DIGITS];
802 	u64 rand[NUM_ECC_DIGITS];
803 	struct ecc_point product, pk;
804 
805 	get_random_bytes(rand, ECC_BYTES);
806 
807 	ecc_bytes2native(public_key, pk.x);
808 	ecc_bytes2native(&public_key[32], pk.y);
809 	ecc_bytes2native(private_key, priv);
810 
811 	ecc_point_mult(&product, &pk, priv, rand, vli_num_bits(priv));
812 
813 	ecc_native2bytes(product.x, secret);
814 
815 	return !ecc_point_is_zero(&product);
816 }
817