• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include "ml_kem_local.h"
19 
20 // basecase multiplication
BaseMul(int16_t polyH[2],int16_t f0,int16_t f1,int16_t g0,int16_t g1,int16_t factor)21 static void BaseMul(int16_t polyH[2], int16_t f0, int16_t f1, int16_t g0, int16_t g1, int16_t factor)
22 {
23     polyH[0] = (int16_t)(((int32_t)(f0 * g0) + (int32_t)(((int32_t)(f1 * g1) % MLKEM_Q) * factor)) % MLKEM_Q);
24     polyH[1] = (int16_t)(((int32_t)(f0 * g1) + (int32_t)(f1 * g0)) % MLKEM_Q);
25     MlKemAddModQ(&polyH[0]);
26     MlKemAddModQ(&polyH[1]);
27 }
28 
29 // circle multiplication
CircMul(int16_t dest[MLKEM_N],int16_t src1[MLKEM_N],int16_t src2[MLKEM_N],const int16_t * factor)30 static void CircMul(int16_t dest[MLKEM_N], int16_t src1[MLKEM_N], int16_t src2[MLKEM_N], const int16_t *factor)
31 {
32     for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
33         // 4-byte data is calculated in each round.
34         BaseMul(&dest[4 * i], src1[4 * i], src1[4 * i + 1], src2[4 * i], src2[4 * i + 1], factor[i]);
35         BaseMul(&dest[4 * i + 2], src1[4 * i + 2], src1[4 * i + 3], src2[4 * i + 2], src2[4 * i + 3], -1 * factor[i]);
36     }
37 }
38 
MLKEM_MatrixMulAdd(uint8_t k,int16_t * matrix[],int16_t * vectorS[],int16_t * vectorE,int16_t * vectorT,const int16_t * factor)39 void MLKEM_MatrixMulAdd(uint8_t k, int16_t *matrix[], int16_t *vectorS[], int16_t *vectorE,
40     int16_t *vectorT, const int16_t *factor)
41 {
42     int16_t dest[MLKEM_N] = { 0 };
43     for (uint8_t j = 0; j < k; j++) {
44         // factor is a half of the NTT table.
45         CircMul(dest, matrix[j], vectorS[j], factor + MLKEM_N_HALF / 2);
46         for (uint32_t n = 0; n < MLKEM_N; n++) {
47             if (j == 0) {
48                 vectorT[n] = (vectorE == NULL) ? dest[n] : (vectorE[n] + dest[n]);
49             } else if (j != 0 && j != (k - 1)) {
50                 vectorT[n] += dest[n];
51             } else if (j == (k - 1)) {
52                 vectorT[n] = (vectorT[n] + dest[n]) % MLKEM_Q;
53             }
54         }
55     }
56 }
57 
MLKEM_SamplePolyCBD(int16_t * polyF,uint8_t * buf,uint8_t eta)58 void MLKEM_SamplePolyCBD(int16_t *polyF, uint8_t *buf, uint8_t eta)
59 {
60     uint32_t i;
61     uint32_t j;
62     uint8_t a;
63     uint8_t b;
64     uint32_t t1;
65     if (eta == 3) {  // The value of eta can only be 2 or 3.
66         for (i = 0; i < MLKEM_N / 4; i++) {
67             uint32_t temp = (uint32_t)buf[eta * i];
68             temp |= (uint32_t)buf[eta * i + 1] << 8;
69             temp |= (uint32_t)buf[eta * i + 2] << 16;
70             t1 = temp & 0x00249249;  // temp & 0x00249249 is used to obtain a specific bit in temp.
71             t1 += (temp >> 1) & 0x00249249;
72             t1 += (temp >> 2) & 0x00249249;
73 
74             for (j = 0; j < 4; j++) {
75                 a = (t1 >> (6 * j)) & 0x3;
76                 b = (t1 >> (6 * j + eta)) & 0x3;
77                 polyF[4 * i + j] = a - b;
78             }
79         }
80     } else if (eta == 2) {
81         for (i = 0; i < MLKEM_N / 4; i++) {
82             uint16_t temp = (uint16_t)buf[eta * i];
83             temp |= (uint16_t)buf[eta * i + 1] << 0x8;
84             t1 = temp & 0x5555;  // temp & 0x5555 is used to obtain a specific bit in temp.
85             t1 += (temp >> 1) & 0x5555;
86 
87             for (j = 0; j < 4; j++) {
88                 a = (t1 >> (4 * j)) & 0x3;
89                 b = (t1 >> (4 * j + eta)) & 0x3;
90                 polyF[4 * i + j] = a - b;
91             }
92         }
93     }
94 }
95 
96 #endif
97