1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.net.cts; 18 19 import static android.system.OsConstants.IPPROTO_IPV6; 20 import static android.system.OsConstants.IPPROTO_UDP; 21 22 import android.util.ArraySet; 23 24 import com.android.internal.net.ipsec.ike.crypto.AesXCbcImpl; 25 26 import java.net.Inet4Address; 27 import java.net.Inet6Address; 28 import java.net.InetAddress; 29 import java.nio.ByteBuffer; 30 import java.nio.ShortBuffer; 31 import java.security.GeneralSecurityException; 32 import java.security.SecureRandom; 33 import java.util.Arrays; 34 import java.util.Set; 35 36 import javax.crypto.Cipher; 37 import javax.crypto.Mac; 38 import javax.crypto.spec.IvParameterSpec; 39 import javax.crypto.spec.SecretKeySpec; 40 41 public class PacketUtils { 42 private static final String TAG = PacketUtils.class.getSimpleName(); 43 44 private static final int DATA_BUFFER_LEN = 4096; 45 46 static final int IP4_HDRLEN = 20; 47 static final int IP6_HDRLEN = 40; 48 static final int UDP_HDRLEN = 8; 49 static final int TCP_HDRLEN = 20; 50 static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12; 51 static final int ESP_HDRLEN = 8; 52 static final int ESP_BLK_SIZE = 4; // ESP has to be 4-byte aligned 53 static final int ESP_TRAILER_LEN = 2; 54 55 // Not defined in OsConstants 56 static final int IPPROTO_IPV4 = 4; 57 static final int IPPROTO_ESP = 50; 58 59 // Encryption parameters 60 static final int AES_CBC_IV_LEN = 16; 61 static final int AES_CBC_BLK_SIZE = 16; 62 static final int AES_CTR_SALT_LEN = 4; 63 64 static final int AES_CTR_KEY_LEN_20 = 20; 65 static final int AES_CTR_KEY_LEN_28 = 28; 66 static final int AES_CTR_KEY_LEN_36 = 36; 67 static final int AES_CTR_BLK_SIZE = ESP_BLK_SIZE; 68 static final int AES_CTR_IV_LEN = 8; 69 70 // AEAD parameters 71 static final int AES_GCM_IV_LEN = 8; 72 static final int AES_GCM_BLK_SIZE = 4; 73 static final int CHACHA20_POLY1305_KEY_LEN = 36; 74 static final int CHACHA20_POLY1305_BLK_SIZE = ESP_BLK_SIZE; 75 static final int CHACHA20_POLY1305_IV_LEN = 8; 76 static final int CHACHA20_POLY1305_SALT_LEN = 4; 77 static final int CHACHA20_POLY1305_ICV_LEN = 16; 78 79 // Authentication parameters 80 static final int HMAC_SHA256_ICV_LEN = 16; 81 static final int HMAC_SHA512_KEY_LEN = 64; 82 static final int HMAC_SHA512_ICV_LEN = 32; 83 static final int AES_XCBC_KEY_LEN = 16; 84 static final int AES_XCBC_ICV_LEN = 12; 85 static final int AES_CMAC_KEY_LEN = 16; 86 static final int AES_CMAC_ICV_LEN = 12; 87 88 // Block counter field should be 32 bits and starts from value one as per RFC 3686 89 static final byte[] AES_CTR_INITIAL_COUNTER = new byte[] {0x00, 0x00, 0x00, 0x01}; 90 91 // Encryption algorithms 92 static final String AES = "AES"; 93 static final String AES_CBC = "AES/CBC/NoPadding"; 94 static final String AES_CTR = "AES/CTR/NoPadding"; 95 96 // AEAD algorithms 97 static final String CHACHA20_POLY1305 = "ChaCha20/Poly1305/NoPadding"; 98 99 // Authentication algorithms 100 static final String HMAC_MD5 = "HmacMD5"; 101 static final String HMAC_SHA1 = "HmacSHA1"; 102 static final String HMAC_SHA_256 = "HmacSHA256"; 103 static final String HMAC_SHA_384 = "HmacSHA384"; 104 static final String HMAC_SHA_512 = "HmacSHA512"; 105 static final String AES_CMAC = "AESCMAC"; 106 static final String AES_XCBC = "AesXCbc"; 107 108 public interface Payload { getPacketBytes(IpHeader header)109 byte[] getPacketBytes(IpHeader header) throws Exception; 110 addPacketBytes(IpHeader header, ByteBuffer resultBuffer)111 void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception; 112 length()113 short length(); 114 getProtocolId()115 int getProtocolId(); 116 } 117 118 public abstract static class IpHeader { 119 120 public final byte proto; 121 public final InetAddress srcAddr; 122 public final InetAddress dstAddr; 123 public final Payload payload; 124 IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload)125 public IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload) { 126 this.proto = (byte) proto; 127 this.srcAddr = src; 128 this.dstAddr = dst; 129 this.payload = payload; 130 } 131 getPacketBytes()132 public abstract byte[] getPacketBytes() throws Exception; 133 getProtocolId()134 public abstract int getProtocolId(); 135 } 136 137 public static class Ip4Header extends IpHeader { 138 private short checksum; 139 Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload)140 public Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload) { 141 super(proto, src, dst, payload); 142 } 143 getPacketBytes()144 public byte[] getPacketBytes() throws Exception { 145 ByteBuffer resultBuffer = buildHeader(); 146 payload.addPacketBytes(this, resultBuffer); 147 148 return getByteArrayFromBuffer(resultBuffer); 149 } 150 buildHeader()151 public ByteBuffer buildHeader() { 152 ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN); 153 154 // Version, IHL 155 bb.put((byte) (0x45)); 156 157 // DCSP, ECN 158 bb.put((byte) 0); 159 160 // Total Length 161 bb.putShort((short) (IP4_HDRLEN + payload.length())); 162 163 // Empty for Identification, Flags and Fragment Offset 164 bb.putShort((short) 0); 165 bb.put((byte) 0x40); 166 bb.put((byte) 0x00); 167 168 // TTL 169 bb.put((byte) 64); 170 171 // Protocol 172 bb.put(proto); 173 174 // Header Checksum 175 final int ipChecksumOffset = bb.position(); 176 bb.putShort((short) 0); 177 178 // Src/Dst addresses 179 bb.put(srcAddr.getAddress()); 180 bb.put(dstAddr.getAddress()); 181 182 bb.putShort(ipChecksumOffset, calculateChecksum(bb)); 183 184 return bb; 185 } 186 calculateChecksum(ByteBuffer bb)187 private short calculateChecksum(ByteBuffer bb) { 188 int checksum = 0; 189 190 // Calculate sum of 16-bit values, excluding checksum. IPv4 headers are always 32-bit 191 // aligned, so no special cases needed for unaligned values. 192 ShortBuffer shortBuffer = ByteBuffer.wrap(getByteArrayFromBuffer(bb)).asShortBuffer(); 193 while (shortBuffer.hasRemaining()) { 194 short val = shortBuffer.get(); 195 196 // Wrap as needed 197 checksum = addAndWrapForChecksum(checksum, val); 198 } 199 200 return onesComplement(checksum); 201 } 202 getProtocolId()203 public int getProtocolId() { 204 return IPPROTO_IPV4; 205 } 206 } 207 208 public static class Ip6Header extends IpHeader { Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload)209 public Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload) { 210 super(nextHeader, src, dst, payload); 211 } 212 getPacketBytes()213 public byte[] getPacketBytes() throws Exception { 214 ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN); 215 216 // Version | Traffic Class (First 4 bits) 217 bb.put((byte) 0x60); 218 219 // Traffic class (Last 4 bits), Flow Label 220 bb.put((byte) 0); 221 bb.put((byte) 0); 222 bb.put((byte) 0); 223 224 // Payload Length 225 bb.putShort((short) payload.length()); 226 227 // Next Header 228 bb.put(proto); 229 230 // Hop Limit 231 bb.put((byte) 64); 232 233 // Src/Dst addresses 234 bb.put(srcAddr.getAddress()); 235 bb.put(dstAddr.getAddress()); 236 237 // Payload 238 payload.addPacketBytes(this, bb); 239 240 return getByteArrayFromBuffer(bb); 241 } 242 getProtocolId()243 public int getProtocolId() { 244 return IPPROTO_IPV6; 245 } 246 } 247 248 public static class BytePayload implements Payload { 249 public final byte[] payload; 250 BytePayload(byte[] payload)251 public BytePayload(byte[] payload) { 252 this.payload = payload; 253 } 254 getProtocolId()255 public int getProtocolId() { 256 return -1; 257 } 258 getPacketBytes(IpHeader header)259 public byte[] getPacketBytes(IpHeader header) { 260 ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN); 261 262 addPacketBytes(header, bb); 263 return getByteArrayFromBuffer(bb); 264 } 265 addPacketBytes(IpHeader header, ByteBuffer resultBuffer)266 public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) { 267 resultBuffer.put(payload); 268 } 269 length()270 public short length() { 271 return (short) payload.length; 272 } 273 } 274 275 public static class UdpHeader implements Payload { 276 277 public final short srcPort; 278 public final short dstPort; 279 public final Payload payload; 280 UdpHeader(int srcPort, int dstPort, Payload payload)281 public UdpHeader(int srcPort, int dstPort, Payload payload) { 282 this.srcPort = (short) srcPort; 283 this.dstPort = (short) dstPort; 284 this.payload = payload; 285 } 286 getProtocolId()287 public int getProtocolId() { 288 return IPPROTO_UDP; 289 } 290 length()291 public short length() { 292 return (short) (payload.length() + 8); 293 } 294 getPacketBytes(IpHeader header)295 public byte[] getPacketBytes(IpHeader header) throws Exception { 296 ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN); 297 298 addPacketBytes(header, bb); 299 return getByteArrayFromBuffer(bb); 300 } 301 addPacketBytes(IpHeader header, ByteBuffer resultBuffer)302 public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception { 303 // Source, Destination port 304 resultBuffer.putShort(srcPort); 305 resultBuffer.putShort(dstPort); 306 307 // Payload Length 308 resultBuffer.putShort(length()); 309 310 // Get payload bytes for checksum + payload 311 ByteBuffer payloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN); 312 payload.addPacketBytes(header, payloadBuffer); 313 byte[] payloadBytes = getByteArrayFromBuffer(payloadBuffer); 314 315 // Checksum 316 resultBuffer.putShort(calculateChecksum(header, payloadBytes)); 317 318 // Payload 319 resultBuffer.put(payloadBytes); 320 } 321 calculateChecksum(IpHeader header, byte[] payloadBytes)322 private short calculateChecksum(IpHeader header, byte[] payloadBytes) throws Exception { 323 int newChecksum = 0; 324 ShortBuffer srcBuffer = ByteBuffer.wrap(header.srcAddr.getAddress()).asShortBuffer(); 325 ShortBuffer dstBuffer = ByteBuffer.wrap(header.dstAddr.getAddress()).asShortBuffer(); 326 327 while (srcBuffer.hasRemaining() || dstBuffer.hasRemaining()) { 328 short val = srcBuffer.hasRemaining() ? srcBuffer.get() : dstBuffer.get(); 329 330 // Wrap as needed 331 newChecksum = addAndWrapForChecksum(newChecksum, val); 332 } 333 334 // Add pseudo-header values. Proto is 0-padded, so just use the byte. 335 newChecksum = addAndWrapForChecksum(newChecksum, header.proto); 336 newChecksum = addAndWrapForChecksum(newChecksum, length()); 337 newChecksum = addAndWrapForChecksum(newChecksum, srcPort); 338 newChecksum = addAndWrapForChecksum(newChecksum, dstPort); 339 newChecksum = addAndWrapForChecksum(newChecksum, length()); 340 341 ShortBuffer payloadShortBuffer = ByteBuffer.wrap(payloadBytes).asShortBuffer(); 342 while (payloadShortBuffer.hasRemaining()) { 343 newChecksum = addAndWrapForChecksum(newChecksum, payloadShortBuffer.get()); 344 } 345 if (payload.length() % 2 != 0) { 346 newChecksum = 347 addAndWrapForChecksum( 348 newChecksum, (payloadBytes[payloadBytes.length - 1] << 8)); 349 } 350 351 return onesComplement(newChecksum); 352 } 353 } 354 355 public static class EspHeader implements Payload { 356 public final int nextHeader; 357 public final int spi; 358 public final int seqNum; 359 public final byte[] payload; 360 public final EspCipher cipher; 361 public final EspAuth auth; 362 363 /** 364 * Generic constructor for ESP headers. 365 * 366 * <p>For Tunnel mode, payload will be a full IP header + attached payloads 367 * 368 * <p>For Transport mode, payload will be only the attached payloads, but with the checksum 369 * calculated using the pre-encryption IP header 370 */ EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload)371 public EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload) { 372 this(nextHeader, spi, seqNum, payload, getDefaultCipher(key), getDefaultAuth(key)); 373 } 374 375 /** 376 * Generic constructor for ESP headers that allows configuring encryption and authentication 377 * algortihms. 378 * 379 * <p>For Tunnel mode, payload will be a full IP header + attached payloads 380 * 381 * <p>For Transport mode, payload will be only the attached payloads, but with the checksum 382 * calculated using the pre-encryption IP header 383 */ EspHeader( int nextHeader, int spi, int seqNum, byte[] payload, EspCipher cipher, EspAuth auth)384 public EspHeader( 385 int nextHeader, 386 int spi, 387 int seqNum, 388 byte[] payload, 389 EspCipher cipher, 390 EspAuth auth) { 391 this.nextHeader = nextHeader; 392 this.spi = spi; 393 this.seqNum = seqNum; 394 this.payload = payload; 395 this.cipher = cipher; 396 this.auth = auth; 397 398 if (cipher instanceof EspCipherNull && auth instanceof EspAuthNull) { 399 throw new IllegalArgumentException("No algorithm is provided"); 400 } 401 402 if (cipher instanceof EspAeadCipher && !(auth instanceof EspAuthNull)) { 403 throw new IllegalArgumentException( 404 "AEAD is provided with an authentication" + " algorithm."); 405 } 406 } 407 getDefaultCipher(byte[] key)408 private static EspCipher getDefaultCipher(byte[] key) { 409 return new EspCryptCipher(AES_CBC, AES_CBC_BLK_SIZE, key, AES_CBC_IV_LEN); 410 } 411 getDefaultAuth(byte[] key)412 private static EspAuth getDefaultAuth(byte[] key) { 413 return new EspAuth(HMAC_SHA_256, key, HMAC_SHA256_ICV_LEN); 414 } 415 getProtocolId()416 public int getProtocolId() { 417 return IPPROTO_ESP; 418 } 419 length()420 public short length() { 421 final int icvLen = 422 cipher instanceof EspAeadCipher ? ((EspAeadCipher) cipher).icvLen : auth.icvLen; 423 return calculateEspPacketSize( 424 payload.length, cipher.ivLen, cipher.blockSize, icvLen * 8); 425 } 426 getPacketBytes(IpHeader header)427 public byte[] getPacketBytes(IpHeader header) throws Exception { 428 ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN); 429 430 addPacketBytes(header, bb); 431 return getByteArrayFromBuffer(bb); 432 } 433 addPacketBytes(IpHeader header, ByteBuffer resultBuffer)434 public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception { 435 ByteBuffer espPayloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN); 436 espPayloadBuffer.putInt(spi); 437 espPayloadBuffer.putInt(seqNum); 438 439 espPayloadBuffer.put(cipher.getCipherText(nextHeader, payload, spi, seqNum)); 440 espPayloadBuffer.put(auth.getIcv(getByteArrayFromBuffer(espPayloadBuffer))); 441 442 resultBuffer.put(getByteArrayFromBuffer(espPayloadBuffer)); 443 } 444 } 445 addAndWrapForChecksum(int currentChecksum, int value)446 private static int addAndWrapForChecksum(int currentChecksum, int value) { 447 currentChecksum += value & 0x0000ffff; 448 449 // Wrap anything beyond the first 16 bits, and add to lower order bits 450 return (currentChecksum >>> 16) + (currentChecksum & 0x0000ffff); 451 } 452 onesComplement(int val)453 private static short onesComplement(int val) { 454 val = (val >>> 16) + (val & 0xffff); 455 456 if (val == 0) return 0; 457 return (short) ((~val) & 0xffff); 458 } 459 calculateEspPacketSize( int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen)460 public static short calculateEspPacketSize( 461 int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) { 462 final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length 463 464 // Align to block size of encryption algorithm 465 payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize); 466 payloadLen += cryptIvLength; // Initialization Vector 467 return (short) (payloadLen + ESP_HDRLEN + ICV_LEN); 468 } 469 calculateEspEncryptedLength(int payloadLen, int cryptBlockSize)470 private static int calculateEspEncryptedLength(int payloadLen, int cryptBlockSize) { 471 payloadLen += 2; // ESP trailer 472 473 // Align to block size of encryption algorithm 474 return payloadLen + calculateEspPadLen(payloadLen, cryptBlockSize); 475 } 476 calculateEspPadLen(int payloadLen, int cryptBlockSize)477 private static int calculateEspPadLen(int payloadLen, int cryptBlockSize) { 478 return (cryptBlockSize - (payloadLen % cryptBlockSize)) % cryptBlockSize; 479 } 480 getByteArrayFromBuffer(ByteBuffer buffer)481 private static byte[] getByteArrayFromBuffer(ByteBuffer buffer) { 482 return Arrays.copyOfRange(buffer.array(), 0, buffer.position()); 483 } 484 getIpHeader( int protocol, InetAddress src, InetAddress dst, Payload payload)485 public static IpHeader getIpHeader( 486 int protocol, InetAddress src, InetAddress dst, Payload payload) { 487 if ((src instanceof Inet6Address) != (dst instanceof Inet6Address)) { 488 throw new IllegalArgumentException("Invalid src/dst address combination"); 489 } 490 491 if (src instanceof Inet6Address) { 492 return new Ip6Header(protocol, (Inet6Address) src, (Inet6Address) dst, payload); 493 } else { 494 return new Ip4Header(protocol, (Inet4Address) src, (Inet4Address) dst, payload); 495 } 496 } 497 498 public abstract static class EspCipher { 499 protected static final int SALT_LEN_UNUSED = 0; 500 501 public final String algoName; 502 public final int blockSize; 503 public final byte[] key; 504 public final int ivLen; 505 public final int saltLen; 506 protected byte[] mIv; 507 EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen)508 public EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) { 509 this.algoName = algoName; 510 this.blockSize = blockSize; 511 this.key = key; 512 this.ivLen = ivLen; 513 this.saltLen = saltLen; 514 this.mIv = getIv(ivLen); 515 } 516 updateIv(byte[] iv)517 public void updateIv(byte[] iv) { 518 this.mIv = iv; 519 } 520 getPaddedPayload(int nextHeader, byte[] payload, int blockSize)521 public static byte[] getPaddedPayload(int nextHeader, byte[] payload, int blockSize) { 522 final int paddedLen = calculateEspEncryptedLength(payload.length, blockSize); 523 final ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen); 524 paddedPayload.put(payload); 525 526 // Add padding - consecutive integers from 0x01 527 byte pad = 1; 528 while (paddedPayload.position() < paddedPayload.limit() - ESP_TRAILER_LEN) { 529 paddedPayload.put((byte) pad++); 530 } 531 532 // Add padding length and next header 533 paddedPayload.put((byte) (paddedLen - ESP_TRAILER_LEN - payload.length)); 534 paddedPayload.put((byte) nextHeader); 535 536 return getByteArrayFromBuffer(paddedPayload); 537 } 538 getIv(int ivLen)539 private static byte[] getIv(int ivLen) { 540 final byte[] iv = new byte[ivLen]; 541 new SecureRandom().nextBytes(iv); 542 return iv; 543 } 544 getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)545 public abstract byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum) 546 throws GeneralSecurityException; 547 } 548 549 public static final class EspCipherNull extends EspCipher { 550 private static final String CRYPT_NULL = "CRYPT_NULL"; 551 private static final int IV_LEN_UNUSED = 0; 552 private static final byte[] KEY_UNUSED = new byte[0]; 553 554 private static final EspCipherNull sInstance = new EspCipherNull(); 555 EspCipherNull()556 private EspCipherNull() { 557 super(CRYPT_NULL, ESP_BLK_SIZE, KEY_UNUSED, IV_LEN_UNUSED, SALT_LEN_UNUSED); 558 } 559 getInstance()560 public static EspCipherNull getInstance() { 561 return sInstance; 562 } 563 564 @Override getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)565 public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum) 566 throws GeneralSecurityException { 567 return getPaddedPayload(nextHeader, payload, blockSize); 568 } 569 } 570 571 public static final class EspCryptCipher extends EspCipher { EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen)572 public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen) { 573 this(algoName, blockSize, key, ivLen, SALT_LEN_UNUSED); 574 } 575 EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen)576 public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) { 577 super(algoName, blockSize, key, ivLen, saltLen); 578 } 579 580 @Override getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)581 public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum) 582 throws GeneralSecurityException { 583 final IvParameterSpec ivParameterSpec; 584 final SecretKeySpec secretKeySpec; 585 586 if (AES_CBC.equals(algoName)) { 587 ivParameterSpec = new IvParameterSpec(mIv); 588 secretKeySpec = new SecretKeySpec(key, algoName); 589 } else if (AES_CTR.equals(algoName)) { 590 // Provided key consists of encryption/decryption key plus 4-byte salt. Salt is used 591 // with ESP payload IV and initial block counter value to build IvParameterSpec. 592 final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen); 593 final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length); 594 secretKeySpec = new SecretKeySpec(secretKey, algoName); 595 596 final ByteBuffer ivParameterBuffer = 597 ByteBuffer.allocate(mIv.length + saltLen + AES_CTR_INITIAL_COUNTER.length); 598 ivParameterBuffer.put(salt); 599 ivParameterBuffer.put(mIv); 600 ivParameterBuffer.put(AES_CTR_INITIAL_COUNTER); 601 ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array()); 602 } else { 603 throw new IllegalArgumentException("Invalid algorithm " + algoName); 604 } 605 606 // Encrypt payload 607 final Cipher cipher = Cipher.getInstance(algoName); 608 cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec); 609 final byte[] encrypted = 610 cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize)); 611 612 // Build ciphertext 613 final ByteBuffer cipherText = ByteBuffer.allocate(mIv.length + encrypted.length); 614 cipherText.put(mIv); 615 cipherText.put(encrypted); 616 617 return getByteArrayFromBuffer(cipherText); 618 } 619 } 620 621 public static final class EspAeadCipher extends EspCipher { 622 public final int icvLen; 623 EspAeadCipher( String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen)624 public EspAeadCipher( 625 String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen) { 626 super(algoName, blockSize, key, ivLen, saltLen); 627 this.icvLen = icvLen; 628 } 629 630 @Override getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)631 public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum) 632 throws GeneralSecurityException { 633 // Provided key consists of encryption/decryption key plus salt. Salt is used 634 // with ESP payload IV to build IvParameterSpec. 635 final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen); 636 final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length); 637 638 final SecretKeySpec secretKeySpec = new SecretKeySpec(secretKey, algoName); 639 640 final ByteBuffer ivParameterBuffer = ByteBuffer.allocate(saltLen + mIv.length); 641 ivParameterBuffer.put(salt); 642 ivParameterBuffer.put(mIv); 643 final IvParameterSpec ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array()); 644 645 final ByteBuffer aadBuffer = ByteBuffer.allocate(ESP_HDRLEN); 646 aadBuffer.putInt(spi); 647 aadBuffer.putInt(seqNum); 648 649 // Encrypt payload 650 final Cipher cipher = Cipher.getInstance(algoName); 651 cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec); 652 cipher.updateAAD(aadBuffer.array()); 653 final byte[] encryptedTextAndIcv = 654 cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize)); 655 656 // Build ciphertext 657 final ByteBuffer cipherText = 658 ByteBuffer.allocate(mIv.length + encryptedTextAndIcv.length); 659 cipherText.put(mIv); 660 cipherText.put(encryptedTextAndIcv); 661 662 return getByteArrayFromBuffer(cipherText); 663 } 664 } 665 666 public static class EspAuth { 667 public final String algoName; 668 public final byte[] key; 669 public final int icvLen; 670 671 private static final Set<String> JCE_SUPPORTED_MACS = new ArraySet<>(); 672 673 static { 674 JCE_SUPPORTED_MACS.add(HMAC_MD5); 675 JCE_SUPPORTED_MACS.add(HMAC_SHA1); 676 JCE_SUPPORTED_MACS.add(HMAC_SHA_256); 677 JCE_SUPPORTED_MACS.add(HMAC_SHA_384); 678 JCE_SUPPORTED_MACS.add(HMAC_SHA_512); 679 JCE_SUPPORTED_MACS.add(AES_CMAC); 680 } 681 EspAuth(String algoName, byte[] key, int icvLen)682 public EspAuth(String algoName, byte[] key, int icvLen) { 683 this.algoName = algoName; 684 this.key = key; 685 this.icvLen = icvLen; 686 } 687 getIcv(byte[] authenticatedSection)688 public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException { 689 if (AES_XCBC.equals(algoName)) { 690 final Cipher aesCipher = Cipher.getInstance(AES_CBC); 691 return new AesXCbcImpl().mac(key, authenticatedSection, true /* needTruncation */); 692 } else if (JCE_SUPPORTED_MACS.contains(algoName)) { 693 final Mac mac = Mac.getInstance(algoName); 694 final SecretKeySpec authKey = new SecretKeySpec(key, algoName); 695 mac.init(authKey); 696 697 final ByteBuffer buffer = ByteBuffer.wrap(mac.doFinal(authenticatedSection)); 698 final byte[] icv = new byte[icvLen]; 699 buffer.get(icv); 700 return icv; 701 } else { 702 throw new IllegalArgumentException("Invalid algorithm: " + algoName); 703 } 704 } 705 } 706 707 public static final class EspAuthNull extends EspAuth { 708 private static final String AUTH_NULL = "AUTH_NULL"; 709 private static final int ICV_LEN_UNUSED = 0; 710 private static final byte[] KEY_UNUSED = new byte[0]; 711 private static final byte[] ICV_EMPTY = new byte[0]; 712 713 private static final EspAuthNull sInstance = new EspAuthNull(); 714 EspAuthNull()715 private EspAuthNull() { 716 super(AUTH_NULL, KEY_UNUSED, ICV_LEN_UNUSED); 717 } 718 getInstance()719 public static EspAuthNull getInstance() { 720 return sInstance; 721 } 722 723 @Override getIcv(byte[] authenticatedSection)724 public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException { 725 return ICV_EMPTY; 726 } 727 } 728 729 /* 730 * Debug printing 731 */ 732 private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); 733 bytesToHex(byte[] bytes)734 public static String bytesToHex(byte[] bytes) { 735 StringBuilder sb = new StringBuilder(); 736 for (byte b : bytes) { 737 sb.append(hexArray[b >>> 4]); 738 sb.append(hexArray[b & 0x0F]); 739 sb.append(' '); 740 } 741 return sb.toString(); 742 } 743 } 744