1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include "ml_kem_local.h"
19
MLKEM_ComputNTT(int16_t * a,const int16_t * psi,uint32_t pruLength)20 void MLKEM_ComputNTT(int16_t *a, const int16_t *psi, uint32_t pruLength)
21 {
22 uint32_t t = MLKEM_N;
23 for (uint32_t m = 1; m < pruLength; m <<= 1) {
24 t >>= 1;
25 for (uint32_t i = 0; i < m; i++) {
26 uint32_t j1 = (i << 1) * t;
27 int16_t s = psi[m + i];
28 int16_t *x = a + j1;
29 int16_t *y = x + (int16_t)t;
30 for (uint32_t j = j1; j < j1 + t; j++) {
31 int32_t ys = (*y) * s;
32 *y = (*x - ys) % MLKEM_Q;
33 *x = (*x + ys) % MLKEM_Q;
34 MlKemAddModQ(y);
35 MlKemAddModQ(x);
36 y++;
37 x++;
38 }
39 }
40 }
41 }
42
MLKEM_ComputINTT(int16_t * a,const int16_t * psiInv,uint32_t pruLength)43 void MLKEM_ComputINTT(int16_t *a, const int16_t *psiInv, uint32_t pruLength)
44 {
45 uint32_t t = MLKEM_N / pruLength;
46 for (uint32_t m = pruLength; m > 1; m >>= 1) {
47 uint32_t j1 = 0;
48 uint32_t h = m >> 1;
49 for (uint32_t i = 0; i < h; i++) {
50 int16_t s = psiInv[h + i];
51 for (uint32_t j = j1; j < j1 + t; j++) {
52 int16_t u = a[j];
53 int16_t v = a[j + t];
54 a[j] = (u + v) % MLKEM_Q;
55 // Both u and v are smaller than MLKEM_Q, temp not overflow.
56 int16_t temp = u - v;
57 MlKemAddModQ(&a[j]);
58 MlKemAddModQ(&temp);
59 a[j + t] = ((int32_t)temp * s) % MLKEM_Q;
60 }
61 j1 += (t << 1);
62 }
63 t <<= 1;
64 }
65 for (uint32_t n = 0; n < MLKEM_N; n++) {
66 a[n] = (a[n] * MLKEM_INVN) % MLKEM_Q;
67 MlKemAddModQ(&a[n]);
68 }
69 }
70
71 #endif