1 /*
2 * Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License); you may
5 * not use this file except in compliance with the License.
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 */
9
10
11
12 #include <stdio.h>
13 #include <string.h>
14 #include <stdlib.h>
15 #include <gmssl/aes.h>
16 #include <gmssl/gcm.h>
17 #include <gmssl/error.h>
18 #include <gmssl/mem.h>
19
20
aes_cbc_encrypt(const AES_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)21 void aes_cbc_encrypt(const AES_KEY *key, const uint8_t iv[16],
22 const uint8_t *in, size_t nblocks, uint8_t *out)
23 {
24 while (nblocks--) {
25 gmssl_memxor(out, in, iv, 16);
26 aes_encrypt(key, out, out);
27 iv = out;
28 in += 16;
29 out += 16;
30 }
31 }
32
aes_cbc_decrypt(const AES_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t nblocks,uint8_t * out)33 void aes_cbc_decrypt(const AES_KEY *key, const uint8_t iv[16],
34 const uint8_t *in, size_t nblocks, uint8_t *out)
35 {
36 while (nblocks--) {
37 aes_decrypt(key, in, out);
38 memxor(out, iv, 16);
39 iv = in;
40 in += 16;
41 out += 16;
42 }
43 }
44
aes_cbc_padding_encrypt(const AES_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)45 int aes_cbc_padding_encrypt(const AES_KEY *key, const uint8_t iv[16],
46 const uint8_t *in, size_t inlen,
47 uint8_t *out, size_t *outlen)
48 {
49 uint8_t block[16];
50 size_t rem = inlen % 16;
51 int padding = 16 - inlen % 16;
52
53 if (in) {
54 memcpy(block, in + inlen - rem, rem);
55 }
56 memset(block + rem, padding, padding);
57 if (inlen/16) {
58 aes_cbc_encrypt(key, iv, in, inlen/16, out);
59 out += inlen - rem;
60 iv = out - 16;
61 }
62 aes_cbc_encrypt(key, iv, block, 1, out);
63 *outlen = inlen - rem + 16;
64 return 1;
65 }
66
aes_cbc_padding_decrypt(const AES_KEY * key,const uint8_t iv[16],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)67 int aes_cbc_padding_decrypt(const AES_KEY *key, const uint8_t iv[16],
68 const uint8_t *in, size_t inlen,
69 uint8_t *out, size_t *outlen)
70 {
71 uint8_t block[16];
72 size_t len = sizeof(block);
73 int padding;
74
75 if (inlen == 0) {
76 error_print();
77 return 0;
78 }
79 if (inlen%16 != 0 || inlen < 16) {
80 error_print();
81 return -1;
82 }
83 if (inlen > 16) {
84 aes_cbc_decrypt(key, iv, in, inlen/16 - 1, out);
85 iv = in + inlen - 32;
86 }
87 aes_cbc_decrypt(key, iv, in + inlen - 16, 1, block);
88 padding = block[15];
89 if (padding < 1 || padding > 16) {
90 error_print();
91 return -1;
92 }
93 len -= padding;
94 memcpy(out + inlen - 16, block, len);
95 *outlen = inlen - padding;
96 return 1;
97 }
98
ctr_incr(uint8_t a[16])99 static void ctr_incr(uint8_t a[16])
100 {
101 int i;
102 for (i = 15; i >= 0; i--) {
103 a[i]++;
104 if (a[i]) break;
105 }
106 }
107
aes_ctr_encrypt(const AES_KEY * key,uint8_t ctr[16],const uint8_t * in,size_t inlen,uint8_t * out)108 void aes_ctr_encrypt(const AES_KEY *key, uint8_t ctr[16], const uint8_t *in, size_t inlen, uint8_t *out)
109 {
110 uint8_t block[16];
111 size_t len;
112
113 while (inlen) {
114 len = inlen < 16 ? inlen : 16;
115 aes_encrypt(key, ctr, block);
116 gmssl_memxor(out, in, block, len);
117 ctr_incr(ctr);
118 in += len;
119 out += len;
120 inlen -= len;
121 }
122 }
123
aes_gcm_encrypt(const AES_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,uint8_t * out,size_t taglen,uint8_t * tag)124 int aes_gcm_encrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
125 const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
126 uint8_t *out, size_t taglen, uint8_t *tag)
127 {
128 const uint8_t *pin = in;
129 uint8_t *pout = out;
130 size_t left = inlen;
131 uint8_t H[16] = {0};
132 uint8_t Y[16];
133 uint8_t T[16];
134
135 if (taglen > AES_GCM_MAX_TAG_SIZE) {
136 error_print();
137 return -1;
138 }
139
140 aes_encrypt(key, H, H);
141
142 if (ivlen == 12) {
143 memcpy(Y, iv, 12);
144 Y[12] = Y[13] = Y[14] = 0;
145 Y[15] = 1;
146 } else {
147 ghash(H, NULL, 0, iv, ivlen, Y);
148 }
149
150 aes_encrypt(key, Y, T);
151
152 while (left) {
153 uint8_t block[16];
154 size_t len = left < 16 ? left : 16;
155 ctr_incr(Y);
156 aes_encrypt(key, Y, block);
157 gmssl_memxor(pout, pin, block, len);
158 pin += len;
159 pout += len;
160 left -= len;
161 }
162
163 ghash(H, aad, aadlen, out, inlen, H);
164 gmssl_memxor(tag, T, H, taglen);
165 return 1;
166 }
167
aes_gcm_decrypt(const AES_KEY * key,const uint8_t * iv,size_t ivlen,const uint8_t * aad,size_t aadlen,const uint8_t * in,size_t inlen,const uint8_t * tag,size_t taglen,uint8_t * out)168 int aes_gcm_decrypt(const AES_KEY *key, const uint8_t *iv, size_t ivlen,
169 const uint8_t *aad, size_t aadlen, const uint8_t *in, size_t inlen,
170 const uint8_t *tag, size_t taglen, uint8_t *out)
171 {
172 const uint8_t *pin = in;
173 uint8_t *pout = out;
174 size_t left = inlen;
175 uint8_t H[16] = {0};
176 uint8_t Y[16];
177 uint8_t T[16];
178
179 aes_encrypt(key, H, H);
180
181 if (ivlen == 12) {
182 memcpy(Y, iv, 12);
183 Y[12] = Y[13] = Y[14] = 0;
184 Y[15] = 1;
185 } else {
186 ghash(H, NULL, 0, iv, ivlen, Y);
187 }
188
189 ghash(H, aad, aadlen, in, inlen, H);
190 aes_encrypt(key, Y, T);
191 gmssl_memxor(T, T, H, taglen);
192 if (memcmp(T, tag, taglen) != 0) {
193 error_print();
194 return -1;
195 }
196
197 while (left) {
198 uint8_t block[16];
199 size_t len = left < 16 ? left : 16;
200 ctr_incr(Y);
201 aes_encrypt(key, Y, block);
202 gmssl_memxor(pout, pin, block, len);
203 pin += len;
204 pout += len;
205 left -= len;
206 }
207 return 1;
208 }
209