1 /*------------------------------------------------------------------------
2 / OCB Version 3 Reference Code (Optimized C) Last modified 12-JUN-2013
3 /-------------------------------------------------------------------------
4 / Copyright (c) 2013 Ted Krovetz.
5 /
6 / Permission to use, copy, modify, and/or distribute this software for any
7 / purpose with or without fee is hereby granted, provided that the above
8 / copyright notice and this permission notice appear in all copies.
9 /
10 / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 /
18 / Phillip Rogaway holds patents relevant to OCB. See the following for
19 / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20 /
21 / Special thanks to Keegan McAllister for suggesting several good improvements
22 /
23 / Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24 /------------------------------------------------------------------------- */
25
26 /* ----------------------------------------------------------------------- */
27 /* Usage notes */
28 /* ----------------------------------------------------------------------- */
29
30 /* - When AE_PENDING is passed as the 'final' parameter of any function,
31 / the length parameters must be a multiple of (BPI*16).
32 / - When available, SSE or AltiVec registers are used to manipulate data.
33 / So, when on machines with these facilities, all pointers passed to
34 / any function should be 16-byte aligned.
35 / - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36 / encrypted in-place), but no other pair of pointers may be equal.
37 / - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38 / when compiling under MSVC. If untrue, alter the #define.
39 / - This code is tested for C99 and recent versions of GCC and MSVC. */
40
41 /* ----------------------------------------------------------------------- */
42 /* User configuration options */
43 /* ----------------------------------------------------------------------- */
44
45 /* Set the AES key length to use and length of authentication tag to produce.
46 / Setting either to 0 requires the value be set at runtime via ae_init().
47 / Some optimizations occur for each when set to a fixed value. */
48 #define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49 #define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init */
50
51 /* This implementation has built-in support for multiple AES APIs. Set any
52 / one of the following to non-zero to specify which to use. */
53 #define USE_OPENSSL_AES 1 /* http://openssl.org */
54 #define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c */
55 #define USE_AES_NI 0 /* Uses compiler's intrinsics */
56
57 /* During encryption and decryption, various "L values" are required.
58 / The L values can be precomputed during initialization (requiring extra
59 / space in ae_ctx), generated as needed (slightly slowing encryption and
60 / decryption), or some combination of the two. L_TABLE_SZ specifies how many
61 / L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62 / are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63 / 2^L_TABLE_SZ blocks need no L values calculated dynamically. */
64 #define L_TABLE_SZ 16
65
66 /* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67 / will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68 / in better performance. */
69 #define L_TABLE_SZ_IS_ENOUGH 1
70
71 /* ----------------------------------------------------------------------- */
72 /* Includes and compiler specific definitions */
73 /* ----------------------------------------------------------------------- */
74
75 #include <keymaster/key_blob_utils/ae.h>
76 #include <stdlib.h>
77 #include <string.h>
78
79 /* Define standard sized integers */
80 #if defined(_MSC_VER) && (_MSC_VER < 1600)
81 typedef unsigned __int8 uint8_t;
82 typedef unsigned __int32 uint32_t;
83 typedef unsigned __int64 uint64_t;
84 typedef __int64 int64_t;
85 #else
86 #include <stdint.h>
87 #endif
88
89 /* Compiler-specific intrinsics and fixes: bswap64, ntz */
90 #if _MSC_VER
91 #define inline __inline /* MSVC doesn't recognize "inline" in C */
92 #define restrict __restrict /* MSVC doesn't recognize "restrict" in C */
93 #define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSE2 */
94 #define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95 #include <intrin.h>
96 #pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97 #define bswap64(x) _byteswap_uint64(x)
ntz(unsigned x)98 static inline unsigned ntz(unsigned x) {
99 _BitScanForward(&x, x);
100 return x;
101 }
102 #elif __GNUC__
103 #define inline __inline__ /* No "inline" in GCC ansi C mode */
104 #define restrict __restrict__ /* No "restrict" in GCC ansi C mode */
105 #define bswap64(x) __builtin_bswap64(x) /* Assuming GCC 4.3+ */
106 #define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107 #else /* Assume some C99 features: stdint.h, inline, restrict */
108 #define bswap32(x) \
109 ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) | \
110 (((x)&0x000000ffu) << 24))
111
bswap64(uint64_t x)112 static inline uint64_t bswap64(uint64_t x) {
113 union {
114 uint64_t u64;
115 uint32_t u32[2];
116 } in, out;
117 in.u64 = x;
118 out.u32[0] = bswap32(in.u32[1]);
119 out.u32[1] = bswap32(in.u32[0]);
120 return out.u64;
121 }
122
123 #if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
ntz(unsigned x)124 static inline unsigned ntz(unsigned x) {
125 static const unsigned char tz_table[] = {
126 0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131 return tz_table[x / 4];
132 }
133 #else /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
ntz(unsigned x)134 static inline unsigned ntz(unsigned x) {
135 static const unsigned char tz_table[32] = {0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20,
136 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19,
137 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
138 return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139 }
140 #endif
141 #endif
142
143 /* ----------------------------------------------------------------------- */
144 /* Define blocks and operations -- Patch if incorrect on your compiler. */
145 /* ----------------------------------------------------------------------- */
146
147 #if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148 #include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149 #include <emmintrin.h> /* SSE2 instructions */
150 typedef __m128i block;
151 #define xor_block(x, y) _mm_xor_si128(x, y)
152 #define zero_block() _mm_setzero_si128()
153 #define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154 #if __SSSE3__ || USE_AES_NI
155 #include <tmmintrin.h> /* SSSE3 instructions */
156 #define swap_if_le(b) \
157 _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158 #else
swap_if_le(block b)159 static inline block swap_if_le(block b) {
160 block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161 a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162 a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163 return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164 }
165 #endif
gen_offset(uint64_t KtopStr[3],unsigned bot)166 static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167 block hi = _mm_load_si128((__m128i*)(KtopStr + 0)); /* hi = B A */
168 block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169 __m128i lshift = _mm_cvtsi32_si128(bot);
170 __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171 lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172 #if __SSSE3__ || USE_AES_NI
173 return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174 #else
175 return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176 #endif
177 }
double_block(block bl)178 static inline block double_block(block bl) {
179 const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180 __m128i tmp = _mm_srai_epi32(bl, 31);
181 tmp = _mm_and_si128(tmp, mask);
182 tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183 bl = _mm_slli_epi32(bl, 1);
184 return _mm_xor_si128(bl, tmp);
185 }
186 #elif __ALTIVEC__
187 #include <altivec.h>
188 typedef vector unsigned block;
189 #define xor_block(x, y) vec_xor(x, y)
190 #define zero_block() vec_splat_u32(0)
191 #define unequal_blocks(x, y) vec_any_ne(x, y)
192 #define swap_if_le(b) (b)
193 #if __PPC64__
gen_offset(uint64_t KtopStr[3],unsigned bot)194 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195 union {
196 uint64_t u64[2];
197 block bl;
198 } rval;
199 rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200 rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201 return rval.bl;
202 }
203 #else
204 /* Special handling: Shifts are mod 32, and no 64-bit types */
gen_offset(uint64_t KtopStr[3],unsigned bot)205 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206 const vector unsigned k32 = {32, 32, 32, 32};
207 vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208 vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209 vector unsigned bot_vec;
210 if (bot < 32) {
211 lo = vec_sld(hi, lo, 4);
212 } else {
213 vector unsigned t = vec_sld(hi, lo, 4);
214 lo = vec_sld(hi, lo, 8);
215 hi = t;
216 bot = bot - 32;
217 }
218 if (bot == 0)
219 return hi;
220 *(unsigned*)&bot_vec = bot;
221 vector unsigned lshift = vec_splat(bot_vec, 0);
222 vector unsigned rshift = vec_sub(k32, lshift);
223 hi = vec_sl(hi, lshift);
224 lo = vec_sr(lo, rshift);
225 return vec_xor(hi, lo);
226 }
227 #endif
double_block(block b)228 static inline block double_block(block b) {
229 const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230 const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231 const vector unsigned char shift7 = vec_splat_u8(7);
232 const vector unsigned char shift1 = vec_splat_u8(1);
233 vector unsigned char c = (vector unsigned char)b;
234 vector unsigned char t = vec_sra(c, shift7);
235 t = vec_and(t, mask);
236 t = vec_perm(t, t, perm);
237 c = vec_sl(c, shift1);
238 return (block)vec_xor(c, t);
239 }
240 #elif __ARM_NEON__
241 #include <arm_neon.h>
242 typedef int8x16_t block __attribute__ ((aligned (16))); /* Yay! Endian-neutral reads! */
243 #define xor_block(x, y) veorq_s8(x, y)
244 #define zero_block() vdupq_n_s8(0)
unequal_blocks(block a,block b)245 static inline int unequal_blocks(block a, block b) {
246 int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247 return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248 }
249 #define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)251 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252 const union {
253 unsigned x;
254 unsigned char endian;
255 } little = {1};
256 const int64x2_t k64 = {-64, -64};
257 /* Copy hi and lo into local variables to ensure proper alignment */
258 uint64x2_t hi = vld1q_u64(KtopStr + 0); /* hi = A B */
259 uint64x2_t lo = vld1q_u64(KtopStr + 1); /* lo = B C */
260 int64x2_t ls = vdupq_n_s64(bot);
261 int64x2_t rs = vqaddq_s64(k64, ls);
262 block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
263 if (little.endian)
264 rval = vrev64q_s8(rval);
265 return rval;
266 }
double_block(block b)267 static inline block double_block(block b) {
268 const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
269 block tmp = vshrq_n_s8(b, 7);
270 tmp = vandq_s8(tmp, mask);
271 tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
272 b = vshlq_n_s8(b, 1);
273 return veorq_s8(tmp, b);
274 }
275 #else
276 typedef struct { uint64_t l, r; } block;
xor_block(block x,block y)277 static inline block xor_block(block x, block y) {
278 x.l ^= y.l;
279 x.r ^= y.r;
280 return x;
281 }
zero_block(void)282 static inline block zero_block(void) {
283 const block t = {0, 0};
284 return t;
285 }
286 #define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
swap_if_le(block b)287 static inline block swap_if_le(block b) {
288 const union {
289 unsigned x;
290 unsigned char endian;
291 } little = {1};
292 if (little.endian) {
293 block r;
294 r.l = bswap64(b.l);
295 r.r = bswap64(b.r);
296 return r;
297 } else
298 return b;
299 }
300
301 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)302 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
303 block rval;
304 if (bot != 0) {
305 rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
306 rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
307 } else {
308 rval.l = KtopStr[0];
309 rval.r = KtopStr[1];
310 }
311 return swap_if_le(rval);
312 }
313
314 #if __GNUC__ && __arm__
double_block(block b)315 static inline block double_block(block b) {
316 __asm__("adds %1,%1,%1\n\t"
317 "adcs %H1,%H1,%H1\n\t"
318 "adcs %0,%0,%0\n\t"
319 "adcs %H0,%H0,%H0\n\t"
320 "it cs\n\t"
321 "eorcs %1,%1,#135"
322 : "+r"(b.l), "+r"(b.r)
323 :
324 : "cc");
325 return b;
326 }
327 #else
double_block(block b)328 static inline block double_block(block b) {
329 uint64_t t = (uint64_t)((int64_t)b.l >> 63);
330 b.l = (b.l + b.l) ^ (b.r >> 63);
331 b.r = (b.r + b.r) ^ (t & 135);
332 return b;
333 }
334 #endif
335
336 #endif
337
338 /* ----------------------------------------------------------------------- */
339 /* AES - Code uses OpenSSL API. Other implementations get mapped to it. */
340 /* ----------------------------------------------------------------------- */
341
342 /*---------------*/
343 #if USE_OPENSSL_AES
344 /*---------------*/
345
346 #include <openssl/aes.h> /* http://openssl.org/ */
347
348 /* How to ECB encrypt an array of blocks, in place */
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)349 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
350 while (nblks) {
351 --nblks;
352 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
353 }
354 }
355
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)356 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
357 while (nblks) {
358 --nblks;
359 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
360 }
361 }
362
363 #define BPI 4 /* Number of blocks in buffer per ECB call */
364
365 /*-------------------*/
366 #elif USE_REFERENCE_AES
367 /*-------------------*/
368
369 #include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
370 #if (OCB_KEY_LEN == 0)
371 typedef struct {
372 uint32_t rd_key[60];
373 int rounds;
374 } AES_KEY;
375 #define ROUNDS(ctx) ((ctx)->rounds)
376 #define AES_set_encrypt_key(x, y, z) \
377 do { \
378 rijndaelKeySetupEnc((z)->rd_key, x, y); \
379 (z)->rounds = y / 32 + 6; \
380 } while (0)
381 #define AES_set_decrypt_key(x, y, z) \
382 do { \
383 rijndaelKeySetupDec((z)->rd_key, x, y); \
384 (z)->rounds = y / 32 + 6; \
385 } while (0)
386 #else
387 typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
388 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
389 #define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
390 #define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
391 #endif
392 #define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
393 #define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
394
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)395 static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
396 while (nblks) {
397 --nblks;
398 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
399 }
400 }
401
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)402 void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
403 while (nblks) {
404 --nblks;
405 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
406 }
407 }
408
409 #define BPI 4 /* Number of blocks in buffer per ECB call */
410
411 /*----------*/
412 #elif USE_AES_NI
413 /*----------*/
414
415 #include <wmmintrin.h>
416
417 #if (OCB_KEY_LEN == 0)
418 typedef struct {
419 __m128i rd_key[15];
420 int rounds;
421 } AES_KEY;
422 #define ROUNDS(ctx) ((ctx)->rounds)
423 #else
424 typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
425 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
426 #endif
427
428 #define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const) \
429 v2 = _mm_aeskeygenassist_si128(v4, aes_const); \
430 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16)); \
431 v1 = _mm_xor_si128(v1, v3); \
432 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140)); \
433 v1 = _mm_xor_si128(v1, v3); \
434 v2 = _mm_shuffle_epi32(v2, shuff_const); \
435 v1 = _mm_xor_si128(v1, v2)
436
437 #define EXPAND192_STEP(idx, aes_const) \
438 EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const); \
439 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
440 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
441 kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68)); \
442 kp[idx + 1] = \
443 _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78)); \
444 EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2)); \
445 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
446 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
447 kp[idx + 2] = x0; \
448 tmp = x3
449
AES_128_Key_Expansion(const unsigned char * userkey,void * key)450 static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
451 __m128i x0, x1, x2;
452 __m128i* kp = (__m128i*)key;
453 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
454 x2 = _mm_setzero_si128();
455 EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
456 kp[1] = x0;
457 EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
458 kp[2] = x0;
459 EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
460 kp[3] = x0;
461 EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
462 kp[4] = x0;
463 EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
464 kp[5] = x0;
465 EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
466 kp[6] = x0;
467 EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
468 kp[7] = x0;
469 EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
470 kp[8] = x0;
471 EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
472 kp[9] = x0;
473 EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
474 kp[10] = x0;
475 }
476
AES_192_Key_Expansion(const unsigned char * userkey,void * key)477 static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
478 __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
479 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
480 tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
481 x2 = _mm_setzero_si128();
482 EXPAND192_STEP(1, 1);
483 EXPAND192_STEP(4, 4);
484 EXPAND192_STEP(7, 16);
485 EXPAND192_STEP(10, 64);
486 }
487
AES_256_Key_Expansion(const unsigned char * userkey,void * key)488 static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
489 __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
490 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
491 kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
492 x2 = _mm_setzero_si128();
493 EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
494 kp[2] = x0;
495 EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
496 kp[3] = x3;
497 EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
498 kp[4] = x0;
499 EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
500 kp[5] = x3;
501 EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
502 kp[6] = x0;
503 EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
504 kp[7] = x3;
505 EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
506 kp[8] = x0;
507 EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
508 kp[9] = x3;
509 EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
510 kp[10] = x0;
511 EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
512 kp[11] = x3;
513 EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
514 kp[12] = x0;
515 EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
516 kp[13] = x3;
517 EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
518 kp[14] = x0;
519 }
520
AES_set_encrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)521 static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
522 if (bits == 128) {
523 AES_128_Key_Expansion(userKey, key);
524 } else if (bits == 192) {
525 AES_192_Key_Expansion(userKey, key);
526 } else if (bits == 256) {
527 AES_256_Key_Expansion(userKey, key);
528 }
529 #if (OCB_KEY_LEN == 0)
530 key->rounds = 6 + bits / 32;
531 #endif
532 return 0;
533 }
534
AES_set_decrypt_key_fast(AES_KEY * dkey,const AES_KEY * ekey)535 static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
536 int j = 0;
537 int i = ROUNDS(ekey);
538 #if (OCB_KEY_LEN == 0)
539 dkey->rounds = i;
540 #endif
541 dkey->rd_key[i--] = ekey->rd_key[j++];
542 while (i)
543 dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
544 dkey->rd_key[i] = ekey->rd_key[j];
545 }
546
AES_set_decrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)547 static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
548 AES_KEY temp_key;
549 AES_set_encrypt_key(userKey, bits, &temp_key);
550 AES_set_decrypt_key_fast(key, &temp_key);
551 return 0;
552 }
553
AES_encrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)554 static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
555 int j, rnds = ROUNDS(key);
556 const __m128i* sched = ((__m128i*)(key->rd_key));
557 __m128i tmp = _mm_load_si128((__m128i*)in);
558 tmp = _mm_xor_si128(tmp, sched[0]);
559 for (j = 1; j < rnds; j++)
560 tmp = _mm_aesenc_si128(tmp, sched[j]);
561 tmp = _mm_aesenclast_si128(tmp, sched[j]);
562 _mm_store_si128((__m128i*)out, tmp);
563 }
564
AES_decrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)565 static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
566 int j, rnds = ROUNDS(key);
567 const __m128i* sched = ((__m128i*)(key->rd_key));
568 __m128i tmp = _mm_load_si128((__m128i*)in);
569 tmp = _mm_xor_si128(tmp, sched[0]);
570 for (j = 1; j < rnds; j++)
571 tmp = _mm_aesdec_si128(tmp, sched[j]);
572 tmp = _mm_aesdeclast_si128(tmp, sched[j]);
573 _mm_store_si128((__m128i*)out, tmp);
574 }
575
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)576 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
577 unsigned i, j, rnds = ROUNDS(key);
578 const __m128i* sched = ((__m128i*)(key->rd_key));
579 for (i = 0; i < nblks; ++i)
580 blks[i] = _mm_xor_si128(blks[i], sched[0]);
581 for (j = 1; j < rnds; ++j)
582 for (i = 0; i < nblks; ++i)
583 blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
584 for (i = 0; i < nblks; ++i)
585 blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
586 }
587
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)588 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
589 unsigned i, j, rnds = ROUNDS(key);
590 const __m128i* sched = ((__m128i*)(key->rd_key));
591 for (i = 0; i < nblks; ++i)
592 blks[i] = _mm_xor_si128(blks[i], sched[0]);
593 for (j = 1; j < rnds; ++j)
594 for (i = 0; i < nblks; ++i)
595 blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
596 for (i = 0; i < nblks; ++i)
597 blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
598 }
599
600 #define BPI 8 /* Number of blocks in buffer per ECB call */
601 /* Set to 4 for Westmere, 8 for Sandy Bridge */
602
603 #endif
604
605 /* ----------------------------------------------------------------------- */
606 /* Define OCB context structure. */
607 /* ----------------------------------------------------------------------- */
608
609 /*------------------------------------------------------------------------
610 / Each item in the OCB context is stored either "memory correct" or
611 / "register correct". On big-endian machines, this is identical. On
612 / little-endian machines, one must choose whether the byte-string
613 / is in the correct order when it resides in memory or in registers.
614 / It must be register correct whenever it is to be manipulated
615 / arithmetically, but must be memory correct whenever it interacts
616 / with the plaintext or ciphertext.
617 /------------------------------------------------------------------------- */
618
619 struct _ae_ctx {
620 block offset; /* Memory correct */
621 block checksum; /* Memory correct */
622 block Lstar; /* Memory correct */
623 block Ldollar; /* Memory correct */
624 block L[L_TABLE_SZ]; /* Memory correct */
625 block ad_checksum; /* Memory correct */
626 block ad_offset; /* Memory correct */
627 block cached_Top; /* Memory correct */
628 uint64_t KtopStr[3]; /* Register correct, each item */
629 uint32_t ad_blocks_processed;
630 uint32_t blocks_processed;
631 AES_KEY decrypt_key;
632 AES_KEY encrypt_key;
633 #if (OCB_TAG_LEN == 0)
634 unsigned tag_len;
635 #endif
636 };
637
638 /* ----------------------------------------------------------------------- */
639 /* L table lookup (or on-the-fly generation) */
640 /* ----------------------------------------------------------------------- */
641
642 #if L_TABLE_SZ_IS_ENOUGH
643 #define getL(_ctx, _tz) ((_ctx)->L[_tz])
644 #else
getL(const ae_ctx * ctx,unsigned tz)645 static block getL(const ae_ctx* ctx, unsigned tz) {
646 if (tz < L_TABLE_SZ)
647 return ctx->L[tz];
648 else {
649 unsigned i;
650 /* Bring L[MAX] into registers, make it register correct */
651 block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
652 rval = double_block(rval);
653 for (i = L_TABLE_SZ; i < tz; i++)
654 rval = double_block(rval);
655 return swap_if_le(rval); /* To memory correct */
656 }
657 }
658 #endif
659
660 /* ----------------------------------------------------------------------- */
661 /* Public functions */
662 /* ----------------------------------------------------------------------- */
663
664 /* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
665 on 16-byte alignments. (I believe all major 64-bit systems do already.) */
666
ae_allocate(void * misc)667 ae_ctx* ae_allocate(void* misc) {
668 void* p;
669 (void)misc; /* misc unused in this implementation */
670 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
671 p = _mm_malloc(sizeof(ae_ctx), 16);
672 #elif(__ALTIVEC__ && !__PPC64__)
673 if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
674 p = NULL;
675 #elif __ARM_NEON__
676 p = memalign(16, sizeof(ae_ctx));
677 #else
678 p = malloc(sizeof(ae_ctx));
679 #endif
680 return (ae_ctx*)p;
681 }
682
ae_free(ae_ctx * ctx)683 void ae_free(ae_ctx* ctx) {
684 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
685 _mm_free(ctx);
686 #else
687 free(ctx);
688 #endif
689 }
690
691 /* ----------------------------------------------------------------------- */
692
ae_clear(ae_ctx * ctx)693 int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization */
694 {
695 memset(ctx, 0, sizeof(ae_ctx));
696 return AE_SUCCESS;
697 }
698
ae_ctx_sizeof(void)699 int ae_ctx_sizeof(void) {
700 return (int)sizeof(ae_ctx);
701 }
702
703 /* ----------------------------------------------------------------------- */
704
ae_init(ae_ctx * ctx,const void * key,int key_len,int nonce_len,int tag_len)705 int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
706 unsigned i;
707 block tmp_blk;
708
709 if (nonce_len != 12)
710 return AE_NOT_SUPPORTED;
711
712 /* Initialize encryption & decryption keys */
713 #if (OCB_KEY_LEN > 0)
714 key_len = OCB_KEY_LEN;
715 #endif
716 AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
717 #if USE_AES_NI
718 AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
719 #else
720 AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
721 #endif
722
723 /* Zero things that need zeroing */
724 ctx->cached_Top = ctx->ad_checksum = zero_block();
725 ctx->ad_blocks_processed = 0;
726
727 /* Compute key-dependent values */
728 AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
729 tmp_blk = swap_if_le(ctx->Lstar);
730 tmp_blk = double_block(tmp_blk);
731 ctx->Ldollar = swap_if_le(tmp_blk);
732 tmp_blk = double_block(tmp_blk);
733 ctx->L[0] = swap_if_le(tmp_blk);
734 for (i = 1; i < L_TABLE_SZ; i++) {
735 tmp_blk = double_block(tmp_blk);
736 ctx->L[i] = swap_if_le(tmp_blk);
737 }
738
739 #if (OCB_TAG_LEN == 0)
740 ctx->tag_len = tag_len;
741 #else
742 (void)tag_len; /* Suppress var not used error */
743 #endif
744
745 return AE_SUCCESS;
746 }
747
748 /* ----------------------------------------------------------------------- */
749
gen_offset_from_nonce(ae_ctx * ctx,const void * nonce)750 static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
751 const union {
752 unsigned x;
753 unsigned char endian;
754 } little = {1};
755 union {
756 uint32_t u32[4];
757 uint8_t u8[16];
758 block bl;
759 } tmp;
760 unsigned idx;
761
762 /* Replace cached nonce Top if needed */
763 #if (OCB_TAG_LEN > 0)
764 if (little.endian)
765 tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
766 else
767 tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
768 #else
769 if (little.endian)
770 tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
771 else
772 tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
773 #endif
774 tmp.u32[1] = ((uint32_t*)nonce)[0];
775 tmp.u32[2] = ((uint32_t*)nonce)[1];
776 tmp.u32[3] = ((uint32_t*)nonce)[2];
777 idx = (unsigned)(tmp.u8[15] & 0x3f); /* Get low 6 bits of nonce */
778 tmp.u8[15] = tmp.u8[15] & 0xc0; /* Zero low 6 bits of nonce */
779 if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached? */
780 ctx->cached_Top = tmp.bl; /* Update cache, KtopStr */
781 AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
782 if (little.endian) { /* Make Register Correct */
783 ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
784 ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
785 }
786 ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
787 }
788 return gen_offset(ctx->KtopStr, idx);
789 }
790
process_ad(ae_ctx * ctx,const void * ad,int ad_len,int final)791 static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
792 union {
793 uint32_t u32[4];
794 uint8_t u8[16];
795 block bl;
796 } tmp;
797 block ad_offset, ad_checksum;
798 const block* adp = (block*)ad;
799 unsigned i, k, tz, remaining;
800
801 ad_offset = ctx->ad_offset;
802 ad_checksum = ctx->ad_checksum;
803 i = ad_len / (BPI * 16);
804 if (i) {
805 unsigned ad_block_num = ctx->ad_blocks_processed;
806 do {
807 block ta[BPI], oa[BPI];
808 ad_block_num += BPI;
809 tz = ntz(ad_block_num);
810 oa[0] = xor_block(ad_offset, ctx->L[0]);
811 ta[0] = xor_block(oa[0], adp[0]);
812 oa[1] = xor_block(oa[0], ctx->L[1]);
813 ta[1] = xor_block(oa[1], adp[1]);
814 oa[2] = xor_block(ad_offset, ctx->L[1]);
815 ta[2] = xor_block(oa[2], adp[2]);
816 #if BPI == 4
817 ad_offset = xor_block(oa[2], getL(ctx, tz));
818 ta[3] = xor_block(ad_offset, adp[3]);
819 #elif BPI == 8
820 oa[3] = xor_block(oa[2], ctx->L[2]);
821 ta[3] = xor_block(oa[3], adp[3]);
822 oa[4] = xor_block(oa[1], ctx->L[2]);
823 ta[4] = xor_block(oa[4], adp[4]);
824 oa[5] = xor_block(oa[0], ctx->L[2]);
825 ta[5] = xor_block(oa[5], adp[5]);
826 oa[6] = xor_block(ad_offset, ctx->L[2]);
827 ta[6] = xor_block(oa[6], adp[6]);
828 ad_offset = xor_block(oa[6], getL(ctx, tz));
829 ta[7] = xor_block(ad_offset, adp[7]);
830 #endif
831 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
832 ad_checksum = xor_block(ad_checksum, ta[0]);
833 ad_checksum = xor_block(ad_checksum, ta[1]);
834 ad_checksum = xor_block(ad_checksum, ta[2]);
835 ad_checksum = xor_block(ad_checksum, ta[3]);
836 #if (BPI == 8)
837 ad_checksum = xor_block(ad_checksum, ta[4]);
838 ad_checksum = xor_block(ad_checksum, ta[5]);
839 ad_checksum = xor_block(ad_checksum, ta[6]);
840 ad_checksum = xor_block(ad_checksum, ta[7]);
841 #endif
842 adp += BPI;
843 } while (--i);
844 ctx->ad_blocks_processed = ad_block_num;
845 ctx->ad_offset = ad_offset;
846 ctx->ad_checksum = ad_checksum;
847 }
848
849 if (final) {
850 block ta[BPI];
851
852 /* Process remaining associated data, compute its tag contribution */
853 remaining = ((unsigned)ad_len) % (BPI * 16);
854 if (remaining) {
855 k = 0;
856 #if (BPI == 8)
857 if (remaining >= 64) {
858 tmp.bl = xor_block(ad_offset, ctx->L[0]);
859 ta[0] = xor_block(tmp.bl, adp[0]);
860 tmp.bl = xor_block(tmp.bl, ctx->L[1]);
861 ta[1] = xor_block(tmp.bl, adp[1]);
862 ad_offset = xor_block(ad_offset, ctx->L[1]);
863 ta[2] = xor_block(ad_offset, adp[2]);
864 ad_offset = xor_block(ad_offset, ctx->L[2]);
865 ta[3] = xor_block(ad_offset, adp[3]);
866 remaining -= 64;
867 k = 4;
868 }
869 #endif
870 if (remaining >= 32) {
871 ad_offset = xor_block(ad_offset, ctx->L[0]);
872 ta[k] = xor_block(ad_offset, adp[k]);
873 ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
874 ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
875 remaining -= 32;
876 k += 2;
877 }
878 if (remaining >= 16) {
879 ad_offset = xor_block(ad_offset, ctx->L[0]);
880 ta[k] = xor_block(ad_offset, adp[k]);
881 remaining = remaining - 16;
882 ++k;
883 }
884 if (remaining) {
885 ad_offset = xor_block(ad_offset, ctx->Lstar);
886 tmp.bl = zero_block();
887 memcpy(tmp.u8, adp + k, remaining);
888 tmp.u8[remaining] = (unsigned char)0x80u;
889 ta[k] = xor_block(ad_offset, tmp.bl);
890 ++k;
891 }
892 AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
893 switch (k) {
894 #if (BPI == 8)
895 case 8:
896 ad_checksum = xor_block(ad_checksum, ta[7]);
897 case 7:
898 ad_checksum = xor_block(ad_checksum, ta[6]);
899 case 6:
900 ad_checksum = xor_block(ad_checksum, ta[5]);
901 case 5:
902 ad_checksum = xor_block(ad_checksum, ta[4]);
903 #endif
904 case 4:
905 ad_checksum = xor_block(ad_checksum, ta[3]);
906 case 3:
907 ad_checksum = xor_block(ad_checksum, ta[2]);
908 case 2:
909 ad_checksum = xor_block(ad_checksum, ta[1]);
910 case 1:
911 ad_checksum = xor_block(ad_checksum, ta[0]);
912 }
913 ctx->ad_checksum = ad_checksum;
914 }
915 }
916 }
917
918 /* ----------------------------------------------------------------------- */
919
ae_encrypt(ae_ctx * ctx,const void * nonce,const void * pt,int pt_len,const void * ad,int ad_len,void * ct,void * tag,int final)920 int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
921 int ad_len, void* ct, void* tag, int final) {
922 union {
923 uint32_t u32[4];
924 uint8_t u8[16];
925 block bl;
926 } tmp;
927 block offset, checksum;
928 unsigned i, k;
929 block* ctp = (block*)ct;
930 const block* ptp = (block*)pt;
931
932 /* Non-null nonce means start of new message, init per-message values */
933 if (nonce) {
934 ctx->offset = gen_offset_from_nonce(ctx, nonce);
935 ctx->ad_offset = ctx->checksum = zero_block();
936 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
937 if (ad_len >= 0)
938 ctx->ad_checksum = zero_block();
939 }
940
941 /* Process associated data */
942 if (ad_len > 0)
943 process_ad(ctx, ad, ad_len, final);
944
945 /* Encrypt plaintext data BPI blocks at a time */
946 offset = ctx->offset;
947 checksum = ctx->checksum;
948 i = pt_len / (BPI * 16);
949 if (i) {
950 block oa[BPI];
951 unsigned block_num = ctx->blocks_processed;
952 oa[BPI - 1] = offset;
953 do {
954 block ta[BPI];
955 block_num += BPI;
956 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
957 ta[0] = xor_block(oa[0], ptp[0]);
958 checksum = xor_block(checksum, ptp[0]);
959 oa[1] = xor_block(oa[0], ctx->L[1]);
960 ta[1] = xor_block(oa[1], ptp[1]);
961 checksum = xor_block(checksum, ptp[1]);
962 oa[2] = xor_block(oa[1], ctx->L[0]);
963 ta[2] = xor_block(oa[2], ptp[2]);
964 checksum = xor_block(checksum, ptp[2]);
965 #if BPI == 4
966 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
967 ta[3] = xor_block(oa[3], ptp[3]);
968 checksum = xor_block(checksum, ptp[3]);
969 #elif BPI == 8
970 oa[3] = xor_block(oa[2], ctx->L[2]);
971 ta[3] = xor_block(oa[3], ptp[3]);
972 checksum = xor_block(checksum, ptp[3]);
973 oa[4] = xor_block(oa[1], ctx->L[2]);
974 ta[4] = xor_block(oa[4], ptp[4]);
975 checksum = xor_block(checksum, ptp[4]);
976 oa[5] = xor_block(oa[0], ctx->L[2]);
977 ta[5] = xor_block(oa[5], ptp[5]);
978 checksum = xor_block(checksum, ptp[5]);
979 oa[6] = xor_block(oa[7], ctx->L[2]);
980 ta[6] = xor_block(oa[6], ptp[6]);
981 checksum = xor_block(checksum, ptp[6]);
982 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
983 ta[7] = xor_block(oa[7], ptp[7]);
984 checksum = xor_block(checksum, ptp[7]);
985 #endif
986 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
987 ctp[0] = xor_block(ta[0], oa[0]);
988 ctp[1] = xor_block(ta[1], oa[1]);
989 ctp[2] = xor_block(ta[2], oa[2]);
990 ctp[3] = xor_block(ta[3], oa[3]);
991 #if (BPI == 8)
992 ctp[4] = xor_block(ta[4], oa[4]);
993 ctp[5] = xor_block(ta[5], oa[5]);
994 ctp[6] = xor_block(ta[6], oa[6]);
995 ctp[7] = xor_block(ta[7], oa[7]);
996 #endif
997 ptp += BPI;
998 ctp += BPI;
999 } while (--i);
1000 ctx->offset = offset = oa[BPI - 1];
1001 ctx->blocks_processed = block_num;
1002 ctx->checksum = checksum;
1003 }
1004
1005 if (final) {
1006 block ta[BPI + 1], oa[BPI];
1007
1008 /* Process remaining plaintext and compute its tag contribution */
1009 unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1010 k = 0; /* How many blocks in ta[] need ECBing */
1011 if (remaining) {
1012 #if (BPI == 8)
1013 if (remaining >= 64) {
1014 oa[0] = xor_block(offset, ctx->L[0]);
1015 ta[0] = xor_block(oa[0], ptp[0]);
1016 checksum = xor_block(checksum, ptp[0]);
1017 oa[1] = xor_block(oa[0], ctx->L[1]);
1018 ta[1] = xor_block(oa[1], ptp[1]);
1019 checksum = xor_block(checksum, ptp[1]);
1020 oa[2] = xor_block(oa[1], ctx->L[0]);
1021 ta[2] = xor_block(oa[2], ptp[2]);
1022 checksum = xor_block(checksum, ptp[2]);
1023 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1024 ta[3] = xor_block(offset, ptp[3]);
1025 checksum = xor_block(checksum, ptp[3]);
1026 remaining -= 64;
1027 k = 4;
1028 }
1029 #endif
1030 if (remaining >= 32) {
1031 oa[k] = xor_block(offset, ctx->L[0]);
1032 ta[k] = xor_block(oa[k], ptp[k]);
1033 checksum = xor_block(checksum, ptp[k]);
1034 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1035 ta[k + 1] = xor_block(offset, ptp[k + 1]);
1036 checksum = xor_block(checksum, ptp[k + 1]);
1037 remaining -= 32;
1038 k += 2;
1039 }
1040 if (remaining >= 16) {
1041 offset = oa[k] = xor_block(offset, ctx->L[0]);
1042 ta[k] = xor_block(offset, ptp[k]);
1043 checksum = xor_block(checksum, ptp[k]);
1044 remaining -= 16;
1045 ++k;
1046 }
1047 if (remaining) {
1048 tmp.bl = zero_block();
1049 memcpy(tmp.u8, ptp + k, remaining);
1050 tmp.u8[remaining] = (unsigned char)0x80u;
1051 checksum = xor_block(checksum, tmp.bl);
1052 ta[k] = offset = xor_block(offset, ctx->Lstar);
1053 ++k;
1054 }
1055 }
1056 offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1057 ta[k] = xor_block(offset, checksum); /* Part of tag gen */
1058 AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1059 offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1060 if (remaining) {
1061 --k;
1062 tmp.bl = xor_block(tmp.bl, ta[k]);
1063 memcpy(ctp + k, tmp.u8, remaining);
1064 }
1065 switch (k) {
1066 #if (BPI == 8)
1067 case 7:
1068 ctp[6] = xor_block(ta[6], oa[6]);
1069 case 6:
1070 ctp[5] = xor_block(ta[5], oa[5]);
1071 case 5:
1072 ctp[4] = xor_block(ta[4], oa[4]);
1073 case 4:
1074 ctp[3] = xor_block(ta[3], oa[3]);
1075 #endif
1076 case 3:
1077 ctp[2] = xor_block(ta[2], oa[2]);
1078 case 2:
1079 ctp[1] = xor_block(ta[1], oa[1]);
1080 case 1:
1081 ctp[0] = xor_block(ta[0], oa[0]);
1082 }
1083
1084 /* Tag is placed at the correct location
1085 */
1086 if (tag) {
1087 #if (OCB_TAG_LEN == 16)
1088 *(block*)tag = offset;
1089 #elif(OCB_TAG_LEN > 0)
1090 memcpy((char*)tag, &offset, OCB_TAG_LEN);
1091 #else
1092 memcpy((char*)tag, &offset, ctx->tag_len);
1093 #endif
1094 } else {
1095 #if (OCB_TAG_LEN > 0)
1096 memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1097 pt_len += OCB_TAG_LEN;
1098 #else
1099 memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1100 pt_len += ctx->tag_len;
1101 #endif
1102 }
1103 }
1104 return (int)pt_len;
1105 }
1106
1107 /* ----------------------------------------------------------------------- */
1108
1109 /* Compare two regions of memory, taking a constant amount of time for a
1110 given buffer size -- under certain assumptions about the compiler
1111 and machine, of course.
1112
1113 Use this to avoid timing side-channel attacks.
1114
1115 Returns 0 for memory regions with equal contents; non-zero otherwise. */
constant_time_memcmp(const void * av,const void * bv,size_t n)1116 static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1117 const uint8_t* a = (const uint8_t*)av;
1118 const uint8_t* b = (const uint8_t*)bv;
1119 uint8_t result = 0;
1120 size_t i;
1121
1122 for (i = 0; i < n; i++) {
1123 result |= *a ^ *b;
1124 a++;
1125 b++;
1126 }
1127
1128 return (int)result;
1129 }
1130
ae_decrypt(ae_ctx * ctx,const void * nonce,const void * ct,int ct_len,const void * ad,int ad_len,void * pt,const void * tag,int final)1131 int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1132 int ad_len, void* pt, const void* tag, int final) {
1133 union {
1134 uint32_t u32[4];
1135 uint8_t u8[16];
1136 block bl;
1137 } tmp;
1138 block offset, checksum;
1139 unsigned i, k;
1140 block* ctp = (block*)ct;
1141 block* ptp = (block*)pt;
1142
1143 /* Reduce ct_len tag bundled in ct */
1144 if ((final) && (!tag))
1145 #if (OCB_TAG_LEN > 0)
1146 ct_len -= OCB_TAG_LEN;
1147 #else
1148 ct_len -= ctx->tag_len;
1149 #endif
1150
1151 /* Non-null nonce means start of new message, init per-message values */
1152 if (nonce) {
1153 ctx->offset = gen_offset_from_nonce(ctx, nonce);
1154 ctx->ad_offset = ctx->checksum = zero_block();
1155 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1156 if (ad_len >= 0)
1157 ctx->ad_checksum = zero_block();
1158 }
1159
1160 /* Process associated data */
1161 if (ad_len > 0)
1162 process_ad(ctx, ad, ad_len, final);
1163
1164 /* Encrypt plaintext data BPI blocks at a time */
1165 offset = ctx->offset;
1166 checksum = ctx->checksum;
1167 i = ct_len / (BPI * 16);
1168 if (i) {
1169 block oa[BPI];
1170 unsigned block_num = ctx->blocks_processed;
1171 oa[BPI - 1] = offset;
1172 do {
1173 block ta[BPI];
1174 block_num += BPI;
1175 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1176 ta[0] = xor_block(oa[0], ctp[0]);
1177 oa[1] = xor_block(oa[0], ctx->L[1]);
1178 ta[1] = xor_block(oa[1], ctp[1]);
1179 oa[2] = xor_block(oa[1], ctx->L[0]);
1180 ta[2] = xor_block(oa[2], ctp[2]);
1181 #if BPI == 4
1182 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1183 ta[3] = xor_block(oa[3], ctp[3]);
1184 #elif BPI == 8
1185 oa[3] = xor_block(oa[2], ctx->L[2]);
1186 ta[3] = xor_block(oa[3], ctp[3]);
1187 oa[4] = xor_block(oa[1], ctx->L[2]);
1188 ta[4] = xor_block(oa[4], ctp[4]);
1189 oa[5] = xor_block(oa[0], ctx->L[2]);
1190 ta[5] = xor_block(oa[5], ctp[5]);
1191 oa[6] = xor_block(oa[7], ctx->L[2]);
1192 ta[6] = xor_block(oa[6], ctp[6]);
1193 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1194 ta[7] = xor_block(oa[7], ctp[7]);
1195 #endif
1196 AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1197 ptp[0] = xor_block(ta[0], oa[0]);
1198 checksum = xor_block(checksum, ptp[0]);
1199 ptp[1] = xor_block(ta[1], oa[1]);
1200 checksum = xor_block(checksum, ptp[1]);
1201 ptp[2] = xor_block(ta[2], oa[2]);
1202 checksum = xor_block(checksum, ptp[2]);
1203 ptp[3] = xor_block(ta[3], oa[3]);
1204 checksum = xor_block(checksum, ptp[3]);
1205 #if (BPI == 8)
1206 ptp[4] = xor_block(ta[4], oa[4]);
1207 checksum = xor_block(checksum, ptp[4]);
1208 ptp[5] = xor_block(ta[5], oa[5]);
1209 checksum = xor_block(checksum, ptp[5]);
1210 ptp[6] = xor_block(ta[6], oa[6]);
1211 checksum = xor_block(checksum, ptp[6]);
1212 ptp[7] = xor_block(ta[7], oa[7]);
1213 checksum = xor_block(checksum, ptp[7]);
1214 #endif
1215 ptp += BPI;
1216 ctp += BPI;
1217 } while (--i);
1218 ctx->offset = offset = oa[BPI - 1];
1219 ctx->blocks_processed = block_num;
1220 ctx->checksum = checksum;
1221 }
1222
1223 if (final) {
1224 block ta[BPI + 1], oa[BPI];
1225
1226 /* Process remaining plaintext and compute its tag contribution */
1227 unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1228 k = 0; /* How many blocks in ta[] need ECBing */
1229 if (remaining) {
1230 #if (BPI == 8)
1231 if (remaining >= 64) {
1232 oa[0] = xor_block(offset, ctx->L[0]);
1233 ta[0] = xor_block(oa[0], ctp[0]);
1234 oa[1] = xor_block(oa[0], ctx->L[1]);
1235 ta[1] = xor_block(oa[1], ctp[1]);
1236 oa[2] = xor_block(oa[1], ctx->L[0]);
1237 ta[2] = xor_block(oa[2], ctp[2]);
1238 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1239 ta[3] = xor_block(offset, ctp[3]);
1240 remaining -= 64;
1241 k = 4;
1242 }
1243 #endif
1244 if (remaining >= 32) {
1245 oa[k] = xor_block(offset, ctx->L[0]);
1246 ta[k] = xor_block(oa[k], ctp[k]);
1247 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1248 ta[k + 1] = xor_block(offset, ctp[k + 1]);
1249 remaining -= 32;
1250 k += 2;
1251 }
1252 if (remaining >= 16) {
1253 offset = oa[k] = xor_block(offset, ctx->L[0]);
1254 ta[k] = xor_block(offset, ctp[k]);
1255 remaining -= 16;
1256 ++k;
1257 }
1258 if (remaining) {
1259 block pad;
1260 offset = xor_block(offset, ctx->Lstar);
1261 AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1262 pad = tmp.bl;
1263 memcpy(tmp.u8, ctp + k, remaining);
1264 tmp.bl = xor_block(tmp.bl, pad);
1265 tmp.u8[remaining] = (unsigned char)0x80u;
1266 memcpy(ptp + k, tmp.u8, remaining);
1267 checksum = xor_block(checksum, tmp.bl);
1268 }
1269 }
1270 AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1271 switch (k) {
1272 #if (BPI == 8)
1273 case 7:
1274 ptp[6] = xor_block(ta[6], oa[6]);
1275 checksum = xor_block(checksum, ptp[6]);
1276 case 6:
1277 ptp[5] = xor_block(ta[5], oa[5]);
1278 checksum = xor_block(checksum, ptp[5]);
1279 case 5:
1280 ptp[4] = xor_block(ta[4], oa[4]);
1281 checksum = xor_block(checksum, ptp[4]);
1282 case 4:
1283 ptp[3] = xor_block(ta[3], oa[3]);
1284 checksum = xor_block(checksum, ptp[3]);
1285 #endif
1286 case 3:
1287 ptp[2] = xor_block(ta[2], oa[2]);
1288 checksum = xor_block(checksum, ptp[2]);
1289 case 2:
1290 ptp[1] = xor_block(ta[1], oa[1]);
1291 checksum = xor_block(checksum, ptp[1]);
1292 case 1:
1293 ptp[0] = xor_block(ta[0], oa[0]);
1294 checksum = xor_block(checksum, ptp[0]);
1295 }
1296
1297 /* Calculate expected tag */
1298 offset = xor_block(offset, ctx->Ldollar);
1299 tmp.bl = xor_block(offset, checksum);
1300 AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1301 tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1302
1303 /* Compare with proposed tag, change ct_len if invalid */
1304 if ((OCB_TAG_LEN == 16) && tag) {
1305 if (unequal_blocks(tmp.bl, *(block*)tag))
1306 ct_len = AE_INVALID;
1307 } else {
1308 #if (OCB_TAG_LEN > 0)
1309 int len = OCB_TAG_LEN;
1310 #else
1311 int len = ctx->tag_len;
1312 #endif
1313 if (tag) {
1314 if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1315 ct_len = AE_INVALID;
1316 } else {
1317 if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1318 ct_len = AE_INVALID;
1319 }
1320 }
1321 }
1322 return ct_len;
1323 }
1324
1325 /* ----------------------------------------------------------------------- */
1326 /* Simple test program */
1327 /* ----------------------------------------------------------------------- */
1328
1329 #if 0
1330
1331 #include <stdio.h>
1332 #include <time.h>
1333
1334 #if __GNUC__
1335 #define ALIGN(n) __attribute__((aligned(n)))
1336 #elif _MSC_VER
1337 #define ALIGN(n) __declspec(align(n))
1338 #else /* Not GNU/Microsoft: delete alignment uses. */
1339 #define ALIGN(n)
1340 #endif
1341
1342 static void pbuf(void *p, unsigned len, const void *s)
1343 {
1344 unsigned i;
1345 if (s)
1346 printf("%s", (char *)s);
1347 for (i = 0; i < len; i++)
1348 printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1349 printf("\n");
1350 }
1351
1352 static void vectors(ae_ctx *ctx, int len)
1353 {
1354 ALIGN(16) char pt[128];
1355 ALIGN(16) char ct[144];
1356 ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1357 int i;
1358 for (i=0; i < 128; i++) pt[i] = i;
1359 i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1360 printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1361 i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1362 printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1363 i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1364 printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1365 }
1366
1367 void validate()
1368 {
1369 ALIGN(16) char pt[1024];
1370 ALIGN(16) char ct[1024];
1371 ALIGN(16) char tag[16];
1372 ALIGN(16) char nonce[12] = {0,};
1373 ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1374 ae_ctx ctx;
1375 char *val_buf, *next;
1376 int i, len;
1377
1378 val_buf = (char *)malloc(22400 + 16);
1379 next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1380
1381 if (0) {
1382 ae_init(&ctx, key, 16, 12, 16);
1383 /* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1384 vectors(&ctx,0);
1385 vectors(&ctx,8);
1386 vectors(&ctx,16);
1387 vectors(&ctx,24);
1388 vectors(&ctx,32);
1389 vectors(&ctx,40);
1390 }
1391
1392 memset(key,0,32);
1393 memset(pt,0,128);
1394 ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1395
1396 /* RFC Vector test */
1397 for (i = 0; i < 128; i++) {
1398 int first = ((i/3)/(BPI*16))*(BPI*16);
1399 int second = first;
1400 int third = i - (first + second);
1401
1402 nonce[11] = i;
1403
1404 if (0) {
1405 ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1406 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1407 next = next+i+OCB_TAG_LEN;
1408
1409 ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1410 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1411 next = next+i+OCB_TAG_LEN;
1412
1413 ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1414 memcpy(next,ct,OCB_TAG_LEN);
1415 next = next+OCB_TAG_LEN;
1416 } else {
1417 ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1418 ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1419 ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1420 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1421 next = next+i+OCB_TAG_LEN;
1422
1423 ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1424 ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1425 ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1426 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1427 next = next+i+OCB_TAG_LEN;
1428
1429 ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1430 ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1431 ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1432 memcpy(next,ct,OCB_TAG_LEN);
1433 next = next+OCB_TAG_LEN;
1434 }
1435
1436 }
1437 nonce[11] = 0;
1438 ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1439 pbuf(tag,OCB_TAG_LEN,0);
1440
1441
1442 /* Encrypt/Decrypt test */
1443 for (i = 0; i < 128; i++) {
1444 int first = ((i/3)/(BPI*16))*(BPI*16);
1445 int second = first;
1446 int third = i - (first + second);
1447
1448 nonce[11] = i%128;
1449
1450 if (1) {
1451 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1452 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1453 len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1454 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1455 if (len != i) { printf("Length error: %d\n", i); return; }
1456 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1457 } else {
1458 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1459 ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1460 ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1461 len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1462 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1463 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1464 }
1465
1466 }
1467 printf("Decrypt: PASS\n");
1468 }
1469
1470 int main()
1471 {
1472 validate();
1473 return 0;
1474 }
1475 #endif
1476
1477 #if USE_AES_NI
1478 char infoString[] = "OCB3 (AES-NI)";
1479 #elif USE_REFERENCE_AES
1480 char infoString[] = "OCB3 (Reference)";
1481 #elif USE_OPENSSL_AES
1482 char infoString[] = "OCB3 (OpenSSL)";
1483 #endif
1484