1 /*------------------------------------------------------------------------
2 / OCB Version 3 Reference Code (Optimized C) Last modified 12-JUN-2013
3 /-------------------------------------------------------------------------
4 / Copyright (c) 2013 Ted Krovetz.
5 /
6 / Permission to use, copy, modify, and/or distribute this software for any
7 / purpose with or without fee is hereby granted, provided that the above
8 / copyright notice and this permission notice appear in all copies.
9 /
10 / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 /
18 / Phillip Rogaway holds patents relevant to OCB. See the following for
19 / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20 /
21 / Special thanks to Keegan McAllister for suggesting several good improvements
22 /
23 / Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24 /------------------------------------------------------------------------- */
25
26 /* ----------------------------------------------------------------------- */
27 /* Usage notes */
28 /* ----------------------------------------------------------------------- */
29
30 /* - When AE_PENDING is passed as the 'final' parameter of any function,
31 / the length parameters must be a multiple of (BPI*16).
32 / - When available, SSE or AltiVec registers are used to manipulate data.
33 / So, when on machines with these facilities, all pointers passed to
34 / any function should be 16-byte aligned.
35 / - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36 / encrypted in-place), but no other pair of pointers may be equal.
37 / - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38 / when compiling under MSVC. If untrue, alter the #define.
39 / - This code is tested for C99 and recent versions of GCC and MSVC. */
40
41 /* ----------------------------------------------------------------------- */
42 /* User configuration options */
43 /* ----------------------------------------------------------------------- */
44
45 /* Set the AES key length to use and length of authentication tag to produce.
46 / Setting either to 0 requires the value be set at runtime via ae_init().
47 / Some optimizations occur for each when set to a fixed value. */
48 #define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49 #define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init */
50
51 /* This implementation has built-in support for multiple AES APIs. Set any
52 / one of the following to non-zero to specify which to use. */
53 #define USE_OPENSSL_AES 1 /* http://openssl.org */
54 #define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c */
55 #define USE_AES_NI 0 /* Uses compiler's intrinsics */
56
57 /* During encryption and decryption, various "L values" are required.
58 / The L values can be precomputed during initialization (requiring extra
59 / space in ae_ctx), generated as needed (slightly slowing encryption and
60 / decryption), or some combination of the two. L_TABLE_SZ specifies how many
61 / L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62 / are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63 / 2^L_TABLE_SZ blocks need no L values calculated dynamically. */
64 #define L_TABLE_SZ 16
65
66 /* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67 / will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68 / in better performance. */
69 #define L_TABLE_SZ_IS_ENOUGH 1
70
71 /* ----------------------------------------------------------------------- */
72 /* Includes and compiler specific definitions */
73 /* ----------------------------------------------------------------------- */
74
75 #include "ae.h"
76 #include <stdlib.h>
77 #include <string.h>
78
79 /* Define standard sized integers */
80 #if defined(_MSC_VER) && (_MSC_VER < 1600)
81 typedef unsigned __int8 uint8_t;
82 typedef unsigned __int32 uint32_t;
83 typedef unsigned __int64 uint64_t;
84 typedef __int64 int64_t;
85 #else
86 #include <stdint.h>
87 #endif
88
89 /* Compiler-specific intrinsics and fixes: bswap64, ntz */
90 #if _MSC_VER
91 #define inline __inline /* MSVC doesn't recognize "inline" in C */
92 #define restrict __restrict /* MSVC doesn't recognize "restrict" in C */
93 #define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSE2 */
94 #define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95 #include <intrin.h>
96 #pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97 #define bswap64(x) _byteswap_uint64(x)
ntz(unsigned x)98 static inline unsigned ntz(unsigned x) {
99 _BitScanForward(&x, x);
100 return x;
101 }
102 #elif __GNUC__
103 #define inline __inline__ /* No "inline" in GCC ansi C mode */
104 #define restrict __restrict__ /* No "restrict" in GCC ansi C mode */
105 #define bswap64(x) __builtin_bswap64(x) /* Assuming GCC 4.3+ */
106 #define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107 #else /* Assume some C99 features: stdint.h, inline, restrict */
108 #define bswap32(x) \
109 ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) | \
110 (((x)&0x000000ffu) << 24))
111
bswap64(uint64_t x)112 static inline uint64_t bswap64(uint64_t x) {
113 union {
114 uint64_t u64;
115 uint32_t u32[2];
116 } in, out;
117 in.u64 = x;
118 out.u32[0] = bswap32(in.u32[1]);
119 out.u32[1] = bswap32(in.u32[0]);
120 return out.u64;
121 }
122
123 #if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
ntz(unsigned x)124 static inline unsigned ntz(unsigned x) {
125 static const unsigned char tz_table[] = {
126 0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131 return tz_table[x / 4];
132 }
133 #else /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
ntz(unsigned x)134 static inline unsigned ntz(unsigned x) {
135 static const unsigned char tz_table[32] = {0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20,
136 15, 25, 17, 4, 8, 31, 27, 13, 23, 21, 19,
137 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
138 return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139 }
140 #endif
141 #endif
142
143 /* ----------------------------------------------------------------------- */
144 /* Define blocks and operations -- Patch if incorrect on your compiler. */
145 /* ----------------------------------------------------------------------- */
146
147 #if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148 #include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149 #include <emmintrin.h> /* SSE2 instructions */
150 typedef __m128i block;
151 #define xor_block(x, y) _mm_xor_si128(x, y)
152 #define zero_block() _mm_setzero_si128()
153 #define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154 #if __SSSE3__ || USE_AES_NI
155 #include <tmmintrin.h> /* SSSE3 instructions */
156 #define swap_if_le(b) \
157 _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158 #else
swap_if_le(block b)159 static inline block swap_if_le(block b) {
160 block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161 a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162 a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163 return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164 }
165 #endif
gen_offset(uint64_t KtopStr[3],unsigned bot)166 static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167 block hi = _mm_load_si128((__m128i*)(KtopStr + 0)); /* hi = B A */
168 block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169 __m128i lshift = _mm_cvtsi32_si128(bot);
170 __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171 lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172 #if __SSSE3__ || USE_AES_NI
173 return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174 #else
175 return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176 #endif
177 }
double_block(block bl)178 static inline block double_block(block bl) {
179 const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180 __m128i tmp = _mm_srai_epi32(bl, 31);
181 tmp = _mm_and_si128(tmp, mask);
182 tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183 bl = _mm_slli_epi32(bl, 1);
184 return _mm_xor_si128(bl, tmp);
185 }
186 #elif __ALTIVEC__
187 #include <altivec.h>
188 typedef vector unsigned block;
189 #define xor_block(x, y) vec_xor(x, y)
190 #define zero_block() vec_splat_u32(0)
191 #define unequal_blocks(x, y) vec_any_ne(x, y)
192 #define swap_if_le(b) (b)
193 #if __PPC64__
gen_offset(uint64_t KtopStr[3],unsigned bot)194 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195 union {
196 uint64_t u64[2];
197 block bl;
198 } rval;
199 rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200 rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201 return rval.bl;
202 }
203 #else
204 /* Special handling: Shifts are mod 32, and no 64-bit types */
gen_offset(uint64_t KtopStr[3],unsigned bot)205 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206 const vector unsigned k32 = {32, 32, 32, 32};
207 vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208 vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209 vector unsigned bot_vec;
210 if (bot < 32) {
211 lo = vec_sld(hi, lo, 4);
212 } else {
213 vector unsigned t = vec_sld(hi, lo, 4);
214 lo = vec_sld(hi, lo, 8);
215 hi = t;
216 bot = bot - 32;
217 }
218 if (bot == 0)
219 return hi;
220 *(unsigned*)&bot_vec = bot;
221 vector unsigned lshift = vec_splat(bot_vec, 0);
222 vector unsigned rshift = vec_sub(k32, lshift);
223 hi = vec_sl(hi, lshift);
224 lo = vec_sr(lo, rshift);
225 return vec_xor(hi, lo);
226 }
227 #endif
double_block(block b)228 static inline block double_block(block b) {
229 const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230 const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231 const vector unsigned char shift7 = vec_splat_u8(7);
232 const vector unsigned char shift1 = vec_splat_u8(1);
233 vector unsigned char c = (vector unsigned char)b;
234 vector unsigned char t = vec_sra(c, shift7);
235 t = vec_and(t, mask);
236 t = vec_perm(t, t, perm);
237 c = vec_sl(c, shift1);
238 return (block)vec_xor(c, t);
239 }
240 #elif __ARM_NEON__
241 #include <arm_neon.h>
242 typedef int8x16_t block; /* Yay! Endian-neutral reads! */
243 #define xor_block(x, y) veorq_s8(x, y)
244 #define zero_block() vdupq_n_s8(0)
unequal_blocks(block a,block b)245 static inline int unequal_blocks(block a, block b) {
246 int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247 return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248 }
249 #define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)251 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252 const union {
253 unsigned x;
254 unsigned char endian;
255 } little = {1};
256 const int64x2_t k64 = {-64, -64};
257 uint64x2_t hi = *(uint64x2_t*)(KtopStr + 0); /* hi = A B */
258 uint64x2_t lo = *(uint64x2_t*)(KtopStr + 1); /* hi = B C */
259 int64x2_t ls = vdupq_n_s64(bot);
260 int64x2_t rs = vqaddq_s64(k64, ls);
261 block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
262 if (little.endian)
263 rval = vrev64q_s8(rval);
264 return rval;
265 }
double_block(block b)266 static inline block double_block(block b) {
267 const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
268 block tmp = vshrq_n_s8(b, 7);
269 tmp = vandq_s8(tmp, mask);
270 tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
271 b = vshlq_n_s8(b, 1);
272 return veorq_s8(tmp, b);
273 }
274 #else
275 typedef struct { uint64_t l, r; } block;
xor_block(block x,block y)276 static inline block xor_block(block x, block y) {
277 x.l ^= y.l;
278 x.r ^= y.r;
279 return x;
280 }
zero_block(void)281 static inline block zero_block(void) {
282 const block t = {0, 0};
283 return t;
284 }
285 #define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
swap_if_le(block b)286 static inline block swap_if_le(block b) {
287 const union {
288 unsigned x;
289 unsigned char endian;
290 } little = {1};
291 if (little.endian) {
292 block r;
293 r.l = bswap64(b.l);
294 r.r = bswap64(b.r);
295 return r;
296 } else
297 return b;
298 }
299
300 /* KtopStr is reg correct by 64 bits, return mem correct */
gen_offset(uint64_t KtopStr[3],unsigned bot)301 block gen_offset(uint64_t KtopStr[3], unsigned bot) {
302 block rval;
303 if (bot != 0) {
304 rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
305 rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
306 } else {
307 rval.l = KtopStr[0];
308 rval.r = KtopStr[1];
309 }
310 return swap_if_le(rval);
311 }
312
313 #if __GNUC__ && __arm__
double_block(block b)314 static inline block double_block(block b) {
315 __asm__("adds %1,%1,%1\n\t"
316 "adcs %H1,%H1,%H1\n\t"
317 "adcs %0,%0,%0\n\t"
318 "adcs %H0,%H0,%H0\n\t"
319 "it cs\n\t"
320 "eorcs %1,%1,#135"
321 : "+r"(b.l), "+r"(b.r)
322 :
323 : "cc");
324 return b;
325 }
326 #else
double_block(block b)327 static inline block double_block(block b) {
328 uint64_t t = (uint64_t)((int64_t)b.l >> 63);
329 b.l = (b.l + b.l) ^ (b.r >> 63);
330 b.r = (b.r + b.r) ^ (t & 135);
331 return b;
332 }
333 #endif
334
335 #endif
336
337 /* ----------------------------------------------------------------------- */
338 /* AES - Code uses OpenSSL API. Other implementations get mapped to it. */
339 /* ----------------------------------------------------------------------- */
340
341 /*---------------*/
342 #if USE_OPENSSL_AES
343 /*---------------*/
344
345 #include <openssl/aes.h> /* http://openssl.org/ */
346
347 /* How to ECB encrypt an array of blocks, in place */
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)348 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
349 while (nblks) {
350 --nblks;
351 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
352 }
353 }
354
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)355 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
356 while (nblks) {
357 --nblks;
358 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
359 }
360 }
361
362 #define BPI 4 /* Number of blocks in buffer per ECB call */
363
364 /*-------------------*/
365 #elif USE_REFERENCE_AES
366 /*-------------------*/
367
368 #include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
369 #if (OCB_KEY_LEN == 0)
370 typedef struct {
371 uint32_t rd_key[60];
372 int rounds;
373 } AES_KEY;
374 #define ROUNDS(ctx) ((ctx)->rounds)
375 #define AES_set_encrypt_key(x, y, z) \
376 do { \
377 rijndaelKeySetupEnc((z)->rd_key, x, y); \
378 (z)->rounds = y / 32 + 6; \
379 } while (0)
380 #define AES_set_decrypt_key(x, y, z) \
381 do { \
382 rijndaelKeySetupDec((z)->rd_key, x, y); \
383 (z)->rounds = y / 32 + 6; \
384 } while (0)
385 #else
386 typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
387 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
388 #define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
389 #define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
390 #endif
391 #define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
392 #define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
393
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)394 static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
395 while (nblks) {
396 --nblks;
397 AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
398 }
399 }
400
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)401 void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
402 while (nblks) {
403 --nblks;
404 AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
405 }
406 }
407
408 #define BPI 4 /* Number of blocks in buffer per ECB call */
409
410 /*----------*/
411 #elif USE_AES_NI
412 /*----------*/
413
414 #include <wmmintrin.h>
415
416 #if (OCB_KEY_LEN == 0)
417 typedef struct {
418 __m128i rd_key[15];
419 int rounds;
420 } AES_KEY;
421 #define ROUNDS(ctx) ((ctx)->rounds)
422 #else
423 typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
424 #define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
425 #endif
426
427 #define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const) \
428 v2 = _mm_aeskeygenassist_si128(v4, aes_const); \
429 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16)); \
430 v1 = _mm_xor_si128(v1, v3); \
431 v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140)); \
432 v1 = _mm_xor_si128(v1, v3); \
433 v2 = _mm_shuffle_epi32(v2, shuff_const); \
434 v1 = _mm_xor_si128(v1, v2)
435
436 #define EXPAND192_STEP(idx, aes_const) \
437 EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const); \
438 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
439 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
440 kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68)); \
441 kp[idx + 1] = \
442 _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78)); \
443 EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2)); \
444 x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4)); \
445 x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255)); \
446 kp[idx + 2] = x0; \
447 tmp = x3
448
AES_128_Key_Expansion(const unsigned char * userkey,void * key)449 static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
450 __m128i x0, x1, x2;
451 __m128i* kp = (__m128i*)key;
452 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
453 x2 = _mm_setzero_si128();
454 EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
455 kp[1] = x0;
456 EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
457 kp[2] = x0;
458 EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
459 kp[3] = x0;
460 EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
461 kp[4] = x0;
462 EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
463 kp[5] = x0;
464 EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
465 kp[6] = x0;
466 EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
467 kp[7] = x0;
468 EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
469 kp[8] = x0;
470 EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
471 kp[9] = x0;
472 EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
473 kp[10] = x0;
474 }
475
AES_192_Key_Expansion(const unsigned char * userkey,void * key)476 static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
477 __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
478 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
479 tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
480 x2 = _mm_setzero_si128();
481 EXPAND192_STEP(1, 1);
482 EXPAND192_STEP(4, 4);
483 EXPAND192_STEP(7, 16);
484 EXPAND192_STEP(10, 64);
485 }
486
AES_256_Key_Expansion(const unsigned char * userkey,void * key)487 static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
488 __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
489 kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
490 kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
491 x2 = _mm_setzero_si128();
492 EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
493 kp[2] = x0;
494 EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
495 kp[3] = x3;
496 EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
497 kp[4] = x0;
498 EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
499 kp[5] = x3;
500 EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
501 kp[6] = x0;
502 EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
503 kp[7] = x3;
504 EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
505 kp[8] = x0;
506 EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
507 kp[9] = x3;
508 EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
509 kp[10] = x0;
510 EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
511 kp[11] = x3;
512 EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
513 kp[12] = x0;
514 EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
515 kp[13] = x3;
516 EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
517 kp[14] = x0;
518 }
519
AES_set_encrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)520 static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
521 if (bits == 128) {
522 AES_128_Key_Expansion(userKey, key);
523 } else if (bits == 192) {
524 AES_192_Key_Expansion(userKey, key);
525 } else if (bits == 256) {
526 AES_256_Key_Expansion(userKey, key);
527 }
528 #if (OCB_KEY_LEN == 0)
529 key->rounds = 6 + bits / 32;
530 #endif
531 return 0;
532 }
533
AES_set_decrypt_key_fast(AES_KEY * dkey,const AES_KEY * ekey)534 static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
535 int j = 0;
536 int i = ROUNDS(ekey);
537 #if (OCB_KEY_LEN == 0)
538 dkey->rounds = i;
539 #endif
540 dkey->rd_key[i--] = ekey->rd_key[j++];
541 while (i)
542 dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
543 dkey->rd_key[i] = ekey->rd_key[j];
544 }
545
AES_set_decrypt_key(const unsigned char * userKey,const int bits,AES_KEY * key)546 static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
547 AES_KEY temp_key;
548 AES_set_encrypt_key(userKey, bits, &temp_key);
549 AES_set_decrypt_key_fast(key, &temp_key);
550 return 0;
551 }
552
AES_encrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)553 static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
554 int j, rnds = ROUNDS(key);
555 const __m128i* sched = ((__m128i*)(key->rd_key));
556 __m128i tmp = _mm_load_si128((__m128i*)in);
557 tmp = _mm_xor_si128(tmp, sched[0]);
558 for (j = 1; j < rnds; j++)
559 tmp = _mm_aesenc_si128(tmp, sched[j]);
560 tmp = _mm_aesenclast_si128(tmp, sched[j]);
561 _mm_store_si128((__m128i*)out, tmp);
562 }
563
AES_decrypt(const unsigned char * in,unsigned char * out,const AES_KEY * key)564 static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
565 int j, rnds = ROUNDS(key);
566 const __m128i* sched = ((__m128i*)(key->rd_key));
567 __m128i tmp = _mm_load_si128((__m128i*)in);
568 tmp = _mm_xor_si128(tmp, sched[0]);
569 for (j = 1; j < rnds; j++)
570 tmp = _mm_aesdec_si128(tmp, sched[j]);
571 tmp = _mm_aesdeclast_si128(tmp, sched[j]);
572 _mm_store_si128((__m128i*)out, tmp);
573 }
574
AES_ecb_encrypt_blks(block * blks,unsigned nblks,AES_KEY * key)575 static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
576 unsigned i, j, rnds = ROUNDS(key);
577 const __m128i* sched = ((__m128i*)(key->rd_key));
578 for (i = 0; i < nblks; ++i)
579 blks[i] = _mm_xor_si128(blks[i], sched[0]);
580 for (j = 1; j < rnds; ++j)
581 for (i = 0; i < nblks; ++i)
582 blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
583 for (i = 0; i < nblks; ++i)
584 blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
585 }
586
AES_ecb_decrypt_blks(block * blks,unsigned nblks,AES_KEY * key)587 static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
588 unsigned i, j, rnds = ROUNDS(key);
589 const __m128i* sched = ((__m128i*)(key->rd_key));
590 for (i = 0; i < nblks; ++i)
591 blks[i] = _mm_xor_si128(blks[i], sched[0]);
592 for (j = 1; j < rnds; ++j)
593 for (i = 0; i < nblks; ++i)
594 blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
595 for (i = 0; i < nblks; ++i)
596 blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
597 }
598
599 #define BPI 8 /* Number of blocks in buffer per ECB call */
600 /* Set to 4 for Westmere, 8 for Sandy Bridge */
601
602 #endif
603
604 /* ----------------------------------------------------------------------- */
605 /* Define OCB context structure. */
606 /* ----------------------------------------------------------------------- */
607
608 /*------------------------------------------------------------------------
609 / Each item in the OCB context is stored either "memory correct" or
610 / "register correct". On big-endian machines, this is identical. On
611 / little-endian machines, one must choose whether the byte-string
612 / is in the correct order when it resides in memory or in registers.
613 / It must be register correct whenever it is to be manipulated
614 / arithmetically, but must be memory correct whenever it interacts
615 / with the plaintext or ciphertext.
616 /------------------------------------------------------------------------- */
617
618 struct _ae_ctx {
619 block offset; /* Memory correct */
620 block checksum; /* Memory correct */
621 block Lstar; /* Memory correct */
622 block Ldollar; /* Memory correct */
623 block L[L_TABLE_SZ]; /* Memory correct */
624 block ad_checksum; /* Memory correct */
625 block ad_offset; /* Memory correct */
626 block cached_Top; /* Memory correct */
627 uint64_t KtopStr[3]; /* Register correct, each item */
628 uint32_t ad_blocks_processed;
629 uint32_t blocks_processed;
630 AES_KEY decrypt_key;
631 AES_KEY encrypt_key;
632 #if (OCB_TAG_LEN == 0)
633 unsigned tag_len;
634 #endif
635 };
636
637 /* ----------------------------------------------------------------------- */
638 /* L table lookup (or on-the-fly generation) */
639 /* ----------------------------------------------------------------------- */
640
641 #if L_TABLE_SZ_IS_ENOUGH
642 #define getL(_ctx, _tz) ((_ctx)->L[_tz])
643 #else
getL(const ae_ctx * ctx,unsigned tz)644 static block getL(const ae_ctx* ctx, unsigned tz) {
645 if (tz < L_TABLE_SZ)
646 return ctx->L[tz];
647 else {
648 unsigned i;
649 /* Bring L[MAX] into registers, make it register correct */
650 block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
651 rval = double_block(rval);
652 for (i = L_TABLE_SZ; i < tz; i++)
653 rval = double_block(rval);
654 return swap_if_le(rval); /* To memory correct */
655 }
656 }
657 #endif
658
659 /* ----------------------------------------------------------------------- */
660 /* Public functions */
661 /* ----------------------------------------------------------------------- */
662
663 /* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
664 on 16-byte alignments. (I believe all major 64-bit systems do already.) */
665
ae_allocate(void * misc)666 ae_ctx* ae_allocate(void* misc) {
667 void* p;
668 (void)misc; /* misc unused in this implementation */
669 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
670 p = _mm_malloc(sizeof(ae_ctx), 16);
671 #elif(__ALTIVEC__ && !__PPC64__)
672 if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
673 p = NULL;
674 #else
675 p = malloc(sizeof(ae_ctx));
676 #endif
677 return (ae_ctx*)p;
678 }
679
ae_free(ae_ctx * ctx)680 void ae_free(ae_ctx* ctx) {
681 #if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
682 _mm_free(ctx);
683 #else
684 free(ctx);
685 #endif
686 }
687
688 /* ----------------------------------------------------------------------- */
689
ae_clear(ae_ctx * ctx)690 int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization */
691 {
692 memset(ctx, 0, sizeof(ae_ctx));
693 return AE_SUCCESS;
694 }
695
ae_ctx_sizeof(void)696 int ae_ctx_sizeof(void) {
697 return (int)sizeof(ae_ctx);
698 }
699
700 /* ----------------------------------------------------------------------- */
701
ae_init(ae_ctx * ctx,const void * key,int key_len,int nonce_len,int tag_len)702 int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
703 unsigned i;
704 block tmp_blk;
705
706 if (nonce_len != 12)
707 return AE_NOT_SUPPORTED;
708
709 /* Initialize encryption & decryption keys */
710 #if (OCB_KEY_LEN > 0)
711 key_len = OCB_KEY_LEN;
712 #endif
713 AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
714 #if USE_AES_NI
715 AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
716 #else
717 AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
718 #endif
719
720 /* Zero things that need zeroing */
721 ctx->cached_Top = ctx->ad_checksum = zero_block();
722 ctx->ad_blocks_processed = 0;
723
724 /* Compute key-dependent values */
725 AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
726 tmp_blk = swap_if_le(ctx->Lstar);
727 tmp_blk = double_block(tmp_blk);
728 ctx->Ldollar = swap_if_le(tmp_blk);
729 tmp_blk = double_block(tmp_blk);
730 ctx->L[0] = swap_if_le(tmp_blk);
731 for (i = 1; i < L_TABLE_SZ; i++) {
732 tmp_blk = double_block(tmp_blk);
733 ctx->L[i] = swap_if_le(tmp_blk);
734 }
735
736 #if (OCB_TAG_LEN == 0)
737 ctx->tag_len = tag_len;
738 #else
739 (void)tag_len; /* Suppress var not used error */
740 #endif
741
742 return AE_SUCCESS;
743 }
744
745 /* ----------------------------------------------------------------------- */
746
gen_offset_from_nonce(ae_ctx * ctx,const void * nonce)747 static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
748 const union {
749 unsigned x;
750 unsigned char endian;
751 } little = {1};
752 union {
753 uint32_t u32[4];
754 uint8_t u8[16];
755 block bl;
756 } tmp;
757 unsigned idx;
758 uint32_t tagadd;
759
760 /* Replace cached nonce Top if needed */
761 #if (OCB_TAG_LEN > 0)
762 if (little.endian)
763 tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
764 else
765 tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
766 #else
767 if (little.endian)
768 tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
769 else
770 tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
771 #endif
772 tmp.u32[1] = ((uint32_t*)nonce)[0];
773 tmp.u32[2] = ((uint32_t*)nonce)[1];
774 tmp.u32[3] = ((uint32_t*)nonce)[2];
775 idx = (unsigned)(tmp.u8[15] & 0x3f); /* Get low 6 bits of nonce */
776 tmp.u8[15] = tmp.u8[15] & 0xc0; /* Zero low 6 bits of nonce */
777 if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached? */
778 ctx->cached_Top = tmp.bl; /* Update cache, KtopStr */
779 AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
780 if (little.endian) { /* Make Register Correct */
781 ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
782 ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
783 }
784 ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
785 }
786 return gen_offset(ctx->KtopStr, idx);
787 }
788
process_ad(ae_ctx * ctx,const void * ad,int ad_len,int final)789 static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
790 union {
791 uint32_t u32[4];
792 uint8_t u8[16];
793 block bl;
794 } tmp;
795 block ad_offset, ad_checksum;
796 const block* adp = (block*)ad;
797 unsigned i, k, tz, remaining;
798
799 ad_offset = ctx->ad_offset;
800 ad_checksum = ctx->ad_checksum;
801 i = ad_len / (BPI * 16);
802 if (i) {
803 unsigned ad_block_num = ctx->ad_blocks_processed;
804 do {
805 block ta[BPI], oa[BPI];
806 ad_block_num += BPI;
807 tz = ntz(ad_block_num);
808 oa[0] = xor_block(ad_offset, ctx->L[0]);
809 ta[0] = xor_block(oa[0], adp[0]);
810 oa[1] = xor_block(oa[0], ctx->L[1]);
811 ta[1] = xor_block(oa[1], adp[1]);
812 oa[2] = xor_block(ad_offset, ctx->L[1]);
813 ta[2] = xor_block(oa[2], adp[2]);
814 #if BPI == 4
815 ad_offset = xor_block(oa[2], getL(ctx, tz));
816 ta[3] = xor_block(ad_offset, adp[3]);
817 #elif BPI == 8
818 oa[3] = xor_block(oa[2], ctx->L[2]);
819 ta[3] = xor_block(oa[3], adp[3]);
820 oa[4] = xor_block(oa[1], ctx->L[2]);
821 ta[4] = xor_block(oa[4], adp[4]);
822 oa[5] = xor_block(oa[0], ctx->L[2]);
823 ta[5] = xor_block(oa[5], adp[5]);
824 oa[6] = xor_block(ad_offset, ctx->L[2]);
825 ta[6] = xor_block(oa[6], adp[6]);
826 ad_offset = xor_block(oa[6], getL(ctx, tz));
827 ta[7] = xor_block(ad_offset, adp[7]);
828 #endif
829 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
830 ad_checksum = xor_block(ad_checksum, ta[0]);
831 ad_checksum = xor_block(ad_checksum, ta[1]);
832 ad_checksum = xor_block(ad_checksum, ta[2]);
833 ad_checksum = xor_block(ad_checksum, ta[3]);
834 #if (BPI == 8)
835 ad_checksum = xor_block(ad_checksum, ta[4]);
836 ad_checksum = xor_block(ad_checksum, ta[5]);
837 ad_checksum = xor_block(ad_checksum, ta[6]);
838 ad_checksum = xor_block(ad_checksum, ta[7]);
839 #endif
840 adp += BPI;
841 } while (--i);
842 ctx->ad_blocks_processed = ad_block_num;
843 ctx->ad_offset = ad_offset;
844 ctx->ad_checksum = ad_checksum;
845 }
846
847 if (final) {
848 block ta[BPI];
849
850 /* Process remaining associated data, compute its tag contribution */
851 remaining = ((unsigned)ad_len) % (BPI * 16);
852 if (remaining) {
853 k = 0;
854 #if (BPI == 8)
855 if (remaining >= 64) {
856 tmp.bl = xor_block(ad_offset, ctx->L[0]);
857 ta[0] = xor_block(tmp.bl, adp[0]);
858 tmp.bl = xor_block(tmp.bl, ctx->L[1]);
859 ta[1] = xor_block(tmp.bl, adp[1]);
860 ad_offset = xor_block(ad_offset, ctx->L[1]);
861 ta[2] = xor_block(ad_offset, adp[2]);
862 ad_offset = xor_block(ad_offset, ctx->L[2]);
863 ta[3] = xor_block(ad_offset, adp[3]);
864 remaining -= 64;
865 k = 4;
866 }
867 #endif
868 if (remaining >= 32) {
869 ad_offset = xor_block(ad_offset, ctx->L[0]);
870 ta[k] = xor_block(ad_offset, adp[k]);
871 ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
872 ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
873 remaining -= 32;
874 k += 2;
875 }
876 if (remaining >= 16) {
877 ad_offset = xor_block(ad_offset, ctx->L[0]);
878 ta[k] = xor_block(ad_offset, adp[k]);
879 remaining = remaining - 16;
880 ++k;
881 }
882 if (remaining) {
883 ad_offset = xor_block(ad_offset, ctx->Lstar);
884 tmp.bl = zero_block();
885 memcpy(tmp.u8, adp + k, remaining);
886 tmp.u8[remaining] = (unsigned char)0x80u;
887 ta[k] = xor_block(ad_offset, tmp.bl);
888 ++k;
889 }
890 AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
891 switch (k) {
892 #if (BPI == 8)
893 case 8:
894 ad_checksum = xor_block(ad_checksum, ta[7]);
895 case 7:
896 ad_checksum = xor_block(ad_checksum, ta[6]);
897 case 6:
898 ad_checksum = xor_block(ad_checksum, ta[5]);
899 case 5:
900 ad_checksum = xor_block(ad_checksum, ta[4]);
901 #endif
902 case 4:
903 ad_checksum = xor_block(ad_checksum, ta[3]);
904 case 3:
905 ad_checksum = xor_block(ad_checksum, ta[2]);
906 case 2:
907 ad_checksum = xor_block(ad_checksum, ta[1]);
908 case 1:
909 ad_checksum = xor_block(ad_checksum, ta[0]);
910 }
911 ctx->ad_checksum = ad_checksum;
912 }
913 }
914 }
915
916 /* ----------------------------------------------------------------------- */
917
ae_encrypt(ae_ctx * ctx,const void * nonce,const void * pt,int pt_len,const void * ad,int ad_len,void * ct,void * tag,int final)918 int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
919 int ad_len, void* ct, void* tag, int final) {
920 union {
921 uint32_t u32[4];
922 uint8_t u8[16];
923 block bl;
924 } tmp;
925 block offset, checksum;
926 unsigned i, k;
927 block* ctp = (block*)ct;
928 const block* ptp = (block*)pt;
929
930 /* Non-null nonce means start of new message, init per-message values */
931 if (nonce) {
932 ctx->offset = gen_offset_from_nonce(ctx, nonce);
933 ctx->ad_offset = ctx->checksum = zero_block();
934 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
935 if (ad_len >= 0)
936 ctx->ad_checksum = zero_block();
937 }
938
939 /* Process associated data */
940 if (ad_len > 0)
941 process_ad(ctx, ad, ad_len, final);
942
943 /* Encrypt plaintext data BPI blocks at a time */
944 offset = ctx->offset;
945 checksum = ctx->checksum;
946 i = pt_len / (BPI * 16);
947 if (i) {
948 block oa[BPI];
949 unsigned block_num = ctx->blocks_processed;
950 oa[BPI - 1] = offset;
951 do {
952 block ta[BPI];
953 block_num += BPI;
954 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
955 ta[0] = xor_block(oa[0], ptp[0]);
956 checksum = xor_block(checksum, ptp[0]);
957 oa[1] = xor_block(oa[0], ctx->L[1]);
958 ta[1] = xor_block(oa[1], ptp[1]);
959 checksum = xor_block(checksum, ptp[1]);
960 oa[2] = xor_block(oa[1], ctx->L[0]);
961 ta[2] = xor_block(oa[2], ptp[2]);
962 checksum = xor_block(checksum, ptp[2]);
963 #if BPI == 4
964 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
965 ta[3] = xor_block(oa[3], ptp[3]);
966 checksum = xor_block(checksum, ptp[3]);
967 #elif BPI == 8
968 oa[3] = xor_block(oa[2], ctx->L[2]);
969 ta[3] = xor_block(oa[3], ptp[3]);
970 checksum = xor_block(checksum, ptp[3]);
971 oa[4] = xor_block(oa[1], ctx->L[2]);
972 ta[4] = xor_block(oa[4], ptp[4]);
973 checksum = xor_block(checksum, ptp[4]);
974 oa[5] = xor_block(oa[0], ctx->L[2]);
975 ta[5] = xor_block(oa[5], ptp[5]);
976 checksum = xor_block(checksum, ptp[5]);
977 oa[6] = xor_block(oa[7], ctx->L[2]);
978 ta[6] = xor_block(oa[6], ptp[6]);
979 checksum = xor_block(checksum, ptp[6]);
980 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
981 ta[7] = xor_block(oa[7], ptp[7]);
982 checksum = xor_block(checksum, ptp[7]);
983 #endif
984 AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
985 ctp[0] = xor_block(ta[0], oa[0]);
986 ctp[1] = xor_block(ta[1], oa[1]);
987 ctp[2] = xor_block(ta[2], oa[2]);
988 ctp[3] = xor_block(ta[3], oa[3]);
989 #if (BPI == 8)
990 ctp[4] = xor_block(ta[4], oa[4]);
991 ctp[5] = xor_block(ta[5], oa[5]);
992 ctp[6] = xor_block(ta[6], oa[6]);
993 ctp[7] = xor_block(ta[7], oa[7]);
994 #endif
995 ptp += BPI;
996 ctp += BPI;
997 } while (--i);
998 ctx->offset = offset = oa[BPI - 1];
999 ctx->blocks_processed = block_num;
1000 ctx->checksum = checksum;
1001 }
1002
1003 if (final) {
1004 block ta[BPI + 1], oa[BPI];
1005
1006 /* Process remaining plaintext and compute its tag contribution */
1007 unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1008 k = 0; /* How many blocks in ta[] need ECBing */
1009 if (remaining) {
1010 #if (BPI == 8)
1011 if (remaining >= 64) {
1012 oa[0] = xor_block(offset, ctx->L[0]);
1013 ta[0] = xor_block(oa[0], ptp[0]);
1014 checksum = xor_block(checksum, ptp[0]);
1015 oa[1] = xor_block(oa[0], ctx->L[1]);
1016 ta[1] = xor_block(oa[1], ptp[1]);
1017 checksum = xor_block(checksum, ptp[1]);
1018 oa[2] = xor_block(oa[1], ctx->L[0]);
1019 ta[2] = xor_block(oa[2], ptp[2]);
1020 checksum = xor_block(checksum, ptp[2]);
1021 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1022 ta[3] = xor_block(offset, ptp[3]);
1023 checksum = xor_block(checksum, ptp[3]);
1024 remaining -= 64;
1025 k = 4;
1026 }
1027 #endif
1028 if (remaining >= 32) {
1029 oa[k] = xor_block(offset, ctx->L[0]);
1030 ta[k] = xor_block(oa[k], ptp[k]);
1031 checksum = xor_block(checksum, ptp[k]);
1032 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1033 ta[k + 1] = xor_block(offset, ptp[k + 1]);
1034 checksum = xor_block(checksum, ptp[k + 1]);
1035 remaining -= 32;
1036 k += 2;
1037 }
1038 if (remaining >= 16) {
1039 offset = oa[k] = xor_block(offset, ctx->L[0]);
1040 ta[k] = xor_block(offset, ptp[k]);
1041 checksum = xor_block(checksum, ptp[k]);
1042 remaining -= 16;
1043 ++k;
1044 }
1045 if (remaining) {
1046 tmp.bl = zero_block();
1047 memcpy(tmp.u8, ptp + k, remaining);
1048 tmp.u8[remaining] = (unsigned char)0x80u;
1049 checksum = xor_block(checksum, tmp.bl);
1050 ta[k] = offset = xor_block(offset, ctx->Lstar);
1051 ++k;
1052 }
1053 }
1054 offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1055 ta[k] = xor_block(offset, checksum); /* Part of tag gen */
1056 AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1057 offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1058 if (remaining) {
1059 --k;
1060 tmp.bl = xor_block(tmp.bl, ta[k]);
1061 memcpy(ctp + k, tmp.u8, remaining);
1062 }
1063 switch (k) {
1064 #if (BPI == 8)
1065 case 7:
1066 ctp[6] = xor_block(ta[6], oa[6]);
1067 case 6:
1068 ctp[5] = xor_block(ta[5], oa[5]);
1069 case 5:
1070 ctp[4] = xor_block(ta[4], oa[4]);
1071 case 4:
1072 ctp[3] = xor_block(ta[3], oa[3]);
1073 #endif
1074 case 3:
1075 ctp[2] = xor_block(ta[2], oa[2]);
1076 case 2:
1077 ctp[1] = xor_block(ta[1], oa[1]);
1078 case 1:
1079 ctp[0] = xor_block(ta[0], oa[0]);
1080 }
1081
1082 /* Tag is placed at the correct location
1083 */
1084 if (tag) {
1085 #if (OCB_TAG_LEN == 16)
1086 *(block*)tag = offset;
1087 #elif(OCB_TAG_LEN > 0)
1088 memcpy((char*)tag, &offset, OCB_TAG_LEN);
1089 #else
1090 memcpy((char*)tag, &offset, ctx->tag_len);
1091 #endif
1092 } else {
1093 #if (OCB_TAG_LEN > 0)
1094 memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1095 pt_len += OCB_TAG_LEN;
1096 #else
1097 memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1098 pt_len += ctx->tag_len;
1099 #endif
1100 }
1101 }
1102 return (int)pt_len;
1103 }
1104
1105 /* ----------------------------------------------------------------------- */
1106
1107 /* Compare two regions of memory, taking a constant amount of time for a
1108 given buffer size -- under certain assumptions about the compiler
1109 and machine, of course.
1110
1111 Use this to avoid timing side-channel attacks.
1112
1113 Returns 0 for memory regions with equal contents; non-zero otherwise. */
constant_time_memcmp(const void * av,const void * bv,size_t n)1114 static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1115 const uint8_t* a = (const uint8_t*)av;
1116 const uint8_t* b = (const uint8_t*)bv;
1117 uint8_t result = 0;
1118 size_t i;
1119
1120 for (i = 0; i < n; i++) {
1121 result |= *a ^ *b;
1122 a++;
1123 b++;
1124 }
1125
1126 return (int)result;
1127 }
1128
ae_decrypt(ae_ctx * ctx,const void * nonce,const void * ct,int ct_len,const void * ad,int ad_len,void * pt,const void * tag,int final)1129 int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1130 int ad_len, void* pt, const void* tag, int final) {
1131 union {
1132 uint32_t u32[4];
1133 uint8_t u8[16];
1134 block bl;
1135 } tmp;
1136 block offset, checksum;
1137 unsigned i, k;
1138 block* ctp = (block*)ct;
1139 block* ptp = (block*)pt;
1140
1141 /* Reduce ct_len tag bundled in ct */
1142 if ((final) && (!tag))
1143 #if (OCB_TAG_LEN > 0)
1144 ct_len -= OCB_TAG_LEN;
1145 #else
1146 ct_len -= ctx->tag_len;
1147 #endif
1148
1149 /* Non-null nonce means start of new message, init per-message values */
1150 if (nonce) {
1151 ctx->offset = gen_offset_from_nonce(ctx, nonce);
1152 ctx->ad_offset = ctx->checksum = zero_block();
1153 ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1154 if (ad_len >= 0)
1155 ctx->ad_checksum = zero_block();
1156 }
1157
1158 /* Process associated data */
1159 if (ad_len > 0)
1160 process_ad(ctx, ad, ad_len, final);
1161
1162 /* Encrypt plaintext data BPI blocks at a time */
1163 offset = ctx->offset;
1164 checksum = ctx->checksum;
1165 i = ct_len / (BPI * 16);
1166 if (i) {
1167 block oa[BPI];
1168 unsigned block_num = ctx->blocks_processed;
1169 oa[BPI - 1] = offset;
1170 do {
1171 block ta[BPI];
1172 block_num += BPI;
1173 oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1174 ta[0] = xor_block(oa[0], ctp[0]);
1175 oa[1] = xor_block(oa[0], ctx->L[1]);
1176 ta[1] = xor_block(oa[1], ctp[1]);
1177 oa[2] = xor_block(oa[1], ctx->L[0]);
1178 ta[2] = xor_block(oa[2], ctp[2]);
1179 #if BPI == 4
1180 oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1181 ta[3] = xor_block(oa[3], ctp[3]);
1182 #elif BPI == 8
1183 oa[3] = xor_block(oa[2], ctx->L[2]);
1184 ta[3] = xor_block(oa[3], ctp[3]);
1185 oa[4] = xor_block(oa[1], ctx->L[2]);
1186 ta[4] = xor_block(oa[4], ctp[4]);
1187 oa[5] = xor_block(oa[0], ctx->L[2]);
1188 ta[5] = xor_block(oa[5], ctp[5]);
1189 oa[6] = xor_block(oa[7], ctx->L[2]);
1190 ta[6] = xor_block(oa[6], ctp[6]);
1191 oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1192 ta[7] = xor_block(oa[7], ctp[7]);
1193 #endif
1194 AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1195 ptp[0] = xor_block(ta[0], oa[0]);
1196 checksum = xor_block(checksum, ptp[0]);
1197 ptp[1] = xor_block(ta[1], oa[1]);
1198 checksum = xor_block(checksum, ptp[1]);
1199 ptp[2] = xor_block(ta[2], oa[2]);
1200 checksum = xor_block(checksum, ptp[2]);
1201 ptp[3] = xor_block(ta[3], oa[3]);
1202 checksum = xor_block(checksum, ptp[3]);
1203 #if (BPI == 8)
1204 ptp[4] = xor_block(ta[4], oa[4]);
1205 checksum = xor_block(checksum, ptp[4]);
1206 ptp[5] = xor_block(ta[5], oa[5]);
1207 checksum = xor_block(checksum, ptp[5]);
1208 ptp[6] = xor_block(ta[6], oa[6]);
1209 checksum = xor_block(checksum, ptp[6]);
1210 ptp[7] = xor_block(ta[7], oa[7]);
1211 checksum = xor_block(checksum, ptp[7]);
1212 #endif
1213 ptp += BPI;
1214 ctp += BPI;
1215 } while (--i);
1216 ctx->offset = offset = oa[BPI - 1];
1217 ctx->blocks_processed = block_num;
1218 ctx->checksum = checksum;
1219 }
1220
1221 if (final) {
1222 block ta[BPI + 1], oa[BPI];
1223
1224 /* Process remaining plaintext and compute its tag contribution */
1225 unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1226 k = 0; /* How many blocks in ta[] need ECBing */
1227 if (remaining) {
1228 #if (BPI == 8)
1229 if (remaining >= 64) {
1230 oa[0] = xor_block(offset, ctx->L[0]);
1231 ta[0] = xor_block(oa[0], ctp[0]);
1232 oa[1] = xor_block(oa[0], ctx->L[1]);
1233 ta[1] = xor_block(oa[1], ctp[1]);
1234 oa[2] = xor_block(oa[1], ctx->L[0]);
1235 ta[2] = xor_block(oa[2], ctp[2]);
1236 offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1237 ta[3] = xor_block(offset, ctp[3]);
1238 remaining -= 64;
1239 k = 4;
1240 }
1241 #endif
1242 if (remaining >= 32) {
1243 oa[k] = xor_block(offset, ctx->L[0]);
1244 ta[k] = xor_block(oa[k], ctp[k]);
1245 offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1246 ta[k + 1] = xor_block(offset, ctp[k + 1]);
1247 remaining -= 32;
1248 k += 2;
1249 }
1250 if (remaining >= 16) {
1251 offset = oa[k] = xor_block(offset, ctx->L[0]);
1252 ta[k] = xor_block(offset, ctp[k]);
1253 remaining -= 16;
1254 ++k;
1255 }
1256 if (remaining) {
1257 block pad;
1258 offset = xor_block(offset, ctx->Lstar);
1259 AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1260 pad = tmp.bl;
1261 memcpy(tmp.u8, ctp + k, remaining);
1262 tmp.bl = xor_block(tmp.bl, pad);
1263 tmp.u8[remaining] = (unsigned char)0x80u;
1264 memcpy(ptp + k, tmp.u8, remaining);
1265 checksum = xor_block(checksum, tmp.bl);
1266 }
1267 }
1268 AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1269 switch (k) {
1270 #if (BPI == 8)
1271 case 7:
1272 ptp[6] = xor_block(ta[6], oa[6]);
1273 checksum = xor_block(checksum, ptp[6]);
1274 case 6:
1275 ptp[5] = xor_block(ta[5], oa[5]);
1276 checksum = xor_block(checksum, ptp[5]);
1277 case 5:
1278 ptp[4] = xor_block(ta[4], oa[4]);
1279 checksum = xor_block(checksum, ptp[4]);
1280 case 4:
1281 ptp[3] = xor_block(ta[3], oa[3]);
1282 checksum = xor_block(checksum, ptp[3]);
1283 #endif
1284 case 3:
1285 ptp[2] = xor_block(ta[2], oa[2]);
1286 checksum = xor_block(checksum, ptp[2]);
1287 case 2:
1288 ptp[1] = xor_block(ta[1], oa[1]);
1289 checksum = xor_block(checksum, ptp[1]);
1290 case 1:
1291 ptp[0] = xor_block(ta[0], oa[0]);
1292 checksum = xor_block(checksum, ptp[0]);
1293 }
1294
1295 /* Calculate expected tag */
1296 offset = xor_block(offset, ctx->Ldollar);
1297 tmp.bl = xor_block(offset, checksum);
1298 AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1299 tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1300
1301 /* Compare with proposed tag, change ct_len if invalid */
1302 if ((OCB_TAG_LEN == 16) && tag) {
1303 if (unequal_blocks(tmp.bl, *(block*)tag))
1304 ct_len = AE_INVALID;
1305 } else {
1306 #if (OCB_TAG_LEN > 0)
1307 int len = OCB_TAG_LEN;
1308 #else
1309 int len = ctx->tag_len;
1310 #endif
1311 if (tag) {
1312 if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1313 ct_len = AE_INVALID;
1314 } else {
1315 if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1316 ct_len = AE_INVALID;
1317 }
1318 }
1319 }
1320 return ct_len;
1321 }
1322
1323 /* ----------------------------------------------------------------------- */
1324 /* Simple test program */
1325 /* ----------------------------------------------------------------------- */
1326
1327 #if 0
1328
1329 #include <stdio.h>
1330 #include <time.h>
1331
1332 #if __GNUC__
1333 #define ALIGN(n) __attribute__((aligned(n)))
1334 #elif _MSC_VER
1335 #define ALIGN(n) __declspec(align(n))
1336 #else /* Not GNU/Microsoft: delete alignment uses. */
1337 #define ALIGN(n)
1338 #endif
1339
1340 static void pbuf(void *p, unsigned len, const void *s)
1341 {
1342 unsigned i;
1343 if (s)
1344 printf("%s", (char *)s);
1345 for (i = 0; i < len; i++)
1346 printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1347 printf("\n");
1348 }
1349
1350 static void vectors(ae_ctx *ctx, int len)
1351 {
1352 ALIGN(16) char pt[128];
1353 ALIGN(16) char ct[144];
1354 ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1355 int i;
1356 for (i=0; i < 128; i++) pt[i] = i;
1357 i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1358 printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1359 i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1360 printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1361 i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1362 printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1363 }
1364
1365 void validate()
1366 {
1367 ALIGN(16) char pt[1024];
1368 ALIGN(16) char ct[1024];
1369 ALIGN(16) char tag[16];
1370 ALIGN(16) char nonce[12] = {0,};
1371 ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1372 ae_ctx ctx;
1373 char *val_buf, *next;
1374 int i, len;
1375
1376 val_buf = (char *)malloc(22400 + 16);
1377 next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1378
1379 if (0) {
1380 ae_init(&ctx, key, 16, 12, 16);
1381 /* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1382 vectors(&ctx,0);
1383 vectors(&ctx,8);
1384 vectors(&ctx,16);
1385 vectors(&ctx,24);
1386 vectors(&ctx,32);
1387 vectors(&ctx,40);
1388 }
1389
1390 memset(key,0,32);
1391 memset(pt,0,128);
1392 ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1393
1394 /* RFC Vector test */
1395 for (i = 0; i < 128; i++) {
1396 int first = ((i/3)/(BPI*16))*(BPI*16);
1397 int second = first;
1398 int third = i - (first + second);
1399
1400 nonce[11] = i;
1401
1402 if (0) {
1403 ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1404 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1405 next = next+i+OCB_TAG_LEN;
1406
1407 ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1408 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1409 next = next+i+OCB_TAG_LEN;
1410
1411 ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1412 memcpy(next,ct,OCB_TAG_LEN);
1413 next = next+OCB_TAG_LEN;
1414 } else {
1415 ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1416 ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1417 ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1418 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1419 next = next+i+OCB_TAG_LEN;
1420
1421 ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1422 ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1423 ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1424 memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1425 next = next+i+OCB_TAG_LEN;
1426
1427 ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1428 ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1429 ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1430 memcpy(next,ct,OCB_TAG_LEN);
1431 next = next+OCB_TAG_LEN;
1432 }
1433
1434 }
1435 nonce[11] = 0;
1436 ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1437 pbuf(tag,OCB_TAG_LEN,0);
1438
1439
1440 /* Encrypt/Decrypt test */
1441 for (i = 0; i < 128; i++) {
1442 int first = ((i/3)/(BPI*16))*(BPI*16);
1443 int second = first;
1444 int third = i - (first + second);
1445
1446 nonce[11] = i%128;
1447
1448 if (1) {
1449 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1450 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1451 len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1452 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1453 if (len != i) { printf("Length error: %d\n", i); return; }
1454 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1455 } else {
1456 len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1457 ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1458 ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1459 len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1460 if (len == -1) { printf("Authentication error: %d\n", i); return; }
1461 if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1462 }
1463
1464 }
1465 printf("Decrypt: PASS\n");
1466 }
1467
1468 int main()
1469 {
1470 validate();
1471 return 0;
1472 }
1473 #endif
1474
1475 #if USE_AES_NI
1476 char infoString[] = "OCB3 (AES-NI)";
1477 #elif USE_REFERENCE_AES
1478 char infoString[] = "OCB3 (Reference)";
1479 #elif USE_OPENSSL_AES
1480 char infoString[] = "OCB3 (OpenSSL)";
1481 #endif
1482