• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static android.system.OsConstants.IPPROTO_IPV6;
20 import static android.system.OsConstants.IPPROTO_UDP;
21 
22 import android.util.ArraySet;
23 
24 import com.android.internal.net.ipsec.ike.crypto.AesXCbcImpl;
25 
26 import java.net.Inet4Address;
27 import java.net.Inet6Address;
28 import java.net.InetAddress;
29 import java.nio.ByteBuffer;
30 import java.nio.ShortBuffer;
31 import java.security.GeneralSecurityException;
32 import java.security.SecureRandom;
33 import java.util.Arrays;
34 import java.util.Set;
35 
36 import javax.crypto.Cipher;
37 import javax.crypto.Mac;
38 import javax.crypto.spec.IvParameterSpec;
39 import javax.crypto.spec.SecretKeySpec;
40 
41 public class PacketUtils {
42     private static final String TAG = PacketUtils.class.getSimpleName();
43 
44     private static final int DATA_BUFFER_LEN = 4096;
45 
46     static final int IP4_HDRLEN = 20;
47     static final int IP6_HDRLEN = 40;
48     static final int UDP_HDRLEN = 8;
49     static final int TCP_HDRLEN = 20;
50     static final int TCP_HDRLEN_WITH_TIMESTAMP_OPT = TCP_HDRLEN + 12;
51     static final int ESP_HDRLEN = 8;
52     static final int ESP_BLK_SIZE = 4; // ESP has to be 4-byte aligned
53     static final int ESP_TRAILER_LEN = 2;
54 
55     // Not defined in OsConstants
56     static final int IPPROTO_IPV4 = 4;
57     static final int IPPROTO_ESP = 50;
58 
59     // Encryption parameters
60     static final int AES_CBC_IV_LEN = 16;
61     static final int AES_CBC_BLK_SIZE = 16;
62     static final int AES_CTR_SALT_LEN = 4;
63 
64     static final int AES_CTR_KEY_LEN_20 = 20;
65     static final int AES_CTR_KEY_LEN_28 = 28;
66     static final int AES_CTR_KEY_LEN_36 = 36;
67     static final int AES_CTR_BLK_SIZE = ESP_BLK_SIZE;
68     static final int AES_CTR_IV_LEN = 8;
69 
70     // AEAD parameters
71     static final int AES_GCM_IV_LEN = 8;
72     static final int AES_GCM_BLK_SIZE = 4;
73     static final int CHACHA20_POLY1305_KEY_LEN = 36;
74     static final int CHACHA20_POLY1305_BLK_SIZE = ESP_BLK_SIZE;
75     static final int CHACHA20_POLY1305_IV_LEN = 8;
76     static final int CHACHA20_POLY1305_SALT_LEN = 4;
77     static final int CHACHA20_POLY1305_ICV_LEN = 16;
78 
79     // Authentication parameters
80     static final int HMAC_SHA256_ICV_LEN = 16;
81     static final int HMAC_SHA512_KEY_LEN = 64;
82     static final int HMAC_SHA512_ICV_LEN = 32;
83     static final int AES_XCBC_KEY_LEN = 16;
84     static final int AES_XCBC_ICV_LEN = 12;
85     static final int AES_CMAC_KEY_LEN = 16;
86     static final int AES_CMAC_ICV_LEN = 12;
87 
88     // Block counter field should be 32 bits and starts from value one as per RFC 3686
89     static final byte[] AES_CTR_INITIAL_COUNTER = new byte[] {0x00, 0x00, 0x00, 0x01};
90 
91     // Encryption algorithms
92     static final String AES = "AES";
93     static final String AES_CBC = "AES/CBC/NoPadding";
94     static final String AES_CTR = "AES/CTR/NoPadding";
95 
96     // AEAD algorithms
97     static final String CHACHA20_POLY1305 = "ChaCha20/Poly1305/NoPadding";
98 
99     // Authentication algorithms
100     static final String HMAC_MD5 = "HmacMD5";
101     static final String HMAC_SHA1 = "HmacSHA1";
102     static final String HMAC_SHA_256 = "HmacSHA256";
103     static final String HMAC_SHA_384 = "HmacSHA384";
104     static final String HMAC_SHA_512 = "HmacSHA512";
105     static final String AES_CMAC = "AESCMAC";
106     static final String AES_XCBC = "AesXCbc";
107 
108     public interface Payload {
getPacketBytes(IpHeader header)109         byte[] getPacketBytes(IpHeader header) throws Exception;
110 
addPacketBytes(IpHeader header, ByteBuffer resultBuffer)111         void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception;
112 
length()113         short length();
114 
getProtocolId()115         int getProtocolId();
116     }
117 
118     public abstract static class IpHeader {
119 
120         public final byte proto;
121         public final InetAddress srcAddr;
122         public final InetAddress dstAddr;
123         public final Payload payload;
124 
IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload)125         public IpHeader(int proto, InetAddress src, InetAddress dst, Payload payload) {
126             this.proto = (byte) proto;
127             this.srcAddr = src;
128             this.dstAddr = dst;
129             this.payload = payload;
130         }
131 
getPacketBytes()132         public abstract byte[] getPacketBytes() throws Exception;
133 
getProtocolId()134         public abstract int getProtocolId();
135     }
136 
137     public static class Ip4Header extends IpHeader {
138         private short checksum;
139 
Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload)140         public Ip4Header(int proto, Inet4Address src, Inet4Address dst, Payload payload) {
141             super(proto, src, dst, payload);
142         }
143 
getPacketBytes()144         public byte[] getPacketBytes() throws Exception {
145             ByteBuffer resultBuffer = buildHeader();
146             payload.addPacketBytes(this, resultBuffer);
147 
148             return getByteArrayFromBuffer(resultBuffer);
149         }
150 
buildHeader()151         public ByteBuffer buildHeader() {
152             ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
153 
154             // Version, IHL
155             bb.put((byte) (0x45));
156 
157             // DCSP, ECN
158             bb.put((byte) 0);
159 
160             // Total Length
161             bb.putShort((short) (IP4_HDRLEN + payload.length()));
162 
163             // Empty for Identification, Flags and Fragment Offset
164             bb.putShort((short) 0);
165             bb.put((byte) 0x40);
166             bb.put((byte) 0x00);
167 
168             // TTL
169             bb.put((byte) 64);
170 
171             // Protocol
172             bb.put(proto);
173 
174             // Header Checksum
175             final int ipChecksumOffset = bb.position();
176             bb.putShort((short) 0);
177 
178             // Src/Dst addresses
179             bb.put(srcAddr.getAddress());
180             bb.put(dstAddr.getAddress());
181 
182             bb.putShort(ipChecksumOffset, calculateChecksum(bb));
183 
184             return bb;
185         }
186 
calculateChecksum(ByteBuffer bb)187         private short calculateChecksum(ByteBuffer bb) {
188             int checksum = 0;
189 
190             // Calculate sum of 16-bit values, excluding checksum. IPv4 headers are always 32-bit
191             // aligned, so no special cases needed for unaligned values.
192             ShortBuffer shortBuffer = ByteBuffer.wrap(getByteArrayFromBuffer(bb)).asShortBuffer();
193             while (shortBuffer.hasRemaining()) {
194                 short val = shortBuffer.get();
195 
196                 // Wrap as needed
197                 checksum = addAndWrapForChecksum(checksum, val);
198             }
199 
200             return onesComplement(checksum);
201         }
202 
getProtocolId()203         public int getProtocolId() {
204             return IPPROTO_IPV4;
205         }
206     }
207 
208     public static class Ip6Header extends IpHeader {
Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload)209         public Ip6Header(int nextHeader, Inet6Address src, Inet6Address dst, Payload payload) {
210             super(nextHeader, src, dst, payload);
211         }
212 
getPacketBytes()213         public byte[] getPacketBytes() throws Exception {
214             ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
215 
216             // Version | Traffic Class (First 4 bits)
217             bb.put((byte) 0x60);
218 
219             // Traffic class (Last 4 bits), Flow Label
220             bb.put((byte) 0);
221             bb.put((byte) 0);
222             bb.put((byte) 0);
223 
224             // Payload Length
225             bb.putShort((short) payload.length());
226 
227             // Next Header
228             bb.put(proto);
229 
230             // Hop Limit
231             bb.put((byte) 64);
232 
233             // Src/Dst addresses
234             bb.put(srcAddr.getAddress());
235             bb.put(dstAddr.getAddress());
236 
237             // Payload
238             payload.addPacketBytes(this, bb);
239 
240             return getByteArrayFromBuffer(bb);
241         }
242 
getProtocolId()243         public int getProtocolId() {
244             return IPPROTO_IPV6;
245         }
246     }
247 
248     public static class BytePayload implements Payload {
249         public final byte[] payload;
250 
BytePayload(byte[] payload)251         public BytePayload(byte[] payload) {
252             this.payload = payload;
253         }
254 
getProtocolId()255         public int getProtocolId() {
256             return -1;
257         }
258 
getPacketBytes(IpHeader header)259         public byte[] getPacketBytes(IpHeader header) {
260             ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
261 
262             addPacketBytes(header, bb);
263             return getByteArrayFromBuffer(bb);
264         }
265 
addPacketBytes(IpHeader header, ByteBuffer resultBuffer)266         public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) {
267             resultBuffer.put(payload);
268         }
269 
length()270         public short length() {
271             return (short) payload.length;
272         }
273     }
274 
275     public static class UdpHeader implements Payload {
276 
277         public final short srcPort;
278         public final short dstPort;
279         public final Payload payload;
280 
UdpHeader(int srcPort, int dstPort, Payload payload)281         public UdpHeader(int srcPort, int dstPort, Payload payload) {
282             this.srcPort = (short) srcPort;
283             this.dstPort = (short) dstPort;
284             this.payload = payload;
285         }
286 
getProtocolId()287         public int getProtocolId() {
288             return IPPROTO_UDP;
289         }
290 
length()291         public short length() {
292             return (short) (payload.length() + 8);
293         }
294 
getPacketBytes(IpHeader header)295         public byte[] getPacketBytes(IpHeader header) throws Exception {
296             ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
297 
298             addPacketBytes(header, bb);
299             return getByteArrayFromBuffer(bb);
300         }
301 
addPacketBytes(IpHeader header, ByteBuffer resultBuffer)302         public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
303             // Source, Destination port
304             resultBuffer.putShort(srcPort);
305             resultBuffer.putShort(dstPort);
306 
307             // Payload Length
308             resultBuffer.putShort(length());
309 
310             // Get payload bytes for checksum + payload
311             ByteBuffer payloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
312             payload.addPacketBytes(header, payloadBuffer);
313             byte[] payloadBytes = getByteArrayFromBuffer(payloadBuffer);
314 
315             // Checksum
316             resultBuffer.putShort(calculateChecksum(header, payloadBytes));
317 
318             // Payload
319             resultBuffer.put(payloadBytes);
320         }
321 
calculateChecksum(IpHeader header, byte[] payloadBytes)322         private short calculateChecksum(IpHeader header, byte[] payloadBytes) throws Exception {
323             int newChecksum = 0;
324             ShortBuffer srcBuffer = ByteBuffer.wrap(header.srcAddr.getAddress()).asShortBuffer();
325             ShortBuffer dstBuffer = ByteBuffer.wrap(header.dstAddr.getAddress()).asShortBuffer();
326 
327             while (srcBuffer.hasRemaining() || dstBuffer.hasRemaining()) {
328                 short val = srcBuffer.hasRemaining() ? srcBuffer.get() : dstBuffer.get();
329 
330                 // Wrap as needed
331                 newChecksum = addAndWrapForChecksum(newChecksum, val);
332             }
333 
334             // Add pseudo-header values. Proto is 0-padded, so just use the byte.
335             newChecksum = addAndWrapForChecksum(newChecksum, header.proto);
336             newChecksum = addAndWrapForChecksum(newChecksum, length());
337             newChecksum = addAndWrapForChecksum(newChecksum, srcPort);
338             newChecksum = addAndWrapForChecksum(newChecksum, dstPort);
339             newChecksum = addAndWrapForChecksum(newChecksum, length());
340 
341             ShortBuffer payloadShortBuffer = ByteBuffer.wrap(payloadBytes).asShortBuffer();
342             while (payloadShortBuffer.hasRemaining()) {
343                 newChecksum = addAndWrapForChecksum(newChecksum, payloadShortBuffer.get());
344             }
345             if (payload.length() % 2 != 0) {
346                 newChecksum =
347                         addAndWrapForChecksum(
348                                 newChecksum, (payloadBytes[payloadBytes.length - 1] << 8));
349             }
350 
351             return onesComplement(newChecksum);
352         }
353     }
354 
355     public static class EspHeader implements Payload {
356         public final int nextHeader;
357         public final int spi;
358         public final int seqNum;
359         public final byte[] payload;
360         public final EspCipher cipher;
361         public final EspAuth auth;
362 
363         /**
364          * Generic constructor for ESP headers.
365          *
366          * <p>For Tunnel mode, payload will be a full IP header + attached payloads
367          *
368          * <p>For Transport mode, payload will be only the attached payloads, but with the checksum
369          * calculated using the pre-encryption IP header
370          */
EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload)371         public EspHeader(int nextHeader, int spi, int seqNum, byte[] key, byte[] payload) {
372             this(nextHeader, spi, seqNum, payload, getDefaultCipher(key), getDefaultAuth(key));
373         }
374 
375         /**
376          * Generic constructor for ESP headers that allows configuring encryption and authentication
377          * algortihms.
378          *
379          * <p>For Tunnel mode, payload will be a full IP header + attached payloads
380          *
381          * <p>For Transport mode, payload will be only the attached payloads, but with the checksum
382          * calculated using the pre-encryption IP header
383          */
EspHeader( int nextHeader, int spi, int seqNum, byte[] payload, EspCipher cipher, EspAuth auth)384         public EspHeader(
385                 int nextHeader,
386                 int spi,
387                 int seqNum,
388                 byte[] payload,
389                 EspCipher cipher,
390                 EspAuth auth) {
391             this.nextHeader = nextHeader;
392             this.spi = spi;
393             this.seqNum = seqNum;
394             this.payload = payload;
395             this.cipher = cipher;
396             this.auth = auth;
397 
398             if (cipher instanceof EspCipherNull && auth instanceof EspAuthNull) {
399                 throw new IllegalArgumentException("No algorithm is provided");
400             }
401 
402             if (cipher instanceof EspAeadCipher && !(auth instanceof EspAuthNull)) {
403                 throw new IllegalArgumentException(
404                         "AEAD is provided with an authentication" + " algorithm.");
405             }
406         }
407 
getDefaultCipher(byte[] key)408         private static EspCipher getDefaultCipher(byte[] key) {
409             return new EspCryptCipher(AES_CBC, AES_CBC_BLK_SIZE, key, AES_CBC_IV_LEN);
410         }
411 
getDefaultAuth(byte[] key)412         private static EspAuth getDefaultAuth(byte[] key) {
413             return new EspAuth(HMAC_SHA_256, key, HMAC_SHA256_ICV_LEN);
414         }
415 
getProtocolId()416         public int getProtocolId() {
417             return IPPROTO_ESP;
418         }
419 
length()420         public short length() {
421             final int icvLen =
422                     cipher instanceof EspAeadCipher ? ((EspAeadCipher) cipher).icvLen : auth.icvLen;
423             return calculateEspPacketSize(
424                     payload.length, cipher.ivLen, cipher.blockSize, icvLen * 8);
425         }
426 
getPacketBytes(IpHeader header)427         public byte[] getPacketBytes(IpHeader header) throws Exception {
428             ByteBuffer bb = ByteBuffer.allocate(DATA_BUFFER_LEN);
429 
430             addPacketBytes(header, bb);
431             return getByteArrayFromBuffer(bb);
432         }
433 
addPacketBytes(IpHeader header, ByteBuffer resultBuffer)434         public void addPacketBytes(IpHeader header, ByteBuffer resultBuffer) throws Exception {
435             ByteBuffer espPayloadBuffer = ByteBuffer.allocate(DATA_BUFFER_LEN);
436             espPayloadBuffer.putInt(spi);
437             espPayloadBuffer.putInt(seqNum);
438 
439             espPayloadBuffer.put(cipher.getCipherText(nextHeader, payload, spi, seqNum));
440             espPayloadBuffer.put(auth.getIcv(getByteArrayFromBuffer(espPayloadBuffer)));
441 
442             resultBuffer.put(getByteArrayFromBuffer(espPayloadBuffer));
443         }
444     }
445 
addAndWrapForChecksum(int currentChecksum, int value)446     private static int addAndWrapForChecksum(int currentChecksum, int value) {
447         currentChecksum += value & 0x0000ffff;
448 
449         // Wrap anything beyond the first 16 bits, and add to lower order bits
450         return (currentChecksum >>> 16) + (currentChecksum & 0x0000ffff);
451     }
452 
onesComplement(int val)453     private static short onesComplement(int val) {
454         val = (val >>> 16) + (val & 0xffff);
455 
456         if (val == 0) return 0;
457         return (short) ((~val) & 0xffff);
458     }
459 
calculateEspPacketSize( int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen)460     public static short calculateEspPacketSize(
461             int payloadLen, int cryptIvLength, int cryptBlockSize, int authTruncLen) {
462         final int ICV_LEN = authTruncLen / 8; // Auth trailer; based on truncation length
463 
464         // Align to block size of encryption algorithm
465         payloadLen = calculateEspEncryptedLength(payloadLen, cryptBlockSize);
466         payloadLen += cryptIvLength; // Initialization Vector
467         return (short) (payloadLen + ESP_HDRLEN + ICV_LEN);
468     }
469 
calculateEspEncryptedLength(int payloadLen, int cryptBlockSize)470     private static int calculateEspEncryptedLength(int payloadLen, int cryptBlockSize) {
471         payloadLen += 2; // ESP trailer
472 
473         // Align to block size of encryption algorithm
474         return payloadLen + calculateEspPadLen(payloadLen, cryptBlockSize);
475     }
476 
calculateEspPadLen(int payloadLen, int cryptBlockSize)477     private static int calculateEspPadLen(int payloadLen, int cryptBlockSize) {
478         return (cryptBlockSize - (payloadLen % cryptBlockSize)) % cryptBlockSize;
479     }
480 
getByteArrayFromBuffer(ByteBuffer buffer)481     private static byte[] getByteArrayFromBuffer(ByteBuffer buffer) {
482         return Arrays.copyOfRange(buffer.array(), 0, buffer.position());
483     }
484 
getIpHeader( int protocol, InetAddress src, InetAddress dst, Payload payload)485     public static IpHeader getIpHeader(
486             int protocol, InetAddress src, InetAddress dst, Payload payload) {
487         if ((src instanceof Inet6Address) != (dst instanceof Inet6Address)) {
488             throw new IllegalArgumentException("Invalid src/dst address combination");
489         }
490 
491         if (src instanceof Inet6Address) {
492             return new Ip6Header(protocol, (Inet6Address) src, (Inet6Address) dst, payload);
493         } else {
494             return new Ip4Header(protocol, (Inet4Address) src, (Inet4Address) dst, payload);
495         }
496     }
497 
498     public abstract static class EspCipher {
499         protected static final int SALT_LEN_UNUSED = 0;
500 
501         public final String algoName;
502         public final int blockSize;
503         public final byte[] key;
504         public final int ivLen;
505         public final int saltLen;
506         protected byte[] mIv;
507 
EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen)508         public EspCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) {
509             this.algoName = algoName;
510             this.blockSize = blockSize;
511             this.key = key;
512             this.ivLen = ivLen;
513             this.saltLen = saltLen;
514             this.mIv = getIv(ivLen);
515         }
516 
updateIv(byte[] iv)517         public void updateIv(byte[] iv) {
518             this.mIv = iv;
519         }
520 
getPaddedPayload(int nextHeader, byte[] payload, int blockSize)521         public static byte[] getPaddedPayload(int nextHeader, byte[] payload, int blockSize) {
522             final int paddedLen = calculateEspEncryptedLength(payload.length, blockSize);
523             final ByteBuffer paddedPayload = ByteBuffer.allocate(paddedLen);
524             paddedPayload.put(payload);
525 
526             // Add padding - consecutive integers from 0x01
527             byte pad = 1;
528             while (paddedPayload.position() < paddedPayload.limit() - ESP_TRAILER_LEN) {
529                 paddedPayload.put((byte) pad++);
530             }
531 
532             // Add padding length and next header
533             paddedPayload.put((byte) (paddedLen - ESP_TRAILER_LEN - payload.length));
534             paddedPayload.put((byte) nextHeader);
535 
536             return getByteArrayFromBuffer(paddedPayload);
537         }
538 
getIv(int ivLen)539         private static byte[] getIv(int ivLen) {
540             final byte[] iv = new byte[ivLen];
541             new SecureRandom().nextBytes(iv);
542             return iv;
543         }
544 
getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)545         public abstract byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
546                 throws GeneralSecurityException;
547     }
548 
549     public static final class EspCipherNull extends EspCipher {
550         private static final String CRYPT_NULL = "CRYPT_NULL";
551         private static final int IV_LEN_UNUSED = 0;
552         private static final byte[] KEY_UNUSED = new byte[0];
553 
554         private static final EspCipherNull sInstance = new EspCipherNull();
555 
EspCipherNull()556         private EspCipherNull() {
557             super(CRYPT_NULL, ESP_BLK_SIZE, KEY_UNUSED, IV_LEN_UNUSED, SALT_LEN_UNUSED);
558         }
559 
getInstance()560         public static EspCipherNull getInstance() {
561             return sInstance;
562         }
563 
564         @Override
getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)565         public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
566                 throws GeneralSecurityException {
567             return getPaddedPayload(nextHeader, payload, blockSize);
568         }
569     }
570 
571     public static final class EspCryptCipher extends EspCipher {
EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen)572         public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen) {
573             this(algoName, blockSize, key, ivLen, SALT_LEN_UNUSED);
574         }
575 
EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen)576         public EspCryptCipher(String algoName, int blockSize, byte[] key, int ivLen, int saltLen) {
577             super(algoName, blockSize, key, ivLen, saltLen);
578         }
579 
580         @Override
getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)581         public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
582                 throws GeneralSecurityException {
583             final IvParameterSpec ivParameterSpec;
584             final SecretKeySpec secretKeySpec;
585 
586             if (AES_CBC.equals(algoName)) {
587                 ivParameterSpec = new IvParameterSpec(mIv);
588                 secretKeySpec = new SecretKeySpec(key, algoName);
589             } else if (AES_CTR.equals(algoName)) {
590                 // Provided key consists of encryption/decryption key plus 4-byte salt. Salt is used
591                 // with ESP payload IV and initial block counter value to build IvParameterSpec.
592                 final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen);
593                 final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length);
594                 secretKeySpec = new SecretKeySpec(secretKey, algoName);
595 
596                 final ByteBuffer ivParameterBuffer =
597                         ByteBuffer.allocate(mIv.length + saltLen + AES_CTR_INITIAL_COUNTER.length);
598                 ivParameterBuffer.put(salt);
599                 ivParameterBuffer.put(mIv);
600                 ivParameterBuffer.put(AES_CTR_INITIAL_COUNTER);
601                 ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array());
602             } else {
603                 throw new IllegalArgumentException("Invalid algorithm " + algoName);
604             }
605 
606             // Encrypt payload
607             final Cipher cipher = Cipher.getInstance(algoName);
608             cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
609             final byte[] encrypted =
610                     cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize));
611 
612             // Build ciphertext
613             final ByteBuffer cipherText = ByteBuffer.allocate(mIv.length + encrypted.length);
614             cipherText.put(mIv);
615             cipherText.put(encrypted);
616 
617             return getByteArrayFromBuffer(cipherText);
618         }
619     }
620 
621     public static final class EspAeadCipher extends EspCipher {
622         public final int icvLen;
623 
EspAeadCipher( String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen)624         public EspAeadCipher(
625                 String algoName, int blockSize, byte[] key, int ivLen, int icvLen, int saltLen) {
626             super(algoName, blockSize, key, ivLen, saltLen);
627             this.icvLen = icvLen;
628         }
629 
630         @Override
getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)631         public byte[] getCipherText(int nextHeader, byte[] payload, int spi, int seqNum)
632                 throws GeneralSecurityException {
633             // Provided key consists of encryption/decryption key plus salt. Salt is used
634             // with ESP payload IV to build IvParameterSpec.
635             final byte[] secretKey = Arrays.copyOfRange(key, 0, key.length - saltLen);
636             final byte[] salt = Arrays.copyOfRange(key, secretKey.length, key.length);
637 
638             final SecretKeySpec secretKeySpec = new SecretKeySpec(secretKey, algoName);
639 
640             final ByteBuffer ivParameterBuffer = ByteBuffer.allocate(saltLen + mIv.length);
641             ivParameterBuffer.put(salt);
642             ivParameterBuffer.put(mIv);
643             final IvParameterSpec ivParameterSpec = new IvParameterSpec(ivParameterBuffer.array());
644 
645             final ByteBuffer aadBuffer = ByteBuffer.allocate(ESP_HDRLEN);
646             aadBuffer.putInt(spi);
647             aadBuffer.putInt(seqNum);
648 
649             // Encrypt payload
650             final Cipher cipher = Cipher.getInstance(algoName);
651             cipher.init(Cipher.ENCRYPT_MODE, secretKeySpec, ivParameterSpec);
652             cipher.updateAAD(aadBuffer.array());
653             final byte[] encryptedTextAndIcv =
654                     cipher.doFinal(getPaddedPayload(nextHeader, payload, blockSize));
655 
656             // Build ciphertext
657             final ByteBuffer cipherText =
658                     ByteBuffer.allocate(mIv.length + encryptedTextAndIcv.length);
659             cipherText.put(mIv);
660             cipherText.put(encryptedTextAndIcv);
661 
662             return getByteArrayFromBuffer(cipherText);
663         }
664     }
665 
666     public static class EspAuth {
667         public final String algoName;
668         public final byte[] key;
669         public final int icvLen;
670 
671         private static final Set<String> JCE_SUPPORTED_MACS = new ArraySet<>();
672 
673         static {
674             JCE_SUPPORTED_MACS.add(HMAC_MD5);
675             JCE_SUPPORTED_MACS.add(HMAC_SHA1);
676             JCE_SUPPORTED_MACS.add(HMAC_SHA_256);
677             JCE_SUPPORTED_MACS.add(HMAC_SHA_384);
678             JCE_SUPPORTED_MACS.add(HMAC_SHA_512);
679             JCE_SUPPORTED_MACS.add(AES_CMAC);
680         }
681 
EspAuth(String algoName, byte[] key, int icvLen)682         public EspAuth(String algoName, byte[] key, int icvLen) {
683             this.algoName = algoName;
684             this.key = key;
685             this.icvLen = icvLen;
686         }
687 
getIcv(byte[] authenticatedSection)688         public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
689             if (AES_XCBC.equals(algoName)) {
690                 final Cipher aesCipher = Cipher.getInstance(AES_CBC);
691                 return new AesXCbcImpl().mac(key, authenticatedSection, true /* needTruncation */);
692             } else if (JCE_SUPPORTED_MACS.contains(algoName)) {
693                 final Mac mac = Mac.getInstance(algoName);
694                 final SecretKeySpec authKey = new SecretKeySpec(key, algoName);
695                 mac.init(authKey);
696 
697                 final ByteBuffer buffer = ByteBuffer.wrap(mac.doFinal(authenticatedSection));
698                 final byte[] icv = new byte[icvLen];
699                 buffer.get(icv);
700                 return icv;
701             } else {
702                 throw new IllegalArgumentException("Invalid algorithm: " + algoName);
703             }
704         }
705     }
706 
707     public static final class EspAuthNull extends EspAuth {
708         private static final String AUTH_NULL = "AUTH_NULL";
709         private static final int ICV_LEN_UNUSED = 0;
710         private static final byte[] KEY_UNUSED = new byte[0];
711         private static final byte[] ICV_EMPTY = new byte[0];
712 
713         private static final EspAuthNull sInstance = new EspAuthNull();
714 
EspAuthNull()715         private EspAuthNull() {
716             super(AUTH_NULL, KEY_UNUSED, ICV_LEN_UNUSED);
717         }
718 
getInstance()719         public static EspAuthNull getInstance() {
720             return sInstance;
721         }
722 
723         @Override
getIcv(byte[] authenticatedSection)724         public byte[] getIcv(byte[] authenticatedSection) throws GeneralSecurityException {
725             return ICV_EMPTY;
726         }
727     }
728 
729     /*
730      * Debug printing
731      */
732     private static final char[] hexArray = "0123456789ABCDEF".toCharArray();
733 
bytesToHex(byte[] bytes)734     public static String bytesToHex(byte[] bytes) {
735         StringBuilder sb = new StringBuilder();
736         for (byte b : bytes) {
737             sb.append(hexArray[b >>> 4]);
738             sb.append(hexArray[b & 0x0F]);
739             sb.append(' ');
740         }
741         return sb.toString();
742     }
743 }
744