1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include "ml_kem_local.h"
19
20 // basecase multiplication
BaseMul(int16_t polyH[2],int16_t f0,int16_t f1,int16_t g0,int16_t g1,int16_t factor)21 static void BaseMul(int16_t polyH[2], int16_t f0, int16_t f1, int16_t g0, int16_t g1, int16_t factor)
22 {
23 polyH[0] = (int16_t)(((int32_t)(f0 * g0) + (int32_t)(((int32_t)(f1 * g1) % MLKEM_Q) * factor)) % MLKEM_Q);
24 polyH[1] = (int16_t)(((int32_t)(f0 * g1) + (int32_t)(f1 * g0)) % MLKEM_Q);
25 MlKemAddModQ(&polyH[0]);
26 MlKemAddModQ(&polyH[1]);
27 }
28
29 // circle multiplication
CircMul(int16_t dest[MLKEM_N],int16_t src1[MLKEM_N],int16_t src2[MLKEM_N],const int16_t * factor)30 static void CircMul(int16_t dest[MLKEM_N], int16_t src1[MLKEM_N], int16_t src2[MLKEM_N], const int16_t *factor)
31 {
32 for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
33 // 4-byte data is calculated in each round.
34 BaseMul(&dest[4 * i], src1[4 * i], src1[4 * i + 1], src2[4 * i], src2[4 * i + 1], factor[i]);
35 BaseMul(&dest[4 * i + 2], src1[4 * i + 2], src1[4 * i + 3], src2[4 * i + 2], src2[4 * i + 3], -1 * factor[i]);
36 }
37 }
38
MLKEM_MatrixMulAdd(uint8_t k,int16_t * matrix[],int16_t * vectorS[],int16_t * vectorE,int16_t * vectorT,const int16_t * factor)39 void MLKEM_MatrixMulAdd(uint8_t k, int16_t *matrix[], int16_t *vectorS[], int16_t *vectorE,
40 int16_t *vectorT, const int16_t *factor)
41 {
42 int16_t dest[MLKEM_N] = { 0 };
43 for (uint8_t j = 0; j < k; j++) {
44 // factor is a half of the NTT table.
45 CircMul(dest, matrix[j], vectorS[j], factor + MLKEM_N_HALF / 2);
46 for (uint32_t n = 0; n < MLKEM_N; n++) {
47 if (j == 0) {
48 vectorT[n] = (vectorE == NULL) ? dest[n] : (vectorE[n] + dest[n]);
49 } else if (j != 0 && j != (k - 1)) {
50 vectorT[n] += dest[n];
51 } else if (j == (k - 1)) {
52 vectorT[n] = (vectorT[n] + dest[n]) % MLKEM_Q;
53 }
54 }
55 }
56 }
57
MLKEM_SamplePolyCBD(int16_t * polyF,uint8_t * buf,uint8_t eta)58 void MLKEM_SamplePolyCBD(int16_t *polyF, uint8_t *buf, uint8_t eta)
59 {
60 uint32_t i;
61 uint32_t j;
62 uint8_t a;
63 uint8_t b;
64 uint32_t t1;
65 if (eta == 3) { // The value of eta can only be 2 or 3.
66 for (i = 0; i < MLKEM_N / 4; i++) {
67 uint32_t temp = (uint32_t)buf[eta * i];
68 temp |= (uint32_t)buf[eta * i + 1] << 8;
69 temp |= (uint32_t)buf[eta * i + 2] << 16;
70 t1 = temp & 0x00249249; // temp & 0x00249249 is used to obtain a specific bit in temp.
71 t1 += (temp >> 1) & 0x00249249;
72 t1 += (temp >> 2) & 0x00249249;
73
74 for (j = 0; j < 4; j++) {
75 a = (t1 >> (6 * j)) & 0x3;
76 b = (t1 >> (6 * j + eta)) & 0x3;
77 polyF[4 * i + j] = a - b;
78 }
79 }
80 } else if (eta == 2) {
81 for (i = 0; i < MLKEM_N / 4; i++) {
82 uint16_t temp = (uint16_t)buf[eta * i];
83 temp |= (uint16_t)buf[eta * i + 1] << 0x8;
84 t1 = temp & 0x5555; // temp & 0x5555 is used to obtain a specific bit in temp.
85 t1 += (temp >> 1) & 0x5555;
86
87 for (j = 0; j < 4; j++) {
88 a = (t1 >> (4 * j)) & 0x3;
89 b = (t1 >> (4 * j + eta)) & 0x3;
90 polyF[4 * i + j] = a - b;
91 }
92 }
93 }
94 }
95
96 #endif
97