1package hrss 2 3import ( 4 "crypto/hmac" 5 "crypto/sha256" 6 "crypto/subtle" 7 "encoding/binary" 8 "io" 9 "math/bits" 10) 11 12const ( 13 PublicKeySize = modQBytes 14 CiphertextSize = modQBytes 15) 16 17const ( 18 N = 701 19 Q = 8192 20 mod3Bytes = 140 21 modQBytes = 1138 22) 23 24const ( 25 bitsPerWord = bits.UintSize 26 wordsPerPoly = (N + bitsPerWord - 1) / bitsPerWord 27 fullWordsPerPoly = N / bitsPerWord 28 bitsInLastWord = N % bitsPerWord 29) 30 31// poly3 represents a degree-N polynomial over GF(3). Each coefficient is 32// bitsliced across the |s| and |a| arrays, like this: 33// 34// s | a | value 35// ----------------- 36// 0 | 0 | 0 37// 0 | 1 | 1 38// 1 | 0 | 2 (aka -1) 39// 1 | 1 | <invalid> 40// 41// ('s' is for sign, and 'a' is just a letter.) 42// 43// Once bitsliced as such, the following circuits can be used to implement 44// addition and multiplication mod 3: 45// 46// (s3, a3) = (s1, a1) × (s2, a2) 47// s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2) 48// a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2) 49// 50// (s3, a3) = (s1, a1) + (s2, a2) 51// t1 = ~(s1 ∨ a1) 52// t2 = ~(s2 ∨ a2) 53// s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1) 54// a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1) 55// 56// Negating a value just involves swapping s and a. 57type poly3 struct { 58 s [wordsPerPoly]uint 59 a [wordsPerPoly]uint 60} 61 62func (p *poly3) trim() { 63 p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 64 p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 65} 66 67func (p *poly3) zero() { 68 for i := range p.a { 69 p.s[i] = 0 70 p.a[i] = 0 71 } 72} 73 74func (p *poly3) fromDiscrete(in *poly) { 75 var shift uint 76 s := p.s[:] 77 a := p.a[:] 78 s[0] = 0 79 a[0] = 0 80 81 for _, v := range in { 82 s[0] >>= 1 83 s[0] |= uint((v>>1)&1) << (bitsPerWord - 1) 84 a[0] >>= 1 85 a[0] |= uint(v&1) << (bitsPerWord - 1) 86 shift++ 87 if shift == bitsPerWord { 88 s = s[1:] 89 a = a[1:] 90 s[0] = 0 91 a[0] = 0 92 shift = 0 93 } 94 } 95 96 a[0] >>= bitsPerWord - shift 97 s[0] >>= bitsPerWord - shift 98} 99 100func (p *poly3) fromModQ(in *poly) int { 101 var shift uint 102 s := p.s[:] 103 a := p.a[:] 104 s[0] = 0 105 a[0] = 0 106 ok := 1 107 108 for _, v := range in { 109 vMod3, vOk := modQToMod3(v) 110 ok &= vOk 111 112 s[0] >>= 1 113 s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1) 114 a[0] >>= 1 115 a[0] |= uint(vMod3&1) << (bitsPerWord - 1) 116 shift++ 117 if shift == bitsPerWord { 118 s = s[1:] 119 a = a[1:] 120 s[0] = 0 121 a[0] = 0 122 shift = 0 123 } 124 } 125 126 a[0] >>= bitsPerWord - shift 127 s[0] >>= bitsPerWord - shift 128 129 return ok 130} 131 132func (p *poly3) fromDiscreteMod3(in *poly) { 133 var shift uint 134 s := p.s[:] 135 a := p.a[:] 136 s[0] = 0 137 a[0] = 0 138 139 for _, v := range in { 140 // This duplicates the 13th bit upwards to the top of the 141 // uint16, essentially treating it as a sign bit and converting 142 // into a signed int16. The signed value is reduced mod 3, 143 // yeilding {-2, -1, 0, 1, 2}. 144 v = uint16((int16(v<<3)>>3)%3) & 7 145 146 // We want to map v thus: 147 // {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom 148 // three bits and then the constants below, when shifted by 149 // those three bits, perform the required mapping. 150 s[0] >>= 1 151 s[0] |= (0xbc >> v) << (bitsPerWord - 1) 152 a[0] >>= 1 153 a[0] |= (0x7a >> v) << (bitsPerWord - 1) 154 shift++ 155 if shift == bitsPerWord { 156 s = s[1:] 157 a = a[1:] 158 s[0] = 0 159 a[0] = 0 160 shift = 0 161 } 162 } 163 164 a[0] >>= bitsPerWord - shift 165 s[0] >>= bitsPerWord - shift 166} 167 168func (p *poly3) marshal(out []byte) { 169 s := p.s[:] 170 a := p.a[:] 171 sw := s[0] 172 aw := a[0] 173 var shift int 174 175 for i := 0; i < 700; i += 5 { 176 acc, scale := 0, 1 177 for j := 0; j < 5; j++ { 178 v := int(aw&1) | int(sw&1)<<1 179 acc += scale * v 180 scale *= 3 181 182 shift++ 183 if shift == bitsPerWord { 184 s = s[1:] 185 a = a[1:] 186 sw = s[0] 187 aw = a[0] 188 shift = 0 189 } else { 190 sw >>= 1 191 aw >>= 1 192 } 193 } 194 195 out[0] = byte(acc) 196 out = out[1:] 197 } 198} 199 200func (p *poly) fromMod2(in *poly2) { 201 var shift uint 202 words := in[:] 203 word := words[0] 204 205 for i := range p { 206 p[i] = uint16(word & 1) 207 word >>= 1 208 shift++ 209 if shift == bitsPerWord { 210 words = words[1:] 211 word = words[0] 212 shift = 0 213 } 214 } 215} 216 217func (p *poly) fromMod3(in *poly3) { 218 var shift uint 219 s := in.s[:] 220 a := in.a[:] 221 sw := s[0] 222 aw := a[0] 223 224 for i := range p { 225 p[i] = uint16(aw&1 | (sw&1)<<1) 226 aw >>= 1 227 sw >>= 1 228 shift++ 229 if shift == bitsPerWord { 230 a = a[1:] 231 s = s[1:] 232 aw = a[0] 233 sw = s[0] 234 shift = 0 235 } 236 } 237} 238 239func (p *poly) fromMod3ToModQ(in *poly3) { 240 var shift uint 241 s := in.s[:] 242 a := in.a[:] 243 sw := s[0] 244 aw := a[0] 245 246 for i := range p { 247 p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1)) 248 aw >>= 1 249 sw >>= 1 250 shift++ 251 if shift == bitsPerWord { 252 a = a[1:] 253 s = s[1:] 254 aw = a[0] 255 sw = s[0] 256 shift = 0 257 } 258 } 259} 260 261func lsbToAll(v uint) uint { 262 return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1)) 263} 264 265func (p *poly3) mulConst(ms, ma uint) { 266 ms = lsbToAll(ms) 267 ma = lsbToAll(ma) 268 269 for i := range p.a { 270 p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i]) 271 } 272} 273 274func cmovWords(out, in *[wordsPerPoly]uint, mov uint) { 275 for i := range out { 276 out[i] = (out[i] & ^mov) | (in[i] & mov) 277 } 278} 279 280func rotWords(out, in *[wordsPerPoly]uint, bits uint) { 281 start := bits / bitsPerWord 282 n := (N - bits) / bitsPerWord 283 284 for i := uint(0); i < n; i++ { 285 out[i] = in[start+i] 286 } 287 288 carry := in[wordsPerPoly-1] 289 290 for i := uint(0); i < start; i++ { 291 out[n+i] = carry | in[i]<<bitsInLastWord 292 carry = in[i] >> (bitsPerWord - bitsInLastWord) 293 } 294 295 out[wordsPerPoly-1] = carry 296} 297 298// rotBits right-rotates the bits in |in|. bits must be a non-zero power of two 299// and less than bitsPerWord. 300func rotBits(out, in *[wordsPerPoly]uint, bits uint) { 301 if (bits == 0 || (bits & (bits - 1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2) { 302 panic("internal error"); 303 } 304 305 carry := in[wordsPerPoly-1] << (bitsPerWord - bits) 306 307 for i := wordsPerPoly - 2; i >= 0; i-- { 308 out[i] = carry | in[i]>>bits 309 carry = in[i] << (bitsPerWord - bits) 310 } 311 312 out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits 313} 314 315func (p *poly3) rotWords(bits uint, in *poly3) { 316 rotWords(&p.s, &in.s, bits) 317 rotWords(&p.a, &in.a, bits) 318} 319 320func (p *poly3) rotBits(bits uint, in *poly3) { 321 rotBits(&p.s, &in.s, bits) 322 rotBits(&p.a, &in.a, bits) 323} 324 325func (p *poly3) cmov(in *poly3, mov uint) { 326 cmovWords(&p.s, &in.s, mov) 327 cmovWords(&p.a, &in.a, mov) 328} 329 330func (p *poly3) rot(bits uint) { 331 if bits > N { 332 panic("invalid") 333 } 334 var shifted poly3 335 336 shift := uint(9) 337 for ; (1 << shift) >= bitsPerWord; shift-- { 338 shifted.rotWords(1<<shift, p) 339 p.cmov(&shifted, lsbToAll(bits>>shift)) 340 } 341 for ; shift < 9; shift-- { 342 shifted.rotBits(1<<shift, p) 343 p.cmov(&shifted, lsbToAll(bits>>shift)) 344 } 345} 346 347func (p *poly3) fmadd(ms, ma uint, in *poly3) { 348 ms = lsbToAll(ms) 349 ma = lsbToAll(ma) 350 351 for i := range p.a { 352 products := (ma & in.s[i]) ^ (ms & in.a[i]) 353 producta := (ma & in.a[i]) ^ (ms & in.s[i]) 354 355 ns1Ana1 := ^p.s[i] & ^p.a[i] 356 ns2Ana2 := ^products & ^producta 357 358 p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2) 359 } 360} 361 362func (p *poly3) modPhiN() { 363 factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) 364 factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) 365 ns2Ana2 := ^factors & ^factora 366 367 for i := range p.s { 368 ns1Ana1 := ^p.s[i] & ^p.a[i] 369 p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2) 370 } 371} 372 373func (p *poly3) cswap(other *poly3, swap uint) { 374 for i := range p.s { 375 sums := swap & (p.s[i] ^ other.s[i]) 376 p.s[i] ^= sums 377 other.s[i] ^= sums 378 379 suma := swap & (p.a[i] ^ other.a[i]) 380 p.a[i] ^= suma 381 other.a[i] ^= suma 382 } 383} 384 385func (p *poly3) mulx() { 386 carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 387 carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 388 389 for i := range p.s { 390 outCarrys := p.s[i] >> (bitsPerWord - 1) 391 outCarrya := p.a[i] >> (bitsPerWord - 1) 392 p.s[i] <<= 1 393 p.a[i] <<= 1 394 p.s[i] |= carrys 395 p.a[i] |= carrya 396 carrys = outCarrys 397 carrya = outCarrya 398 } 399} 400 401func (p *poly3) divx() { 402 var carrys, carrya uint 403 404 for i := len(p.s) - 1; i >= 0; i-- { 405 outCarrys := p.s[i] & 1 406 outCarrya := p.a[i] & 1 407 p.s[i] >>= 1 408 p.a[i] >>= 1 409 p.s[i] |= carrys << (bitsPerWord - 1) 410 p.a[i] |= carrya << (bitsPerWord - 1) 411 carrys = outCarrys 412 carrya = outCarrya 413 } 414} 415 416type poly2 [wordsPerPoly]uint 417 418func (p *poly2) fromDiscrete(in *poly) { 419 var shift uint 420 words := p[:] 421 words[0] = 0 422 423 for _, v := range in { 424 words[0] >>= 1 425 words[0] |= uint(v&1) << (bitsPerWord - 1) 426 shift++ 427 if shift == bitsPerWord { 428 words = words[1:] 429 words[0] = 0 430 shift = 0 431 } 432 } 433 434 words[0] >>= bitsPerWord - shift 435} 436 437func (p *poly2) setPhiN() { 438 for i := range p { 439 p[i] = ^uint(0) 440 } 441 p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 442} 443 444func (p *poly2) cswap(other *poly2, swap uint) { 445 for i := range p { 446 sum := swap & (p[i] ^ other[i]) 447 p[i] ^= sum 448 other[i] ^= sum 449 } 450} 451 452func (p *poly2) fmadd(m uint, in *poly2) { 453 m = ^(m - 1) 454 455 for i := range p { 456 p[i] ^= in[i] & m 457 } 458} 459 460func (p *poly2) lshift1() { 461 var carry uint 462 for i := range p { 463 nextCarry := p[i] >> (bitsPerWord - 1) 464 p[i] <<= 1 465 p[i] |= carry 466 carry = nextCarry 467 } 468} 469 470func (p *poly2) rshift1() { 471 var carry uint 472 for i := len(p) - 1; i >= 0; i-- { 473 nextCarry := p[i] & 1 474 p[i] >>= 1 475 p[i] |= carry << (bitsPerWord - 1) 476 carry = nextCarry 477 } 478} 479 480func (p *poly2) rot(bits uint) { 481 if bits > N { 482 panic("invalid") 483 } 484 var shifted [wordsPerPoly]uint 485 out := (*[wordsPerPoly]uint)(p) 486 487 shift := uint(9) 488 for ; (1 << shift) >= bitsPerWord; shift-- { 489 rotWords(&shifted, out, 1<<shift) 490 cmovWords(out, &shifted, lsbToAll(bits>>shift)) 491 } 492 for ; shift < 9; shift-- { 493 rotBits(&shifted, out, 1<<shift) 494 cmovWords(out, &shifted, lsbToAll(bits>>shift)) 495 } 496} 497 498type poly [N]uint16 499 500func (in *poly) marshal(out []byte) { 501 p := in[:] 502 503 for len(p) >= 8 { 504 out[0] = byte(p[0]) 505 out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) 506 out[2] = byte(p[1] >> 3) 507 out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) 508 out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) 509 out[5] = byte(p[3] >> 1) 510 out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4) 511 out[7] = byte(p[4] >> 4) 512 out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1) 513 out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6) 514 out[10] = byte(p[6] >> 2) 515 out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3) 516 out[12] = byte(p[7] >> 5) 517 518 p = p[8:] 519 out = out[13:] 520 } 521 522 // There are four remaining values. 523 out[0] = byte(p[0]) 524 out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) 525 out[2] = byte(p[1] >> 3) 526 out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) 527 out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) 528 out[5] = byte(p[3] >> 1) 529 out[6] = byte(p[3] >> 9) 530} 531 532func (out *poly) unmarshal(in []byte) bool { 533 p := out[:] 534 for i := 0; i < 87; i++ { 535 p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 536 p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 537 p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 538 p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 539 p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12 540 p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7 541 p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10 542 p[7] = uint16(in[11]>>3) | uint16(in[12])<<5 543 544 p = p[8:] 545 in = in[13:] 546 } 547 548 // There are four coefficients left over 549 p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 550 p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 551 p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 552 p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 553 554 if in[6]&0xf0 != 0 { 555 return false 556 } 557 558 out[N-1] = 0 559 var top int 560 for _, v := range out { 561 top += int(v) 562 } 563 564 out[N-1] = uint16(-top) % Q 565 return true 566} 567 568func (in *poly) marshalS3(out []byte) { 569 p := in[:] 570 for len(p) >= 5 { 571 out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81) 572 out = out[1:] 573 p = p[5:] 574 } 575} 576 577func (out *poly) unmarshalS3(in []byte) bool { 578 p := out[:] 579 for i := 0; i < 140; i++ { 580 c := in[0] 581 if c >= 243 { 582 return false 583 } 584 p[0] = uint16(c % 3) 585 p[1] = uint16((c / 3) % 3) 586 p[2] = uint16((c / 9) % 3) 587 p[3] = uint16((c / 27) % 3) 588 p[4] = uint16((c / 81) % 3) 589 590 p = p[5:] 591 in = in[1:] 592 } 593 594 out[N-1] = 0 595 return true 596} 597 598func (p *poly) modPhiN() { 599 for i := range p { 600 p[i] = (p[i] + Q - p[N-1]) % Q 601 } 602} 603 604func (out *poly) shortSample(in []byte) { 605 // b a result 606 // 00 00 00 607 // 00 01 01 608 // 00 10 10 609 // 00 11 11 610 // 01 00 10 611 // 01 01 00 612 // 01 10 01 613 // 01 11 11 614 // 10 00 01 615 // 10 01 10 616 // 10 10 00 617 // 10 11 11 618 // 11 00 11 619 // 11 01 11 620 // 11 10 11 621 // 11 11 11 622 623 // 1111 1111 1100 1001 1101 0010 1110 0100 624 // f f c 9 d 2 e 4 625 const lookup = uint32(0xffc9d2e4) 626 627 p := out[:] 628 for i := 0; i < 87; i++ { 629 v := binary.LittleEndian.Uint32(in) 630 v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) 631 for j := 0; j < 8; j++ { 632 p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) 633 v2 >>= 4 634 } 635 p = p[8:] 636 in = in[4:] 637 } 638 639 // There are four values remaining. 640 v := binary.LittleEndian.Uint32(in) 641 v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) 642 for j := 0; j < 4; j++ { 643 p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) 644 v2 >>= 4 645 } 646 647 out[N-1] = 0 648} 649 650func (out *poly) shortSamplePlus(in []byte) { 651 out.shortSample(in) 652 653 var sum uint16 654 for i := 0; i < N-1; i++ { 655 sum += mod3ResultToModQ(out[i] * out[i+1]) 656 } 657 658 scale := 1 + (1 & (sum >> 12)) 659 for i := 0; i < len(out); i += 2 { 660 out[i] = (out[i] * scale) % 3 661 } 662} 663 664func mul(out, scratch, a, b []uint16) { 665 const schoolbookLimit = 32 666 if len(a) < schoolbookLimit { 667 for i := 0; i < len(a)*2; i++ { 668 out[i] = 0 669 } 670 for i := range a { 671 for j := range b { 672 out[i+j] += a[i] * b[j] 673 } 674 } 675 return 676 } 677 678 lowLen := len(a) / 2 679 highLen := len(a) - lowLen 680 aLow, aHigh := a[:lowLen], a[lowLen:] 681 bLow, bHigh := b[:lowLen], b[lowLen:] 682 683 for i := 0; i < lowLen; i++ { 684 out[i] = aHigh[i] + aLow[i] 685 } 686 if highLen != lowLen { 687 out[lowLen] = aHigh[lowLen] 688 } 689 690 for i := 0; i < lowLen; i++ { 691 out[highLen+i] = bHigh[i] + bLow[i] 692 } 693 if highLen != lowLen { 694 out[highLen+lowLen] = bHigh[lowLen] 695 } 696 697 mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2]) 698 mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh) 699 mul(out, scratch[2*highLen:], aLow, bLow) 700 701 for i := 0; i < lowLen*2; i++ { 702 scratch[i] -= out[i] + out[lowLen*2+i] 703 } 704 if lowLen != highLen { 705 scratch[lowLen*2] -= out[lowLen*4] 706 } 707 708 for i := 0; i < 2*highLen; i++ { 709 out[lowLen+i] += scratch[i] 710 } 711} 712 713func (out *poly) mul(a, b *poly) { 714 var prod, scratch [2 * N]uint16 715 mul(prod[:], scratch[:], a[:], b[:]) 716 for i := range out { 717 out[i] = (prod[i] + prod[i+N]) % Q 718 } 719} 720 721func (p3 *poly3) mulMod3(x, y *poly3) { 722 // (^n - 1) is a multiple of Φ(N) so we can work mod (^n - 1) here and 723 // (reduce mod Φ(N) afterwards. 724 x3 := *x 725 y3 := *y 726 s := x3.s[:] 727 a := x3.a[:] 728 sw := s[0] 729 aw := a[0] 730 p3.zero() 731 var shift uint 732 for i := 0; i < N; i++ { 733 p3.fmadd(sw, aw, &y3) 734 sw >>= 1 735 aw >>= 1 736 shift++ 737 if shift == bitsPerWord { 738 s = s[1:] 739 a = a[1:] 740 sw = s[0] 741 aw = a[0] 742 shift = 0 743 } 744 y3.mulx() 745 } 746 p3.modPhiN() 747} 748 749// mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff} 750// The case of n == 3 should never happen but is included so that modQToMod3 751// can easily catch invalid inputs. 752func mod3ToModQ(n uint16) uint16 { 753 return uint16(uint64(0xffff1fff00010000) >> (16 * n)) 754} 755 756// modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int 757// which is one if the input is in range and zero otherwise. 758func modQToMod3(n uint16) (uint16, int) { 759 result := (n&3 - (n>>1)&1) 760 return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n)) 761} 762 763// mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1} 764func mod3ResultToModQ(n uint16) uint16 { 765 return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1) 766 //shift := (uint(0x324) >> (2 * n)) & 3 767 //return uint16(uint64(0x00011fff00010000) >> (16 * shift)) 768} 769 770// mulXMinus1 sets out to a×( - 1) mod (^n - 1) 771func (out *poly) mulXMinus1() { 772 // Multiplying by ( - 1) means negating each coefficient and adding in 773 // the value of the previous one. 774 origOut700 := out[700] 775 776 for i := N - 1; i > 0; i-- { 777 out[i] = (Q - out[i] + out[i-1]) % Q 778 } 779 out[0] = (Q - out[0] + origOut700) % Q 780} 781 782func (out *poly) lift(a *poly) { 783 // We wish to calculate a/(-1) mod Φ(N) over GF(3), where Φ(N) is the 784 // Nth cyclotomic polynomial, i.e. 1 + + … + ^700 (since N is prime). 785 786 // 1/(-1) has a fairly basic structure that we can exploit to speed this up: 787 // 788 // R.<x> = PolynomialRing(GF(3)…) 789 // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n)) 790 // list(inv)[:15] 791 // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2] 792 // 793 // This three-element pattern of coefficients repeats for the whole 794 // polynomial. 795 // 796 // Next define the overbar operator such that z̅ = z[0] + 797 // reverse(z[1:]). (Index zero of a polynomial here is the coefficient 798 // of the constant term. So index one is the coefficient of and so 799 // on.) 800 // 801 // A less odd way to define this is to see that z̅ negates the indexes, 802 // so z̅[0] = z[-0], z̅[1] = z[-1] and so on. 803 // 804 // The use of z̅ is that, when working mod (^701 - 1), vz[0] = <v, 805 // z̅>, vz[1] = <v, z̅>, …. (Where <a, b> is the inner product: the sum 806 // of the point-wise products.) Although we calculated the inverse mod 807 // Φ(N), we can work mod (^N - 1) and reduce mod Φ(N) at the end. 808 // (That's because (^N - 1) is a multiple of Φ(N).) 809 // 810 // When working mod (^N - 1), multiplication by is a right-rotation 811 // of the list of coefficients. 812 // 813 // Thus we can consider what the pattern of z̅, z̅, ^2z̅, … looks like: 814 // 815 // def reverse(xs): 816 // suffix = list(xs[1:]) 817 // suffix.reverse() 818 // return [xs[0]] + suffix 819 // 820 // def rotate(xs): 821 // return [xs[-1]] + xs[:-1] 822 // 823 // zoverbar = reverse(list(inv) + [0]) 824 // xzoverbar = rotate(reverse(list(inv) + [0])) 825 // x2zoverbar = rotate(rotate(reverse(list(inv) + [0]))) 826 // 827 // zoverbar[:15] 828 // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1] 829 // xzoverbar[:15] 830 // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0] 831 // x2zoverbar[:15] 832 // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] 833 // 834 // (For a formula for z̅, see lemma two of appendix B.) 835 // 836 // After the first three elements have been taken care of, all then have 837 // a repeating three-element cycle. The next value (^3z̅) involves 838 // three rotations of the first pattern, thus the three-element cycle 839 // lines up. However, the discontinuity in the first three elements 840 // obviously moves to a different position. Consider the difference 841 // between ^3z̅ and z̅: 842 // 843 // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15] 844 // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 845 // 846 // This pattern of differences is the same for all elements, although it 847 // obviously moves right with the rotations. 848 // 849 // From this, we reach algorithm eight of appendix B. 850 851 // Handle the first three elements of the inner products. 852 out[0] = a[0] + a[2] 853 out[1] = a[1] 854 out[2] = 2*a[0] + a[2] 855 856 // Use the repeating pattern to complete the first three inner products. 857 for i := 3; i < 699; i += 3 { 858 out[0] += 2*a[i] + a[i+2] 859 out[1] += a[i] + 2*a[i+1] 860 out[2] += a[i+1] + 2*a[i+2] 861 } 862 863 // Handle the fact that the three-element pattern doesn't fill the 864 // polynomial exactly (since 701 isn't a multiple of three). 865 out[2] += a[700] 866 out[0] += 2 * a[699] 867 out[1] += a[699] + 2*a[700] 868 869 out[0] = out[0] % 3 870 out[1] = out[1] % 3 871 out[2] = out[2] % 3 872 873 // Calculate the remaining inner products by taking advantage of the 874 // fact that the pattern repeats every three cycles and the pattern of 875 // differences is moves with the rotation. 876 for i := 3; i < N; i++ { 877 // Add twice something is the same as subtracting when working 878 // mod 3. Doing it this way avoids underflow. Underflow is bad 879 // because "% 3" doesn't work correctly for negative numbers 880 // here since underflow will wrap to 2^16-1 and 2^16 isn't a 881 // multiple of three. 882 out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3 883 } 884 885 // Reduce mod Φ(N) by subtracting a multiple of out[700] from every 886 // element and convert to mod Q. (See above about adding twice as 887 // subtraction.) 888 v := out[700] * 2 889 for i := range out { 890 out[i] = mod3ToModQ((out[i] + v) % 3) 891 } 892 893 out.mulXMinus1() 894} 895 896func (a *poly) cswap(b *poly, swap uint16) { 897 for i := range a { 898 sum := swap & (a[i] ^ b[i]) 899 a[i] ^= sum 900 b[i] ^= sum 901 } 902} 903 904func lt(a, b uint) uint { 905 if a < b { 906 return ^uint(0) 907 } 908 return 0 909} 910 911func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) { 912 s3 = (a1 & s2) ^ (s1 & a2) 913 a3 = (a1 & a2) ^ (s1 & s2) 914 return 915} 916 917func (out *poly3) invertMod3(in *poly3) { 918 // This algorithm follows algorithm 10 in the paper. (Although note that 919 // the paper appears to have a bug: k should start at zero, not one.) 920 // The best explanation for why it works is in the "Why it works" 921 // section of 922 // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. 923 var k uint 924 degF, degG := uint(N-1), uint(N-1) 925 926 var b, c, g poly3 927 f := *in 928 929 for i := range g.a { 930 g.a[i] = ^uint(0) 931 } 932 933 b.a[0] = 1 934 935 var f0s, f0a uint 936 stillGoing := ^uint(0) 937 for i := 0; i < 2*(N-1)-1; i++ { 938 ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0]) 939 ss, sa = sa&stillGoing&1, ss&stillGoing&1 940 shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG) 941 f.cswap(&g, shouldSwap) 942 b.cswap(&c, shouldSwap) 943 degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) 944 f.fmadd(ss, sa, &g) 945 b.fmadd(ss, sa, &c) 946 947 f.divx() 948 f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 949 f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 950 c.mulx() 951 c.s[0] &= ^uint(1) 952 c.a[0] &= ^uint(1) 953 954 degF-- 955 k += 1 & stillGoing 956 f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s) 957 f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a) 958 stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) 959 } 960 961 k -= N & lt(N, k) 962 *out = b 963 out.rot(k) 964 out.mulConst(f0s, f0a) 965 out.modPhiN() 966} 967 968func (out *poly) invertMod2(a *poly) { 969 // This algorithm follows mix of algorithm 10 in the paper and the first 970 // page of the PDF linked below. (Although note that the paper appears 971 // to have a bug: k should start at zero, not one.) The best explanation 972 // for why it works is in the "Why it works" section of 973 // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. 974 var k uint 975 degF, degG := uint(N-1), uint(N-1) 976 977 var f poly2 978 f.fromDiscrete(a) 979 var b, c, g poly2 980 g.setPhiN() 981 b[0] = 1 982 983 stillGoing := ^uint(0) 984 for i := 0; i < 2*(N-1)-1; i++ { 985 s := uint(f[0]&1) & stillGoing 986 shouldSwap := ^(s - 1) & lt(degF, degG) 987 f.cswap(&g, shouldSwap) 988 b.cswap(&c, shouldSwap) 989 degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) 990 f.fmadd(s, &g) 991 b.fmadd(s, &c) 992 993 f.rshift1() 994 c.lshift1() 995 996 degF-- 997 k += 1 & stillGoing 998 stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) 999 } 1000 1001 k -= N & lt(N, k) 1002 b.rot(k) 1003 out.fromMod2(&b) 1004} 1005 1006func (out *poly) invert(origA *poly) { 1007 // Inversion mod Q, which is done based on the result of inverting mod 1008 // 2. See the NTRU paper, page three. 1009 var a, tmp, tmp2, b poly 1010 b.invertMod2(origA) 1011 1012 // Negate a. 1013 for i := range a { 1014 a[i] = Q - origA[i] 1015 } 1016 1017 // We are working mod Q=2**13 and we need to iterate ceil(log_2(13)) 1018 // times, which is four. 1019 for i := 0; i < 4; i++ { 1020 tmp.mul(&a, &b) 1021 tmp[0] += 2 1022 tmp2.mul(&b, &tmp) 1023 b = tmp2 1024 } 1025 1026 *out = b 1027} 1028 1029type PublicKey struct { 1030 h poly 1031} 1032 1033func ParsePublicKey(in []byte) (*PublicKey, bool) { 1034 ret := new(PublicKey) 1035 if !ret.h.unmarshal(in) { 1036 return nil, false 1037 } 1038 return ret, true 1039} 1040 1041func (pub *PublicKey) Marshal() []byte { 1042 ret := make([]byte, modQBytes) 1043 pub.h.marshal(ret) 1044 return ret 1045} 1046 1047func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) { 1048 var randBytes [352 + 352]byte 1049 if _, err := io.ReadFull(rand, randBytes[:]); err != nil { 1050 panic("rand failed") 1051 } 1052 1053 var m, r poly 1054 m.shortSample(randBytes[:352]) 1055 r.shortSample(randBytes[352:]) 1056 1057 var mBytes, rBytes [mod3Bytes]byte 1058 m.marshalS3(mBytes[:]) 1059 r.marshalS3(rBytes[:]) 1060 1061 ciphertext = pub.owf(&m, &r) 1062 1063 h := sha256.New() 1064 h.Write([]byte("shared key\x00")) 1065 h.Write(mBytes[:]) 1066 h.Write(rBytes[:]) 1067 h.Write(ciphertext) 1068 sharedKey = h.Sum(nil) 1069 1070 return ciphertext, sharedKey 1071} 1072 1073func (pub *PublicKey) owf(m, r *poly) []byte { 1074 for i := range r { 1075 r[i] = mod3ToModQ(r[i]) 1076 } 1077 1078 var mq poly 1079 mq.lift(m) 1080 1081 var e poly 1082 e.mul(r, &pub.h) 1083 for i := range e { 1084 e[i] = (e[i] + mq[i]) % Q 1085 } 1086 1087 ret := make([]byte, modQBytes) 1088 e.marshal(ret[:]) 1089 return ret 1090} 1091 1092type PrivateKey struct { 1093 PublicKey 1094 f, fp poly3 1095 hInv poly 1096 hmacKey [32]byte 1097} 1098 1099func (priv *PrivateKey) Marshal() []byte { 1100 var ret [2*mod3Bytes + modQBytes]byte 1101 priv.f.marshal(ret[:]) 1102 priv.fp.marshal(ret[mod3Bytes:]) 1103 priv.h.marshal(ret[2*mod3Bytes:]) 1104 return ret[:] 1105} 1106 1107func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) { 1108 if len(ciphertext) != modQBytes { 1109 return nil, false 1110 } 1111 1112 var e poly 1113 if !e.unmarshal(ciphertext) { 1114 return nil, false 1115 } 1116 1117 var f poly 1118 f.fromMod3ToModQ(&priv.f) 1119 1120 var v1, m poly 1121 v1.mul(&e, &f) 1122 1123 var v13 poly3 1124 v13.fromDiscreteMod3(&v1) 1125 // Note: v13 is not reduced mod phi(n). 1126 1127 var m3 poly3 1128 m3.mulMod3(&v13, &priv.fp) 1129 m3.modPhiN() 1130 m.fromMod3(&m3) 1131 1132 var mLift, delta poly 1133 mLift.lift(&m) 1134 for i := range delta { 1135 delta[i] = (e[i] - mLift[i] + Q) % Q 1136 } 1137 delta.mul(&delta, &priv.hInv) 1138 delta.modPhiN() 1139 1140 var r poly3 1141 allOk := r.fromModQ(&delta) 1142 1143 var mBytes, rBytes [mod3Bytes]byte 1144 m.marshalS3(mBytes[:]) 1145 r.marshal(rBytes[:]) 1146 1147 var rPoly poly 1148 rPoly.fromMod3(&r) 1149 expectedCiphertext := priv.PublicKey.owf(&m, &rPoly) 1150 1151 allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext) 1152 1153 hmacHash := hmac.New(sha256.New, priv.hmacKey[:]) 1154 hmacHash.Write(ciphertext) 1155 hmacDigest := hmacHash.Sum(nil) 1156 1157 h := sha256.New() 1158 h.Write([]byte("shared key\x00")) 1159 h.Write(mBytes[:]) 1160 h.Write(rBytes[:]) 1161 h.Write(ciphertext) 1162 sharedKey = h.Sum(nil) 1163 1164 mask := uint8(allOk - 1) 1165 for i := range sharedKey { 1166 sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask) 1167 } 1168 1169 return sharedKey, true 1170} 1171 1172func GenerateKey(rand io.Reader) PrivateKey { 1173 var randBytes [352 + 352]byte 1174 if _, err := io.ReadFull(rand, randBytes[:]); err != nil { 1175 panic("rand failed") 1176 } 1177 1178 var f poly 1179 f.shortSamplePlus(randBytes[:352]) 1180 var priv PrivateKey 1181 priv.f.fromDiscrete(&f) 1182 priv.fp.invertMod3(&priv.f) 1183 1184 var g poly 1185 g.shortSamplePlus(randBytes[352:]) 1186 1187 var pgPhi1 poly 1188 for i := range g { 1189 pgPhi1[i] = mod3ToModQ(g[i]) 1190 } 1191 for i := range pgPhi1 { 1192 pgPhi1[i] = (pgPhi1[i] * 3) % Q 1193 } 1194 pgPhi1.mulXMinus1() 1195 1196 var fModQ poly 1197 fModQ.fromMod3ToModQ(&priv.f) 1198 1199 var pfgPhi1 poly 1200 pfgPhi1.mul(&fModQ, &pgPhi1) 1201 1202 var i poly 1203 i.invert(&pfgPhi1) 1204 1205 priv.h.mul(&i, &pgPhi1) 1206 priv.h.mul(&priv.h, &pgPhi1) 1207 1208 priv.hInv.mul(&i, &fModQ) 1209 priv.hInv.mul(&priv.hInv, &fModQ) 1210 1211 return priv 1212} 1213