• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3  *
4  *  Licensed under the Apache License, Version 2.0 (the License); you may
5  *  not use this file except in compliance with the License.
6  *
7  *  http://www.apache.org/licenses/LICENSE-2.0
8  */
9 
10 
11 #include <gmssl/sm4.h>
12 #include <gmssl/mem.h>
13 #include <gmssl/gcm.h>
14 #include <gmssl/error.h>
15 
sm4_cbc_encrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)16 void sm4_cbc_encrypt(const SM4_KEY *key, const uint8_t iv[16],
17 	const uint8_t *in, size_t nblocks, uint8_t *out)
18 {
19 	while (nblocks--) {
20 		gmssl_memxor(out, in, iv, 16);
21 		sm4_encrypt(key, out, out);
22 		iv = out;
23 		in += 16;
24 		out += 16;
25 	}
26 }
27 
sm4_cbc_decrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)28 void sm4_cbc_decrypt(const SM4_KEY *key, const uint8_t iv[16],
29 	const uint8_t *in, size_t nblocks, uint8_t *out)
30 {
31 	while (nblocks--) {
32 		sm4_encrypt(key, in, out);
33 		memxor(out, iv, 16);
34 		iv = in;
35 		in += 16;
36 		out += 16;
37 	}
38 }
39 
sm4_cbc_padding_encrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)40 int sm4_cbc_padding_encrypt(const SM4_KEY *key, const uint8_t iv[16],
41 	const uint8_t *in, size_t inlen,
42 	uint8_t *out, size_t *outlen)
43 {
44 	uint8_t block[16];
45 	size_t rem = inlen % 16;
46 	int padding = 16 - inlen % 16;
47 
48 	if (in) {
49 		memcpy(block, in + inlen - rem, rem);
50 	}
51 	memset(block + rem, padding, padding);
52 	if (inlen/16) {
53 		sm4_cbc_encrypt(key, iv, in, inlen/16, out);
54 		out += inlen - rem;
55 		iv = out - 16;
56 	}
57 	sm4_cbc_encrypt(key, iv, block, 1, out);
58 	*outlen = inlen - rem + 16;
59 	return 1;
60 }
61 
sm4_cbc_padding_decrypt(const SM4_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)62 int sm4_cbc_padding_decrypt(const SM4_KEY *key, const uint8_t iv[16],
63 	const uint8_t *in, size_t inlen,
64 	uint8_t *out, size_t *outlen)
65 {
66 	uint8_t block[16];
67 	size_t len = sizeof(block);
68 	int padding;
69 
70 	if (inlen == 0) {
71 		error_puts("warning: input lenght = 0");
72 		return 0;
73 	}
74 	if (inlen%16 != 0 || inlen < 16) {
75 		error_puts("invalid cbc ciphertext length");
76 		return -1;
77 	}
78 	if (inlen > 16) {
79 		sm4_cbc_decrypt(key, iv, in, inlen/16 - 1, out);
80 		iv = in + inlen - 32;
81 	}
82 	sm4_cbc_decrypt(key, iv, in + inlen - 16, 1, block);
83 
84 	padding = block[15];
85 	if (padding < 1 || padding > 16) {
86 		error_print();
87 		return -1;
88 	}
89 	len -= padding;
90 	memcpy(out + inlen - 16, block, len);
91 	*outlen = inlen - padding;
92 	return 1;
93 }
94 
ctr_incr(uint8_t a[16])95 static void ctr_incr(uint8_t a[16])
96 {
97 	int i;
98 	for (i = 15; i >= 0; i--) {
99 		a[i]++;
100 		if (a[i]) break;
101 	}
102 }
103 
sm4_ctr_encrypt(const SM4_KEY * key,uint8_t ctr[16],const uint8_t * in,size_t inlen,uint8_t * out)104 void sm4_ctr_encrypt(const SM4_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out)
105 {
106 	uint8_t block[16];
107 	size_t len;
108 
109 	while (inlen) {
110 		len = inlen < 16 ? inlen : 16;
111 		sm4_encrypt(key, ctr, block);
112 		gmssl_memxor(out, in, block, len);
113 		ctr_incr(ctr);
114 		in += len;
115 		out += len;
116 		inlen -= len;
117 	}
118 }
119 
sm4_gcm_encrypt(const SM4_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,uint8_t * out,size_t taglen,uint8_t * tag)120 int sm4_gcm_encrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
121 	const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
122 	uint8_t *out, size_t taglen, uint8_t *tag)
123 {
124 	const uint8_t *pin = in;
125 	uint8_t *pout = out;
126 	size_t left = inlen;
127 	uint8_t H[16] = {0};
128 	uint8_t Y[16];
129 	uint8_t T[16];
130 
131 	if (taglen > SM4_GCM_MAX_TAG_SIZE) {
132 		error_print();
133 		return -1;
134 	}
135 
136 	sm4_encrypt(key, H, H);
137 
138 	if (ivlen == 12) {
139 		memcpy(Y, iv, 12);
140 		Y[12] = Y[13] = Y[14] = 0;
141 		Y[15] = 1;
142 	} else {
143 		ghash(H, NULL, 0, iv, ivlen, Y);
144 	}
145 
146 	sm4_encrypt(key, Y, T);
147 
148 	while (left) {
149 		uint8_t block[16];
150 		size_t len = left < 16 ? left : 16;
151 		ctr_incr(Y);
152 		sm4_encrypt(key, Y, block);
153 		gmssl_memxor(pout, pin, block, len);
154 		pin += len;
155 		pout += len;
156 		left -= len;
157 	}
158 
159 	ghash(H, aad, aadlen, out, inlen, H);
160 	gmssl_memxor(tag, T, H, taglen);
161 	return 1;
162 }
163 
sm4_gcm_decrypt(const SM4_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,const uint8_t * tag,size_t taglen,uint8_t * out)164 int sm4_gcm_decrypt(const SM4_KEY *key, const uint8_t *iv, size_t ivlen,
165 	const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
166 	const uint8_t *tag, size_t taglen, uint8_t *out)
167 {
168 	const uint8_t *pin = in;
169 	uint8_t *pout = out;
170 	size_t left = inlen;
171 	uint8_t H[16] = {0};
172 	uint8_t Y[16];
173 	uint8_t T[16];
174 
175 	sm4_encrypt(key, H, H);
176 
177 	if (ivlen == 12) {
178 		memcpy(Y, iv, 12);
179 		Y[12] = Y[13] = Y[14] = 0;
180 		Y[15] = 1;
181 	} else {
182 		ghash(H, NULL, 0, iv, ivlen, Y);
183 	}
184 
185 	ghash(H, aad, aadlen, in, inlen, H);
186 	sm4_encrypt(key, Y, T);
187 	gmssl_memxor(T, T, H, taglen);
188 	if (memcmp(T, tag, taglen) != 0) {
189 		error_print();
190 		return -1;
191 	}
192 
193 	while (left) {
194 		uint8_t block[16];
195 		size_t len = left < 16 ? left : 16;
196 		ctr_incr(Y);
197 		sm4_encrypt(key, Y, block);
198 		gmssl_memxor(pout, pin, block, len);
199 		pin += len;
200 		pout += len;
201 		left -= len;
202 	}
203 	return 1;
204 }
205 
sm4_cbc_encrypt_init(SM4_CBC_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t iv[SM4_BLOCK_SIZE])206 int sm4_cbc_encrypt_init(SM4_CBC_CTX *ctx,
207 	const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE])
208 {
209 	sm4_set_encrypt_key(&ctx->sm4_key, key);
210 	memcpy(ctx->iv, iv, SM4_BLOCK_SIZE);
211 	memset(ctx->block, 0, SM4_BLOCK_SIZE);
212 	ctx->block_nbytes = 0;
213 	return 1;
214 }
215 
sm4_cbc_encrypt_update(SM4_CBC_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)216 int sm4_cbc_encrypt_update(SM4_CBC_CTX *ctx,
217 	const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
218 {
219 	size_t left;
220 	size_t nblocks;
221 	size_t len;
222 
223 	if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
224 		error_print();
225 		return -1;
226 	}
227 	*outlen = 0;
228 	if (ctx->block_nbytes) {
229 		left = SM4_BLOCK_SIZE - ctx->block_nbytes;
230 		if (inlen < left) {
231 			memcpy(ctx->block + ctx->block_nbytes, in, inlen);
232 			ctx->block_nbytes += inlen;
233 			return 1;
234 		}
235 		memcpy(ctx->block + ctx->block_nbytes, in, left);
236 		sm4_cbc_encrypt(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
237 		memcpy(ctx->iv, out, SM4_BLOCK_SIZE);
238 		in += left;
239 		inlen -= left;
240 		out += SM4_BLOCK_SIZE;
241 		*outlen += SM4_BLOCK_SIZE;
242 	}
243 	if (inlen >= SM4_BLOCK_SIZE) {
244 		nblocks = inlen / SM4_BLOCK_SIZE;
245 		len = nblocks * SM4_BLOCK_SIZE;
246 		sm4_cbc_encrypt(&ctx->sm4_key, ctx->iv, in, nblocks, out);
247 		memcpy(ctx->iv, out + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
248 		in += len;
249 		inlen -= len;
250 		out += len;
251 		*outlen += len;
252 	}
253 	if (inlen) {
254 		memcpy(ctx->block, in, inlen);
255 	}
256 	ctx->block_nbytes = inlen;
257 	return 1;
258 }
259 
sm4_cbc_encrypt_finish(SM4_CBC_CTX * ctx,uint8_t * out,size_t * outlen)260 int sm4_cbc_encrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen)
261 {
262 	size_t left;
263 	size_t i;
264 
265 	if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
266 		error_print();
267 		return -1;
268 	}
269 	if (sm4_cbc_padding_encrypt(&ctx->sm4_key, ctx->iv, ctx->block, ctx->block_nbytes, out, outlen) != 1) {
270 		error_print();
271 		return -1;
272 	}
273 	return 1;
274 }
275 
sm4_cbc_decrypt_init(SM4_CBC_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t iv[SM4_BLOCK_SIZE])276 int sm4_cbc_decrypt_init(SM4_CBC_CTX *ctx,
277 	const uint8_t key[SM4_BLOCK_SIZE], const uint8_t iv[SM4_BLOCK_SIZE])
278 {
279 	sm4_set_decrypt_key(&ctx->sm4_key, key);
280 	memcpy(ctx->iv, iv, SM4_BLOCK_SIZE);
281 	memset(ctx->block, 0, SM4_BLOCK_SIZE);
282 	ctx->block_nbytes = 0;
283 	return 1;
284 }
285 
sm4_cbc_decrypt_update(SM4_CBC_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)286 int sm4_cbc_decrypt_update(SM4_CBC_CTX *ctx,
287 	const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
288 {
289 	size_t left, len, nblocks;
290 
291 	if (ctx->block_nbytes > SM4_BLOCK_SIZE) {
292 		error_print();
293 		return -1;
294 	}
295 
296 	*outlen = 0;
297 	if (ctx->block_nbytes) {
298 		left = SM4_BLOCK_SIZE - ctx->block_nbytes;
299 		if (inlen <= left) {
300 			memcpy(ctx->block + ctx->block_nbytes, in, inlen);
301 			ctx->block_nbytes += inlen;
302 			return 1;
303 		}
304 		memcpy(ctx->block + ctx->block_nbytes, in, left);
305 		sm4_cbc_decrypt(&ctx->sm4_key, ctx->iv, ctx->block, 1, out);
306 		memcpy(ctx->iv, ctx->block, SM4_BLOCK_SIZE);
307 		in += left;
308 		inlen -= left;
309 		out += SM4_BLOCK_SIZE;
310 		*outlen += SM4_BLOCK_SIZE;
311 	}
312 	if (inlen > SM4_BLOCK_SIZE) {
313 		nblocks = (inlen-1) / SM4_BLOCK_SIZE;
314 		len = nblocks * SM4_BLOCK_SIZE;
315 		sm4_cbc_decrypt(&ctx->sm4_key, ctx->iv, in, nblocks, out);
316 		memcpy(ctx->iv, in + len - SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
317 		in += len;
318 		inlen -= len;
319 		out += len;
320 		*outlen += len;
321 	}
322 	memcpy(ctx->block, in, inlen);
323 	ctx->block_nbytes = inlen;
324 	return 1;
325 }
326 
sm4_cbc_decrypt_finish(SM4_CBC_CTX * ctx,uint8_t * out,size_t * outlen)327 int sm4_cbc_decrypt_finish(SM4_CBC_CTX *ctx, uint8_t *out, size_t *outlen)
328 {
329 	if (ctx->block_nbytes != SM4_BLOCK_SIZE) {
330 		error_print();
331 		return -1;
332 	}
333 	if (sm4_cbc_padding_decrypt(&ctx->sm4_key, ctx->iv, ctx->block, SM4_BLOCK_SIZE, out, outlen) != 1) {
334 		error_print();
335 		return -1;
336 	}
337 	return 1;
338 }
339 
sm4_ctr_encrypt_init(SM4_CTR_CTX * ctx,const uint8_t key[SM4_BLOCK_SIZE],const uint8_t ctr[SM4_BLOCK_SIZE])340 int sm4_ctr_encrypt_init(SM4_CTR_CTX *ctx,
341 	const uint8_t key[SM4_BLOCK_SIZE], const uint8_t ctr[SM4_BLOCK_SIZE])
342 {
343 	sm4_set_encrypt_key(&ctx->sm4_key, key);
344 	memcpy(ctx->ctr, ctr, SM4_BLOCK_SIZE);
345 	memset(ctx->block, 0, SM4_BLOCK_SIZE);
346 	ctx->block_nbytes = 0;
347 	return 1;
348 }
349 
sm4_ctr_encrypt_update(SM4_CTR_CTX * ctx,const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)350 int sm4_ctr_encrypt_update(SM4_CTR_CTX *ctx,
351 	const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
352 {
353 	size_t left;
354 	size_t nblocks;
355 	size_t len;
356 
357 	if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
358 		error_print();
359 		return -1;
360 	}
361 	*outlen = 0;
362 	if (ctx->block_nbytes) {
363 		left = SM4_BLOCK_SIZE - ctx->block_nbytes;
364 		if (inlen < left) {
365 			memcpy(ctx->block + ctx->block_nbytes, in, inlen);
366 			ctx->block_nbytes += inlen;
367 			return 1;
368 		}
369 		memcpy(ctx->block + ctx->block_nbytes, in, left);
370 		sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, SM4_BLOCK_SIZE, out);
371 		in += left;
372 		inlen -= left;
373 		out += SM4_BLOCK_SIZE;
374 		*outlen += SM4_BLOCK_SIZE;
375 	}
376 	if (inlen >= SM4_BLOCK_SIZE) {
377 		nblocks = inlen / SM4_BLOCK_SIZE;
378 		len = nblocks * SM4_BLOCK_SIZE;
379 		sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, in, len, out);
380 		in += len;
381 		inlen -= len;
382 		out += len;
383 		*outlen += len;
384 	}
385 	if (inlen) {
386 		memcpy(ctx->block, in, inlen);
387 	}
388 	ctx->block_nbytes = inlen;
389 	return 1;
390 }
391 
sm4_ctr_encrypt_finish(SM4_CTR_CTX * ctx,uint8_t * out,size_t * outlen)392 int sm4_ctr_encrypt_finish(SM4_CTR_CTX *ctx, uint8_t *out, size_t *outlen)
393 {
394 	size_t left;
395 	if (ctx->block_nbytes >= SM4_BLOCK_SIZE) {
396 		error_print();
397 		return -1;
398 	}
399 	sm4_ctr_encrypt(&ctx->sm4_key, ctx->ctr, ctx->block, ctx->block_nbytes, out);
400 	*outlen = ctx->block_nbytes;
401 	return 1;
402 }
403