1 /*
2 * Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License); you may
5 * not use this file except in compliance with the License.
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 */
9
10
11 #include <gmssl/sm4.h>
12 #include <gmssl/mem.h>
13 #include <gmssl/gcm.h>
14 #include <gmssl/error.h>
15
sm4_cbc_encrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)16 void sm4_cbc_encrypt(const SM4_KEY *key, const uint8_t iv[16],
17 const uint8_t *in, size_t nblocks, uint8_t *out)
18 {
19 while (nblocks--) {
20 gmssl_memxor(out, in, iv, 16);
21 sm4_encrypt(key, out, out);
22 iv = out;
23 in += 16;
24 out += 16;
25 }
26 }
27
sm4_cbc_decrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)28 void sm4_cbc_decrypt(const SM4_KEY *key, const uint8_t iv[16],
29 const uint8_t *in, size_t nblocks, uint8_t *out)
30 {
31 while (nblocks--) {
32 sm4_encrypt(key, in, out);
33 memxor(out, iv, 16);
34 iv = in;
35 in += 16;
36 out += 16;
37 }
38 }
39
sm4_cbc_padding_encrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)40 int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[16],
41 const uint8_t *in, size_t inlen,
42 uint8_t *out, size_t *outlen)
43 {
44 uint8_t block[16];
45 size_t rem = inlen % 16;
46 int padding = 16 - inlen % 16;
47
48 if (in) {
49 memcpy(block, in + inlen - rem, rem);
50 }
51 memset(block + rem, padding, padding);
52 if (inlen/16) {
53 sm4_cbc_encrypt(key, iv, in, inlen/16, out);
54 out += inlen - rem;
55 iv = out - 16;
56 }
57 sm4_cbc_encrypt(key, iv, block, 1, out);
58 *outlen = inlen - rem + 16;
59 return 1;
60 }
61
sm4_cbc_padding_decrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)62 int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16],
63 const uint8_t *in, size_t inlen,
64 uint8_t *out, size_t *outlen)
65 {
66 uint8_t block[16];
67 size_t len = sizeof(block);
68 int padding;
69
70 if (inlen == 0) {
71 error_puts("warning: input lenght = 0");
72 return 0;
73 }
74 if (inlen%16 != 0 || inlen < 16) {
75 error_puts("invalid cbc ciphertext length");
76 return -1;
77 }
78 if (inlen > 16) {
79 sm4_cbc_decrypt(key, iv, in, inlen/16 - 1, out);
80 iv = in + inlen - 32;
81 }
82 sm4_cbc_decrypt(key, iv, in + inlen - 16, 1, block);
83
84 padding = block[15];
85 if (padding < 1 || padding > 16) {
86 error_print();
87 return -1;
88 }
89 len -= padding;
90 memcpy(out + inlen - 16, block, len);
91 *outlen = inlen - padding;
92 return 1;
93 }
94
ctr_incr(uint8_t a[16])95 static void ctr_incr(uint8_t a[16])
96 {
97 int i;
98 for (i = 15; i >= 0; i--) {
99 a[i]++;
100 if (a[i]) break;
101 }
102 }
103
sm4_ctr_encrypt(const SM4_KEY * key,uint8_t ctr[16],const uint8_t * in,size_t inlen,uint8_t * out)104 void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out)
105 {
106 uint8_t block[16];
107 size_t len;
108
109 while (inlen) {
110 len = inlen < 16 ? inlen : 16;
111 sm4_encrypt(key, ctr, block);
112 gmssl_memxor(out, in, block, len);
113 ctr_incr(ctr);
114 in += len;
115 out += len;
116 inlen -= len;
117 }
118 }
119
sm4_gcm_encrypt(const SM4_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,uint8_t * out,size_t taglen,uint8_t * tag)120 int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
121 const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
122 uint8_t *out, size_t taglen, uint8_t *tag)
123 {
124 const uint8_t *pin = in;
125 uint8_t *pout = out;
126 size_t left = inlen;
127 uint8_t H[16] = {0};
128 uint8_t Y[16];
129 uint8_t T[16];
130
131 if (taglen > SM4_GCM_MAX_TAG_SIZE) {
132 error_print();
133 return -1;
134 }
135
136 sm4_encrypt(key, H, H);
137
138 if (ivlen == 12) {
139 memcpy(Y, iv, 12);
140 Y[12] = Y[13] = Y[14] = 0;
141 Y[15] = 1;
142 } else {
143 ghash(H, NULL, 0, iv, ivlen, Y);
144 }
145
146 sm4_encrypt(key, Y, T);
147
148 while (left) {
149 uint8_t block[16];
150 size_t len = left < 16 ? left : 16;
151 ctr_incr(Y);
152 sm4_encrypt(key, Y, block);
153 gmssl_memxor(pout, pin, block, len);
154 pin += len;
155 pout += len;
156 left -= len;
157 }
158
159 ghash(H, aad, aadlen, out, inlen, H);
160 gmssl_memxor(tag, T, H, taglen);
161 return 1;
162 }
163
sm4_gcm_decrypt(const SM4_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,const uint8_t * tag,size_t taglen,uint8_t * out)164 int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
165 const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
166 const uint8_t *tag, size_t taglen, uint8_t *out)
167 {
168 const uint8_t *pin = in;
169 uint8_t *pout = out;
170 size_t left = inlen;
171 uint8_t H[16] = {0};
172 uint8_t Y[16];
173 uint8_t T[16];
174
175 sm4_encrypt(key, H, H);
176
177 if (ivlen == 12) {
178 memcpy(Y, iv, 12);
179 Y[12] = Y[13] = Y[14] = 0;
180 Y[15] = 1;
181 } else {
182 ghash(H, NULL, 0, iv, ivlen, Y);
183 }
184
185 ghash(H, aad, aadlen, in, inlen, H);
186 sm4_encrypt(key, Y, T);
187 gmssl_memxor(T, T, H, taglen);
188 if (memcmp(T, tag, taglen) != 0) {
189 error_print();
190 return -1;
191 }
192
193 while (left) {
194 uint8_t block[16];
195 size_t len = left < 16 ? left : 16;
196 ctr_incr(Y);
197 sm4_encrypt(key, Y, block);
198 gmssl_memxor(pout, pin, block, len);
199 pin += len;
200 pout += len;
201 left -= len;
202 }
203 return 1;
204 }
205
sm4_cbc_encrypt_init(SM4_CBC_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t iv[SM4_BLOCK_SIZE])206 int sm4_cbc_encrypt_init(SM4_CBC_CTX *ctx,
207 const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE])
208 {
209 sm4_set_encrypt_key(&ctx->sm4_key, key);
210 memcpy(ctx->iv, iv, SM4_BLOCK_SIZE);
211 memset(ctx->block, 0, SM4_BLOCK_SIZE);
212 ctx->block_nbytes = 0;
213 return 1;
214 }
215
sm4_cbc_encrypt_update(SM4_CBC_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)216 int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx,
217 const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
218 {
219 size_t left;
220 size_t nblocks;
221 size_t len;
222
223 if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
224 error_print();
225 return -1;
226 }
227 *outlen = 0;
228 if (ctx->block_nbytes) {
229 left = SM4_BLOCK_SIZE - ctx->block_nbytes;
230 if (inlen < left) {
231 memcpy(ctx->block + ctx->block_nbytes, in, inlen);
232 ctx->block_nbytes += inlen;
233 return 1;
234 }
235 memcpy(ctx->block + ctx->block_nbytes, in, left);
236 sm4_cbc_encrypt(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
237 memcpy(ctx->iv, out, SM4_BLOCK_SIZE);
238 in += left;
239 inlen -= left;
240 out += SM4_BLOCK_SIZE;
241 *outlen += SM4_BLOCK_SIZE;
242 }
243 if (inlen >= SM4_BLOCK_SIZE) {
244 nblocks = inlen / SM4_BLOCK_SIZE;
245 len = nblocks * SM4_BLOCK_SIZE;
246 sm4_cbc_encrypt(&ctx->sm4_key, ctx->iv, in, nblocks, out);
247 memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
248 in += len;
249 inlen -= len;
250 out += len;
251 *outlen += len;
252 }
253 if (inlen) {
254 memcpy(ctx->block, in, inlen);
255 }
256 ctx->block_nbytes = inlen;
257 return 1;
258 }
259
sm4_cbc_encrypt_finish(SM4_CBC_CTX * ctx,uint8_t * out,size_t * outlen)260 int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen)
261 {
262 size_t left;
263 size_t i;
264
265 if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
266 error_print();
267 return -1;
268 }
269 if (sm4_cbc_padding_encrypt(&ctx->sm4_key, ctx->iv, ctx->block, ctx->block_nbytes, out, outlen) != 1) {
270 error_print();
271 return -1;
272 }
273 return 1;
274 }
275
sm4_cbc_decrypt_init(SM4_CBC_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t iv[SM4_BLOCK_SIZE])276 int sm4_cbc_decrypt_init(SM4_CBC_CTX *ctx,
277 const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE])
278 {
279 sm4_set_decrypt_key(&ctx->sm4_key, key);
280 memcpy(ctx->iv, iv, SM4_BLOCK_SIZE);
281 memset(ctx->block, 0, SM4_BLOCK_SIZE);
282 ctx->block_nbytes = 0;
283 return 1;
284 }
285
sm4_cbc_decrypt_update(SM4_CBC_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)286 int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx,
287 const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
288 {
289 size_t left, len, nblocks;
290
291 if (ctx->block_nbytes > SM4_BLOCK_SIZE) {
292 error_print();
293 return -1;
294 }
295
296 *outlen = 0;
297 if (ctx->block_nbytes) {
298 left = SM4_BLOCK_SIZE - ctx->block_nbytes;
299 if (inlen <= left) {
300 memcpy(ctx->block + ctx->block_nbytes, in, inlen);
301 ctx->block_nbytes += inlen;
302 return 1;
303 }
304 memcpy(ctx->block + ctx->block_nbytes, in, left);
305 sm4_cbc_decrypt(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
306 memcpy(ctx->iv, ctx->block, SM4_BLOCK_SIZE);
307 in += left;
308 inlen -= left;
309 out += SM4_BLOCK_SIZE;
310 *outlen += SM4_BLOCK_SIZE;
311 }
312 if (inlen > SM4_BLOCK_SIZE) {
313 nblocks = (inlen-1) / SM4_BLOCK_SIZE;
314 len = nblocks * SM4_BLOCK_SIZE;
315 sm4_cbc_decrypt(&ctx->sm4_key, ctx->iv, in, nblocks, out);
316 memcpy(ctx->iv, in + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
317 in += len;
318 inlen -= len;
319 out += len;
320 *outlen += len;
321 }
322 memcpy(ctx->block, in, inlen);
323 ctx->block_nbytes = inlen;
324 return 1;
325 }
326
sm4_cbc_decrypt_finish(SM4_CBC_CTX * ctx,uint8_t * out,size_t * outlen)327 int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen)
328 {
329 if (ctx->block_nbytes != SM4_BLOCK_SIZE) {
330 error_print();
331 return -1;
332 }
333 if (sm4_cbc_padding_decrypt(&ctx->sm4_key, ctx->iv, ctx->block, SM4_BLOCK_SIZE, out, outlen) != 1) {
334 error_print();
335 return -1;
336 }
337 return 1;
338 }
339
sm4_ctr_encrypt_init(SM4_CTR_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t ctr[SM4_BLOCK_SIZE])340 int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx,
341 const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE])
342 {
343 sm4_set_encrypt_key(&ctx->sm4_key, key);
344 memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE);
345 memset(ctx->block, 0, SM4_BLOCK_SIZE);
346 ctx->block_nbytes = 0;
347 return 1;
348 }
349
sm4_ctr_encrypt_update(SM4_CTR_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)350 int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx,
351 const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
352 {
353 size_t left;
354 size_t nblocks;
355 size_t len;
356
357 if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
358 error_print();
359 return -1;
360 }
361 *outlen = 0;
362 if (ctx->block_nbytes) {
363 left = SM4_BLOCK_SIZE - ctx->block_nbytes;
364 if (inlen < left) {
365 memcpy(ctx->block + ctx->block_nbytes, in, inlen);
366 ctx->block_nbytes += inlen;
367 return 1;
368 }
369 memcpy(ctx->block + ctx->block_nbytes, in, left);
370 sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out);
371 in += left;
372 inlen -= left;
373 out += SM4_BLOCK_SIZE;
374 *outlen += SM4_BLOCK_SIZE;
375 }
376 if (inlen >= SM4_BLOCK_SIZE) {
377 nblocks = inlen / SM4_BLOCK_SIZE;
378 len = nblocks * SM4_BLOCK_SIZE;
379 sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out);
380 in += len;
381 inlen -= len;
382 out += len;
383 *outlen += len;
384 }
385 if (inlen) {
386 memcpy(ctx->block, in, inlen);
387 }
388 ctx->block_nbytes = inlen;
389 return 1;
390 }
391
sm4_ctr_encrypt_finish(SM4_CTR_CTX * ctx,uint8_t * out,size_t * outlen)392 int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen)
393 {
394 size_t left;
395 if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
396 error_print();
397 return -1;
398 }
399 sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out);
400 *outlen = ctx->block_nbytes;
401 return 1;
402 }
403