1 /* $OpenBSD: sshkey-xmss.c,v 1.8 2019/11/13 07:53:10 markus Exp $ */
2 /*
3 * Copyright (c) 2017 Markus Friedl. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions
7 * are met:
8 * 1. Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 * notice, this list of conditions and the following disclaimer in the
12 * documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 */
25
26 #include "includes.h"
27 #ifdef WITH_XMSS
28
29 #include <sys/types.h>
30 #include <sys/uio.h>
31
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 #ifdef HAVE_SYS_FILE_H
38 # include <sys/file.h>
39 #endif
40
41 #include "ssh2.h"
42 #include "ssherr.h"
43 #include "sshbuf.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "sshkey-xmss.h"
47 #include "atomicio.h"
48
49 #include "xmss_fast.h"
50
51 /* opaque internal XMSS state */
52 #define XMSS_MAGIC "xmss-state-v1"
53 #define XMSS_CIPHERNAME "aes256-gcm@openssh.com"
54 struct ssh_xmss_state {
55 xmss_params params;
56 u_int32_t n, w, h, k;
57
58 bds_state bds;
59 u_char *stack;
60 u_int32_t stackoffset;
61 u_char *stacklevels;
62 u_char *auth;
63 u_char *keep;
64 u_char *th_nodes;
65 u_char *retain;
66 treehash_inst *treehash;
67
68 u_int32_t idx; /* state read from file */
69 u_int32_t maxidx; /* restricted # of signatures */
70 int have_state; /* .state file exists */
71 int lockfd; /* locked in sshkey_xmss_get_state() */
72 u_char allow_update; /* allow sshkey_xmss_update_state() */
73 char *enc_ciphername;/* encrypt state with cipher */
74 u_char *enc_keyiv; /* encrypt state with key */
75 u_int32_t enc_keyiv_len; /* length of enc_keyiv */
76 };
77
78 int sshkey_xmss_init_bds_state(struct sshkey *);
79 int sshkey_xmss_init_enc_key(struct sshkey *, const char *);
80 void sshkey_xmss_free_bds(struct sshkey *);
81 int sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
82 int *, sshkey_printfn *);
83 int sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
84 struct sshbuf **);
85 int sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
86 struct sshbuf **);
87 int sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
88 int sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
89
90 #define PRINT(s...) do { if (pr) pr(s); } while (0)
91
92 int
sshkey_xmss_init(struct sshkey * key,const char * name)93 sshkey_xmss_init(struct sshkey *key, const char *name)
94 {
95 struct ssh_xmss_state *state;
96
97 if (key->xmss_state != NULL)
98 return SSH_ERR_INVALID_FORMAT;
99 if (name == NULL)
100 return SSH_ERR_INVALID_FORMAT;
101 state = calloc(sizeof(struct ssh_xmss_state), 1);
102 if (state == NULL)
103 return SSH_ERR_ALLOC_FAIL;
104 if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
105 state->n = 32;
106 state->w = 16;
107 state->h = 10;
108 } else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
109 state->n = 32;
110 state->w = 16;
111 state->h = 16;
112 } else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
113 state->n = 32;
114 state->w = 16;
115 state->h = 20;
116 } else {
117 free(state);
118 return SSH_ERR_KEY_TYPE_UNKNOWN;
119 }
120 if ((key->xmss_name = strdup(name)) == NULL) {
121 free(state);
122 return SSH_ERR_ALLOC_FAIL;
123 }
124 state->k = 2; /* XXX hardcoded */
125 state->lockfd = -1;
126 if (xmss_set_params(&state->params, state->n, state->h, state->w,
127 state->k) != 0) {
128 free(state);
129 return SSH_ERR_INVALID_FORMAT;
130 }
131 key->xmss_state = state;
132 return 0;
133 }
134
135 void
sshkey_xmss_free_state(struct sshkey * key)136 sshkey_xmss_free_state(struct sshkey *key)
137 {
138 struct ssh_xmss_state *state = key->xmss_state;
139
140 sshkey_xmss_free_bds(key);
141 if (state) {
142 if (state->enc_keyiv) {
143 explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
144 free(state->enc_keyiv);
145 }
146 free(state->enc_ciphername);
147 free(state);
148 }
149 key->xmss_state = NULL;
150 }
151
152 #define SSH_XMSS_K2_MAGIC "k=2"
153 #define num_stack(x) ((x->h+1)*(x->n))
154 #define num_stacklevels(x) (x->h+1)
155 #define num_auth(x) ((x->h)*(x->n))
156 #define num_keep(x) ((x->h >> 1)*(x->n))
157 #define num_th_nodes(x) ((x->h - x->k)*(x->n))
158 #define num_retain(x) (((1ULL << x->k) - x->k - 1) * (x->n))
159 #define num_treehash(x) ((x->h) - (x->k))
160
161 int
sshkey_xmss_init_bds_state(struct sshkey * key)162 sshkey_xmss_init_bds_state(struct sshkey *key)
163 {
164 struct ssh_xmss_state *state = key->xmss_state;
165 u_int32_t i;
166
167 state->stackoffset = 0;
168 if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
169 (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
170 (state->auth = calloc(num_auth(state), 1)) == NULL ||
171 (state->keep = calloc(num_keep(state), 1)) == NULL ||
172 (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
173 (state->retain = calloc(num_retain(state), 1)) == NULL ||
174 (state->treehash = calloc(num_treehash(state),
175 sizeof(treehash_inst))) == NULL) {
176 sshkey_xmss_free_bds(key);
177 return SSH_ERR_ALLOC_FAIL;
178 }
179 for (i = 0; i < state->h - state->k; i++)
180 state->treehash[i].node = &state->th_nodes[state->n*i];
181 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
182 state->stacklevels, state->auth, state->keep, state->treehash,
183 state->retain, 0);
184 return 0;
185 }
186
187 void
sshkey_xmss_free_bds(struct sshkey * key)188 sshkey_xmss_free_bds(struct sshkey *key)
189 {
190 struct ssh_xmss_state *state = key->xmss_state;
191
192 if (state == NULL)
193 return;
194 free(state->stack);
195 free(state->stacklevels);
196 free(state->auth);
197 free(state->keep);
198 free(state->th_nodes);
199 free(state->retain);
200 free(state->treehash);
201 state->stack = NULL;
202 state->stacklevels = NULL;
203 state->auth = NULL;
204 state->keep = NULL;
205 state->th_nodes = NULL;
206 state->retain = NULL;
207 state->treehash = NULL;
208 }
209
210 void *
sshkey_xmss_params(const struct sshkey * key)211 sshkey_xmss_params(const struct sshkey *key)
212 {
213 struct ssh_xmss_state *state = key->xmss_state;
214
215 if (state == NULL)
216 return NULL;
217 return &state->params;
218 }
219
220 void *
sshkey_xmss_bds_state(const struct sshkey * key)221 sshkey_xmss_bds_state(const struct sshkey *key)
222 {
223 struct ssh_xmss_state *state = key->xmss_state;
224
225 if (state == NULL)
226 return NULL;
227 return &state->bds;
228 }
229
230 int
sshkey_xmss_siglen(const struct sshkey * key,size_t * lenp)231 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
232 {
233 struct ssh_xmss_state *state = key->xmss_state;
234
235 if (lenp == NULL)
236 return SSH_ERR_INVALID_ARGUMENT;
237 if (state == NULL)
238 return SSH_ERR_INVALID_FORMAT;
239 *lenp = 4 + state->n +
240 state->params.wots_par.keysize +
241 state->h * state->n;
242 return 0;
243 }
244
245 size_t
sshkey_xmss_pklen(const struct sshkey * key)246 sshkey_xmss_pklen(const struct sshkey *key)
247 {
248 struct ssh_xmss_state *state = key->xmss_state;
249
250 if (state == NULL)
251 return 0;
252 return state->n * 2;
253 }
254
255 size_t
sshkey_xmss_sklen(const struct sshkey * key)256 sshkey_xmss_sklen(const struct sshkey *key)
257 {
258 struct ssh_xmss_state *state = key->xmss_state;
259
260 if (state == NULL)
261 return 0;
262 return state->n * 4 + 4;
263 }
264
265 int
sshkey_xmss_init_enc_key(struct sshkey * k,const char * ciphername)266 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
267 {
268 struct ssh_xmss_state *state = k->xmss_state;
269 const struct sshcipher *cipher;
270 size_t keylen = 0, ivlen = 0;
271
272 if (state == NULL)
273 return SSH_ERR_INVALID_ARGUMENT;
274 if ((cipher = cipher_by_name(ciphername)) == NULL)
275 return SSH_ERR_INTERNAL_ERROR;
276 if ((state->enc_ciphername = strdup(ciphername)) == NULL)
277 return SSH_ERR_ALLOC_FAIL;
278 keylen = cipher_keylen(cipher);
279 ivlen = cipher_ivlen(cipher);
280 state->enc_keyiv_len = keylen + ivlen;
281 if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
282 free(state->enc_ciphername);
283 state->enc_ciphername = NULL;
284 return SSH_ERR_ALLOC_FAIL;
285 }
286 arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
287 return 0;
288 }
289
290 int
sshkey_xmss_serialize_enc_key(const struct sshkey * k,struct sshbuf * b)291 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
292 {
293 struct ssh_xmss_state *state = k->xmss_state;
294 int r;
295
296 if (state == NULL || state->enc_keyiv == NULL ||
297 state->enc_ciphername == NULL)
298 return SSH_ERR_INVALID_ARGUMENT;
299 if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
300 (r = sshbuf_put_string(b, state->enc_keyiv,
301 state->enc_keyiv_len)) != 0)
302 return r;
303 return 0;
304 }
305
306 int
sshkey_xmss_deserialize_enc_key(struct sshkey * k,struct sshbuf * b)307 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
308 {
309 struct ssh_xmss_state *state = k->xmss_state;
310 size_t len;
311 int r;
312
313 if (state == NULL)
314 return SSH_ERR_INVALID_ARGUMENT;
315 if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
316 (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
317 return r;
318 state->enc_keyiv_len = len;
319 return 0;
320 }
321
322 int
sshkey_xmss_serialize_pk_info(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)323 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
324 enum sshkey_serialize_rep opts)
325 {
326 struct ssh_xmss_state *state = k->xmss_state;
327 u_char have_info = 1;
328 u_int32_t idx;
329 int r;
330
331 if (state == NULL)
332 return SSH_ERR_INVALID_ARGUMENT;
333 if (opts != SSHKEY_SERIALIZE_INFO)
334 return 0;
335 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
336 if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
337 (r = sshbuf_put_u32(b, idx)) != 0 ||
338 (r = sshbuf_put_u32(b, state->maxidx)) != 0)
339 return r;
340 return 0;
341 }
342
343 int
sshkey_xmss_deserialize_pk_info(struct sshkey * k,struct sshbuf * b)344 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
345 {
346 struct ssh_xmss_state *state = k->xmss_state;
347 u_char have_info;
348 int r;
349
350 if (state == NULL)
351 return SSH_ERR_INVALID_ARGUMENT;
352 /* optional */
353 if (sshbuf_len(b) == 0)
354 return 0;
355 if ((r = sshbuf_get_u8(b, &have_info)) != 0)
356 return r;
357 if (have_info != 1)
358 return SSH_ERR_INVALID_ARGUMENT;
359 if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
360 (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
361 return r;
362 return 0;
363 }
364
365 int
sshkey_xmss_generate_private_key(struct sshkey * k,u_int bits)366 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
367 {
368 int r;
369 const char *name;
370
371 if (bits == 10) {
372 name = XMSS_SHA2_256_W16_H10_NAME;
373 } else if (bits == 16) {
374 name = XMSS_SHA2_256_W16_H16_NAME;
375 } else if (bits == 20) {
376 name = XMSS_SHA2_256_W16_H20_NAME;
377 } else {
378 name = XMSS_DEFAULT_NAME;
379 }
380 if ((r = sshkey_xmss_init(k, name)) != 0 ||
381 (r = sshkey_xmss_init_bds_state(k)) != 0 ||
382 (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
383 return r;
384 if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
385 (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
386 return SSH_ERR_ALLOC_FAIL;
387 }
388 xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
389 sshkey_xmss_params(k));
390 return 0;
391 }
392
393 int
sshkey_xmss_get_state_from_file(struct sshkey * k,const char * filename,int * have_file,sshkey_printfn * pr)394 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
395 int *have_file, sshkey_printfn *pr)
396 {
397 struct sshbuf *b = NULL, *enc = NULL;
398 int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
399 u_int32_t len;
400 unsigned char buf[4], *data = NULL;
401
402 *have_file = 0;
403 if ((fd = open(filename, O_RDONLY)) >= 0) {
404 *have_file = 1;
405 if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
406 PRINT("%s: corrupt state file: %s", __func__, filename);
407 goto done;
408 }
409 len = PEEK_U32(buf);
410 if ((data = calloc(len, 1)) == NULL) {
411 ret = SSH_ERR_ALLOC_FAIL;
412 goto done;
413 }
414 if (atomicio(read, fd, data, len) != len) {
415 PRINT("%s: cannot read blob: %s", __func__, filename);
416 goto done;
417 }
418 if ((enc = sshbuf_from(data, len)) == NULL) {
419 ret = SSH_ERR_ALLOC_FAIL;
420 goto done;
421 }
422 sshkey_xmss_free_bds(k);
423 if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
424 ret = r;
425 goto done;
426 }
427 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
428 ret = r;
429 goto done;
430 }
431 ret = 0;
432 }
433 done:
434 if (fd != -1)
435 close(fd);
436 free(data);
437 sshbuf_free(enc);
438 sshbuf_free(b);
439 return ret;
440 }
441
442 int
sshkey_xmss_get_state(const struct sshkey * k,sshkey_printfn * pr)443 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
444 {
445 struct ssh_xmss_state *state = k->xmss_state;
446 u_int32_t idx = 0;
447 char *filename = NULL;
448 char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
449 int lockfd = -1, have_state = 0, have_ostate, tries = 0;
450 int ret = SSH_ERR_INVALID_ARGUMENT, r;
451
452 if (state == NULL)
453 goto done;
454 /*
455 * If maxidx is set, then we are allowed a limited number
456 * of signatures, but don't need to access the disk.
457 * Otherwise we need to deal with the on-disk state.
458 */
459 if (state->maxidx) {
460 /* xmss_sk always contains the current state */
461 idx = PEEK_U32(k->xmss_sk);
462 if (idx < state->maxidx) {
463 state->allow_update = 1;
464 return 0;
465 }
466 return SSH_ERR_INVALID_ARGUMENT;
467 }
468 if ((filename = k->xmss_filename) == NULL)
469 goto done;
470 if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
471 asprintf(&statefile, "%s.state", filename) == -1 ||
472 asprintf(&ostatefile, "%s.ostate", filename) == -1) {
473 ret = SSH_ERR_ALLOC_FAIL;
474 goto done;
475 }
476 if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
477 ret = SSH_ERR_SYSTEM_ERROR;
478 PRINT("%s: cannot open/create: %s", __func__, lockfile);
479 goto done;
480 }
481 while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
482 if (errno != EWOULDBLOCK) {
483 ret = SSH_ERR_SYSTEM_ERROR;
484 PRINT("%s: cannot lock: %s", __func__, lockfile);
485 goto done;
486 }
487 if (++tries > 10) {
488 ret = SSH_ERR_SYSTEM_ERROR;
489 PRINT("%s: giving up on: %s", __func__, lockfile);
490 goto done;
491 }
492 usleep(1000*100*tries);
493 }
494 /* XXX no longer const */
495 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
496 statefile, &have_state, pr)) != 0) {
497 if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498 ostatefile, &have_ostate, pr)) == 0) {
499 state->allow_update = 1;
500 r = sshkey_xmss_forward_state(k, 1);
501 state->idx = PEEK_U32(k->xmss_sk);
502 state->allow_update = 0;
503 }
504 }
505 if (!have_state && !have_ostate) {
506 /* check that bds state is initialized */
507 if (state->bds.auth == NULL)
508 goto done;
509 PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
510 } else if (r != 0) {
511 ret = r;
512 goto done;
513 }
514 if (state->idx + 1 < state->idx) {
515 PRINT("%s: state wrap: %u", __func__, state->idx);
516 goto done;
517 }
518 state->have_state = have_state;
519 state->lockfd = lockfd;
520 state->allow_update = 1;
521 lockfd = -1;
522 ret = 0;
523 done:
524 if (lockfd != -1)
525 close(lockfd);
526 free(lockfile);
527 free(statefile);
528 free(ostatefile);
529 return ret;
530 }
531
532 int
sshkey_xmss_forward_state(const struct sshkey * k,u_int32_t reserve)533 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
534 {
535 struct ssh_xmss_state *state = k->xmss_state;
536 u_char *sig = NULL;
537 size_t required_siglen;
538 unsigned long long smlen;
539 u_char data;
540 int ret, r;
541
542 if (state == NULL || !state->allow_update)
543 return SSH_ERR_INVALID_ARGUMENT;
544 if (reserve == 0)
545 return SSH_ERR_INVALID_ARGUMENT;
546 if (state->idx + reserve <= state->idx)
547 return SSH_ERR_INVALID_ARGUMENT;
548 if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
549 return r;
550 if ((sig = malloc(required_siglen)) == NULL)
551 return SSH_ERR_ALLOC_FAIL;
552 while (reserve-- > 0) {
553 state->idx = PEEK_U32(k->xmss_sk);
554 smlen = required_siglen;
555 if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
556 sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
557 r = SSH_ERR_INVALID_ARGUMENT;
558 break;
559 }
560 }
561 free(sig);
562 return r;
563 }
564
565 int
sshkey_xmss_update_state(const struct sshkey * k,sshkey_printfn * pr)566 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
567 {
568 struct ssh_xmss_state *state = k->xmss_state;
569 struct sshbuf *b = NULL, *enc = NULL;
570 u_int32_t idx = 0;
571 unsigned char buf[4];
572 char *filename = NULL;
573 char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
574 int fd = -1;
575 int ret = SSH_ERR_INVALID_ARGUMENT;
576
577 if (state == NULL || !state->allow_update)
578 return ret;
579 if (state->maxidx) {
580 /* no update since the number of signatures is limited */
581 ret = 0;
582 goto done;
583 }
584 idx = PEEK_U32(k->xmss_sk);
585 if (idx == state->idx) {
586 /* no signature happened, no need to update */
587 ret = 0;
588 goto done;
589 } else if (idx != state->idx + 1) {
590 PRINT("%s: more than one signature happened: idx %u state %u",
591 __func__, idx, state->idx);
592 goto done;
593 }
594 state->idx = idx;
595 if ((filename = k->xmss_filename) == NULL)
596 goto done;
597 if (asprintf(&statefile, "%s.state", filename) == -1 ||
598 asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
599 asprintf(&nstatefile, "%s.nstate", filename) == -1) {
600 ret = SSH_ERR_ALLOC_FAIL;
601 goto done;
602 }
603 unlink(nstatefile);
604 if ((b = sshbuf_new()) == NULL) {
605 ret = SSH_ERR_ALLOC_FAIL;
606 goto done;
607 }
608 if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
609 PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
610 goto done;
611 }
612 if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
613 PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
614 goto done;
615 }
616 if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
617 ret = SSH_ERR_SYSTEM_ERROR;
618 PRINT("%s: open new state file: %s", __func__, nstatefile);
619 goto done;
620 }
621 POKE_U32(buf, sshbuf_len(enc));
622 if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
623 ret = SSH_ERR_SYSTEM_ERROR;
624 PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
625 close(fd);
626 goto done;
627 }
628 if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
629 sshbuf_len(enc)) {
630 ret = SSH_ERR_SYSTEM_ERROR;
631 PRINT("%s: write new state file data: %s", __func__, nstatefile);
632 close(fd);
633 goto done;
634 }
635 if (fsync(fd) == -1) {
636 ret = SSH_ERR_SYSTEM_ERROR;
637 PRINT("%s: sync new state file: %s", __func__, nstatefile);
638 close(fd);
639 goto done;
640 }
641 if (close(fd) == -1) {
642 ret = SSH_ERR_SYSTEM_ERROR;
643 PRINT("%s: close new state file: %s", __func__, nstatefile);
644 goto done;
645 }
646 if (state->have_state) {
647 unlink(ostatefile);
648 if (link(statefile, ostatefile)) {
649 ret = SSH_ERR_SYSTEM_ERROR;
650 PRINT("%s: backup state %s to %s", __func__, statefile,
651 ostatefile);
652 goto done;
653 }
654 }
655 if (rename(nstatefile, statefile) == -1) {
656 ret = SSH_ERR_SYSTEM_ERROR;
657 PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
658 goto done;
659 }
660 ret = 0;
661 done:
662 if (state->lockfd != -1) {
663 close(state->lockfd);
664 state->lockfd = -1;
665 }
666 if (nstatefile)
667 unlink(nstatefile);
668 free(statefile);
669 free(ostatefile);
670 free(nstatefile);
671 sshbuf_free(b);
672 sshbuf_free(enc);
673 return ret;
674 }
675
676 int
sshkey_xmss_serialize_state(const struct sshkey * k,struct sshbuf * b)677 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
678 {
679 struct ssh_xmss_state *state = k->xmss_state;
680 treehash_inst *th;
681 u_int32_t i, node;
682 int r;
683
684 if (state == NULL)
685 return SSH_ERR_INVALID_ARGUMENT;
686 if (state->stack == NULL)
687 return SSH_ERR_INVALID_ARGUMENT;
688 state->stackoffset = state->bds.stackoffset; /* copy back */
689 if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
690 (r = sshbuf_put_u32(b, state->idx)) != 0 ||
691 (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
692 (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
693 (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
694 (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
695 (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
696 (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
697 (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
698 (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
699 return r;
700 for (i = 0; i < num_treehash(state); i++) {
701 th = &state->treehash[i];
702 node = th->node - state->th_nodes;
703 if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
704 (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
705 (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
706 (r = sshbuf_put_u8(b, th->completed)) != 0 ||
707 (r = sshbuf_put_u32(b, node)) != 0)
708 return r;
709 }
710 return 0;
711 }
712
713 int
sshkey_xmss_serialize_state_opt(const struct sshkey * k,struct sshbuf * b,enum sshkey_serialize_rep opts)714 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
715 enum sshkey_serialize_rep opts)
716 {
717 struct ssh_xmss_state *state = k->xmss_state;
718 int r = SSH_ERR_INVALID_ARGUMENT;
719 u_char have_stack, have_filename, have_enc;
720
721 if (state == NULL)
722 return SSH_ERR_INVALID_ARGUMENT;
723 if ((r = sshbuf_put_u8(b, opts)) != 0)
724 return r;
725 switch (opts) {
726 case SSHKEY_SERIALIZE_STATE:
727 r = sshkey_xmss_serialize_state(k, b);
728 break;
729 case SSHKEY_SERIALIZE_FULL:
730 if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
731 return r;
732 r = sshkey_xmss_serialize_state(k, b);
733 break;
734 case SSHKEY_SERIALIZE_SHIELD:
735 /* all of stack/filename/enc are optional */
736 have_stack = state->stack != NULL;
737 if ((r = sshbuf_put_u8(b, have_stack)) != 0)
738 return r;
739 if (have_stack) {
740 state->idx = PEEK_U32(k->xmss_sk); /* update */
741 if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
742 return r;
743 }
744 have_filename = k->xmss_filename != NULL;
745 if ((r = sshbuf_put_u8(b, have_filename)) != 0)
746 return r;
747 if (have_filename &&
748 (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
749 return r;
750 have_enc = state->enc_keyiv != NULL;
751 if ((r = sshbuf_put_u8(b, have_enc)) != 0)
752 return r;
753 if (have_enc &&
754 (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
755 return r;
756 if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
757 (r = sshbuf_put_u8(b, state->allow_update)) != 0)
758 return r;
759 break;
760 case SSHKEY_SERIALIZE_DEFAULT:
761 r = 0;
762 break;
763 default:
764 r = SSH_ERR_INVALID_ARGUMENT;
765 break;
766 }
767 return r;
768 }
769
770 int
sshkey_xmss_deserialize_state(struct sshkey * k,struct sshbuf * b)771 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
772 {
773 struct ssh_xmss_state *state = k->xmss_state;
774 treehash_inst *th;
775 u_int32_t i, lh, node;
776 size_t ls, lsl, la, lk, ln, lr;
777 char *magic;
778 int r = SSH_ERR_INTERNAL_ERROR;
779
780 if (state == NULL)
781 return SSH_ERR_INVALID_ARGUMENT;
782 if (k->xmss_sk == NULL)
783 return SSH_ERR_INVALID_ARGUMENT;
784 if ((state->treehash = calloc(num_treehash(state),
785 sizeof(treehash_inst))) == NULL)
786 return SSH_ERR_ALLOC_FAIL;
787 if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
788 (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
789 (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
790 (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
791 (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
792 (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
793 (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
794 (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
795 (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
796 (r = sshbuf_get_u32(b, &lh)) != 0)
797 goto out;
798 if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
799 r = SSH_ERR_INVALID_ARGUMENT;
800 goto out;
801 }
802 /* XXX check stackoffset */
803 if (ls != num_stack(state) ||
804 lsl != num_stacklevels(state) ||
805 la != num_auth(state) ||
806 lk != num_keep(state) ||
807 ln != num_th_nodes(state) ||
808 lr != num_retain(state) ||
809 lh != num_treehash(state)) {
810 r = SSH_ERR_INVALID_ARGUMENT;
811 goto out;
812 }
813 for (i = 0; i < num_treehash(state); i++) {
814 th = &state->treehash[i];
815 if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
816 (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
817 (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
818 (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
819 (r = sshbuf_get_u32(b, &node)) != 0)
820 goto out;
821 if (node < num_th_nodes(state))
822 th->node = &state->th_nodes[node];
823 }
824 POKE_U32(k->xmss_sk, state->idx);
825 xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
826 state->stacklevels, state->auth, state->keep, state->treehash,
827 state->retain, 0);
828 /* success */
829 r = 0;
830 out:
831 free(magic);
832 return r;
833 }
834
835 int
sshkey_xmss_deserialize_state_opt(struct sshkey * k,struct sshbuf * b)836 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
837 {
838 struct ssh_xmss_state *state = k->xmss_state;
839 enum sshkey_serialize_rep opts;
840 u_char have_state, have_stack, have_filename, have_enc;
841 int r;
842
843 if ((r = sshbuf_get_u8(b, &have_state)) != 0)
844 return r;
845
846 opts = have_state;
847 switch (opts) {
848 case SSHKEY_SERIALIZE_DEFAULT:
849 r = 0;
850 break;
851 case SSHKEY_SERIALIZE_SHIELD:
852 if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
853 return r;
854 if (have_stack &&
855 (r = sshkey_xmss_deserialize_state(k, b)) != 0)
856 return r;
857 if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
858 return r;
859 if (have_filename &&
860 (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
861 return r;
862 if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
863 return r;
864 if (have_enc &&
865 (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
866 return r;
867 if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
868 (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
869 return r;
870 break;
871 case SSHKEY_SERIALIZE_STATE:
872 if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
873 return r;
874 break;
875 case SSHKEY_SERIALIZE_FULL:
876 if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
877 (r = sshkey_xmss_deserialize_state(k, b)) != 0)
878 return r;
879 break;
880 default:
881 r = SSH_ERR_INVALID_FORMAT;
882 break;
883 }
884 return r;
885 }
886
887 int
sshkey_xmss_encrypt_state(const struct sshkey * k,struct sshbuf * b,struct sshbuf ** retp)888 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
889 struct sshbuf **retp)
890 {
891 struct ssh_xmss_state *state = k->xmss_state;
892 struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
893 struct sshcipher_ctx *ciphercontext = NULL;
894 const struct sshcipher *cipher;
895 u_char *cp, *key, *iv = NULL;
896 size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
897 int r = SSH_ERR_INTERNAL_ERROR;
898
899 if (retp != NULL)
900 *retp = NULL;
901 if (state == NULL ||
902 state->enc_keyiv == NULL ||
903 state->enc_ciphername == NULL)
904 return SSH_ERR_INTERNAL_ERROR;
905 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
906 r = SSH_ERR_INTERNAL_ERROR;
907 goto out;
908 }
909 blocksize = cipher_blocksize(cipher);
910 keylen = cipher_keylen(cipher);
911 ivlen = cipher_ivlen(cipher);
912 authlen = cipher_authlen(cipher);
913 if (state->enc_keyiv_len != keylen + ivlen) {
914 r = SSH_ERR_INVALID_FORMAT;
915 goto out;
916 }
917 key = state->enc_keyiv;
918 if ((encrypted = sshbuf_new()) == NULL ||
919 (encoded = sshbuf_new()) == NULL ||
920 (padded = sshbuf_new()) == NULL ||
921 (iv = malloc(ivlen)) == NULL) {
922 r = SSH_ERR_ALLOC_FAIL;
923 goto out;
924 }
925
926 /* replace first 4 bytes of IV with index to ensure uniqueness */
927 memcpy(iv, key + keylen, ivlen);
928 POKE_U32(iv, state->idx);
929
930 if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
931 (r = sshbuf_put_u32(encoded, state->idx)) != 0)
932 goto out;
933
934 /* padded state will be encrypted */
935 if ((r = sshbuf_putb(padded, b)) != 0)
936 goto out;
937 i = 0;
938 while (sshbuf_len(padded) % blocksize) {
939 if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
940 goto out;
941 }
942 encrypted_len = sshbuf_len(padded);
943
944 /* header including the length of state is used as AAD */
945 if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
946 goto out;
947 aadlen = sshbuf_len(encoded);
948
949 /* concat header and state */
950 if ((r = sshbuf_putb(encoded, padded)) != 0)
951 goto out;
952
953 /* reserve space for encryption of encoded data plus auth tag */
954 /* encrypt at offset addlen */
955 if ((r = sshbuf_reserve(encrypted,
956 encrypted_len + aadlen + authlen, &cp)) != 0 ||
957 (r = cipher_init(&ciphercontext, cipher, key, keylen,
958 iv, ivlen, 1)) != 0 ||
959 (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
960 encrypted_len, aadlen, authlen)) != 0)
961 goto out;
962
963 /* success */
964 r = 0;
965 out:
966 if (retp != NULL) {
967 *retp = encrypted;
968 encrypted = NULL;
969 }
970 sshbuf_free(padded);
971 sshbuf_free(encoded);
972 sshbuf_free(encrypted);
973 cipher_free(ciphercontext);
974 free(iv);
975 return r;
976 }
977
978 int
sshkey_xmss_decrypt_state(const struct sshkey * k,struct sshbuf * encoded,struct sshbuf ** retp)979 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
980 struct sshbuf **retp)
981 {
982 struct ssh_xmss_state *state = k->xmss_state;
983 struct sshbuf *copy = NULL, *decrypted = NULL;
984 struct sshcipher_ctx *ciphercontext = NULL;
985 const struct sshcipher *cipher = NULL;
986 u_char *key, *iv = NULL, *dp;
987 size_t keylen, ivlen, authlen, aadlen;
988 u_int blocksize, encrypted_len, index;
989 int r = SSH_ERR_INTERNAL_ERROR;
990
991 if (retp != NULL)
992 *retp = NULL;
993 if (state == NULL ||
994 state->enc_keyiv == NULL ||
995 state->enc_ciphername == NULL)
996 return SSH_ERR_INTERNAL_ERROR;
997 if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
998 r = SSH_ERR_INVALID_FORMAT;
999 goto out;
1000 }
1001 blocksize = cipher_blocksize(cipher);
1002 keylen = cipher_keylen(cipher);
1003 ivlen = cipher_ivlen(cipher);
1004 authlen = cipher_authlen(cipher);
1005 if (state->enc_keyiv_len != keylen + ivlen) {
1006 r = SSH_ERR_INTERNAL_ERROR;
1007 goto out;
1008 }
1009 key = state->enc_keyiv;
1010
1011 if ((copy = sshbuf_fromb(encoded)) == NULL ||
1012 (decrypted = sshbuf_new()) == NULL ||
1013 (iv = malloc(ivlen)) == NULL) {
1014 r = SSH_ERR_ALLOC_FAIL;
1015 goto out;
1016 }
1017
1018 /* check magic */
1019 if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1020 memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1021 r = SSH_ERR_INVALID_FORMAT;
1022 goto out;
1023 }
1024 /* parse public portion */
1025 if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1026 (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1027 (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1028 goto out;
1029
1030 /* check size of encrypted key blob */
1031 if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1032 r = SSH_ERR_INVALID_FORMAT;
1033 goto out;
1034 }
1035 /* check that an appropriate amount of auth data is present */
1036 if (sshbuf_len(encoded) < authlen ||
1037 sshbuf_len(encoded) - authlen < encrypted_len) {
1038 r = SSH_ERR_INVALID_FORMAT;
1039 goto out;
1040 }
1041
1042 aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1043
1044 /* replace first 4 bytes of IV with index to ensure uniqueness */
1045 memcpy(iv, key + keylen, ivlen);
1046 POKE_U32(iv, index);
1047
1048 /* decrypt private state of key */
1049 if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1050 (r = cipher_init(&ciphercontext, cipher, key, keylen,
1051 iv, ivlen, 0)) != 0 ||
1052 (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1053 encrypted_len, aadlen, authlen)) != 0)
1054 goto out;
1055
1056 /* there should be no trailing data */
1057 if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1058 goto out;
1059 if (sshbuf_len(encoded) != 0) {
1060 r = SSH_ERR_INVALID_FORMAT;
1061 goto out;
1062 }
1063
1064 /* remove AAD */
1065 if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1066 goto out;
1067 /* XXX encrypted includes unchecked padding */
1068
1069 /* success */
1070 r = 0;
1071 if (retp != NULL) {
1072 *retp = decrypted;
1073 decrypted = NULL;
1074 }
1075 out:
1076 cipher_free(ciphercontext);
1077 sshbuf_free(copy);
1078 sshbuf_free(decrypted);
1079 free(iv);
1080 return r;
1081 }
1082
1083 u_int32_t
sshkey_xmss_signatures_left(const struct sshkey * k)1084 sshkey_xmss_signatures_left(const struct sshkey *k)
1085 {
1086 struct ssh_xmss_state *state = k->xmss_state;
1087 u_int32_t idx;
1088
1089 if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1090 state->maxidx) {
1091 idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1092 if (idx < state->maxidx)
1093 return state->maxidx - idx;
1094 }
1095 return 0;
1096 }
1097
1098 int
sshkey_xmss_enable_maxsign(struct sshkey * k,u_int32_t maxsign)1099 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1100 {
1101 struct ssh_xmss_state *state = k->xmss_state;
1102
1103 if (sshkey_type_plain(k->type) != KEY_XMSS)
1104 return SSH_ERR_INVALID_ARGUMENT;
1105 if (maxsign == 0)
1106 return 0;
1107 if (state->idx + maxsign < state->idx)
1108 return SSH_ERR_INVALID_ARGUMENT;
1109 state->maxidx = state->idx + maxsign;
1110 return 0;
1111 }
1112 #endif /* WITH_XMSS */
1113