• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static android.net.IpSecAlgorithm.AUTH_CRYPT_AES_GCM;
20 import static android.net.IpSecAlgorithm.AUTH_HMAC_MD5;
21 import static android.net.IpSecAlgorithm.AUTH_HMAC_SHA1;
22 import static android.net.IpSecAlgorithm.AUTH_HMAC_SHA256;
23 import static android.net.IpSecAlgorithm.AUTH_HMAC_SHA384;
24 import static android.net.IpSecAlgorithm.AUTH_HMAC_SHA512;
25 import static android.net.IpSecAlgorithm.CRYPT_AES_CBC;
26 import static android.system.OsConstants.FIONREAD;
27 
28 import static org.junit.Assert.assertArrayEquals;
29 
30 import android.content.Context;
31 import android.net.ConnectivityManager;
32 import android.net.IpSecAlgorithm;
33 import android.net.IpSecManager;
34 import android.net.IpSecTransform;
35 import android.platform.test.annotations.AppModeFull;
36 import android.system.ErrnoException;
37 import android.system.Os;
38 import android.system.OsConstants;
39 import android.system.StructTimeval;
40 import android.util.Log;
41 
42 import androidx.test.InstrumentationRegistry;
43 import androidx.test.runner.AndroidJUnit4;
44 
45 import com.android.modules.utils.build.SdkLevel;
46 
47 import org.junit.Before;
48 import org.junit.Test;
49 import org.junit.runner.RunWith;
50 
51 import java.io.FileDescriptor;
52 import java.io.FileInputStream;
53 import java.io.FileOutputStream;
54 import java.io.IOException;
55 import java.io.InputStream;
56 import java.io.OutputStream;
57 import java.net.DatagramPacket;
58 import java.net.DatagramSocket;
59 import java.net.Inet6Address;
60 import java.net.InetAddress;
61 import java.net.InetSocketAddress;
62 import java.net.Socket;
63 import java.net.SocketAddress;
64 import java.net.SocketException;
65 import java.net.SocketImpl;
66 import java.net.SocketOptions;
67 import java.util.Arrays;
68 import java.util.HashSet;
69 import java.util.Set;
70 import java.util.concurrent.atomic.AtomicInteger;
71 
72 @RunWith(AndroidJUnit4.class)
73 public class IpSecBaseTest {
74 
75     private static final String TAG = IpSecBaseTest.class.getSimpleName();
76 
77     protected static final String IPV4_LOOPBACK = "127.0.0.1";
78     protected static final String IPV6_LOOPBACK = "::1";
79     protected static final String[] LOOPBACK_ADDRS = new String[] {IPV4_LOOPBACK, IPV6_LOOPBACK};
80     protected static final int[] DIRECTIONS =
81             new int[] {IpSecManager.DIRECTION_IN, IpSecManager.DIRECTION_OUT};
82 
83     protected static final byte[] TEST_DATA = "Best test data ever!".getBytes();
84     protected static final int DATA_BUFFER_LEN = 4096;
85     protected static final int SOCK_TIMEOUT = 500;
86 
87     private static final byte[] KEY_DATA = {
88         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
89         0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
90         0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
91         0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
92         0x20, 0x21, 0x22, 0x23
93     };
94 
95     private static final Set<String> MANDATORY_IPSEC_ALGOS_SINCE_P = new HashSet<>();
96 
97     static {
98         MANDATORY_IPSEC_ALGOS_SINCE_P.add(CRYPT_AES_CBC);
99         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_HMAC_MD5);
100         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_HMAC_SHA1);
101         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_HMAC_SHA256);
102         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_HMAC_SHA384);
103         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_HMAC_SHA512);
104         MANDATORY_IPSEC_ALGOS_SINCE_P.add(AUTH_CRYPT_AES_GCM);
105     }
106 
107     protected static final byte[] AUTH_KEY = getKey(256);
108     protected static final byte[] CRYPT_KEY = getKey(256);
109 
110     protected ConnectivityManager mCM;
111     protected IpSecManager mISM;
112 
113     @Before
setUp()114     public void setUp() throws Exception {
115         mISM =
116                 (IpSecManager)
117                         InstrumentationRegistry.getContext()
118                                 .getSystemService(Context.IPSEC_SERVICE);
119         mCM =
120                 (ConnectivityManager)
121                         InstrumentationRegistry.getContext()
122                                 .getSystemService(Context.CONNECTIVITY_SERVICE);
123     }
124 
125     /** Checks if an IPsec algorithm is enabled on the device */
hasIpSecAlgorithm(String algorithm)126     protected static boolean hasIpSecAlgorithm(String algorithm) {
127         if (SdkLevel.isAtLeastS()) {
128             return IpSecAlgorithm.getSupportedAlgorithms().contains(algorithm);
129         } else {
130             return MANDATORY_IPSEC_ALGOS_SINCE_P.contains(algorithm);
131         }
132     }
133 
getKeyBytes(int byteLength)134     protected static byte[] getKeyBytes(int byteLength) {
135         return Arrays.copyOf(KEY_DATA, byteLength);
136     }
137 
getKey(int bitLength)138     protected static byte[] getKey(int bitLength) {
139         if (bitLength % 8 != 0) {
140             throw new IllegalArgumentException("Invalid key length in bits" + bitLength);
141         }
142         return getKeyBytes(bitLength / 8);
143     }
144 
getDomain(InetAddress address)145     protected static int getDomain(InetAddress address) {
146         int domain;
147         if (address instanceof Inet6Address) {
148             domain = OsConstants.AF_INET6;
149         } else {
150             domain = OsConstants.AF_INET;
151         }
152         return domain;
153     }
154 
getPort(FileDescriptor sock)155     protected static int getPort(FileDescriptor sock) throws Exception {
156         return ((InetSocketAddress) Os.getsockname(sock)).getPort();
157     }
158 
159     public static interface GenericSocket extends AutoCloseable {
send(byte[] data)160         void send(byte[] data) throws Exception;
161 
receive()162         byte[] receive() throws Exception;
163 
getPort()164         int getPort() throws Exception;
165 
close()166         void close() throws Exception;
167 
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)168         void applyTransportModeTransform(
169                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception;
170 
removeTransportModeTransforms(IpSecManager ism)171         void removeTransportModeTransforms(IpSecManager ism) throws Exception;
172     }
173 
174     public static interface GenericTcpSocket extends GenericSocket {}
175 
176     public static interface GenericUdpSocket extends GenericSocket {
sendTo(byte[] data, InetAddress dstAddr, int port)177         void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception;
178     }
179 
180     public abstract static class NativeSocket implements GenericSocket {
181         public FileDescriptor mFd;
182 
NativeSocket(FileDescriptor fd)183         public NativeSocket(FileDescriptor fd) {
184             mFd = fd;
185         }
186 
187         @Override
send(byte[] data)188         public void send(byte[] data) throws Exception {
189             Os.write(mFd, data, 0, data.length);
190         }
191 
192         @Override
receive()193         public byte[] receive() throws Exception {
194             byte[] in = new byte[DATA_BUFFER_LEN];
195             AtomicInteger bytesRead = new AtomicInteger(-1);
196 
197             Thread readSockThread = new Thread(() -> {
198                 long startTime = System.currentTimeMillis();
199                 while (bytesRead.get() < 0 && System.currentTimeMillis() < startTime + SOCK_TIMEOUT) {
200                     try {
201                         bytesRead.set(Os.recvfrom(mFd, in, 0, DATA_BUFFER_LEN, 0, null));
202                     } catch (Exception e) {
203                         Log.e(TAG, "Error encountered reading from socket", e);
204                     }
205                 }
206             });
207 
208             readSockThread.start();
209             readSockThread.join(SOCK_TIMEOUT);
210 
211             if (bytesRead.get() < 0) {
212                 throw new IOException("No data received from socket");
213             }
214 
215             return Arrays.copyOfRange(in, 0, bytesRead.get());
216         }
217 
218         @Override
getPort()219         public int getPort() throws Exception {
220             return IpSecBaseTest.getPort(mFd);
221         }
222 
223         @Override
close()224         public void close() throws Exception {
225             Os.close(mFd);
226         }
227 
228         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)229         public void applyTransportModeTransform(
230                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
231             ism.applyTransportModeTransform(mFd, direction, transform);
232         }
233 
234         @Override
removeTransportModeTransforms(IpSecManager ism)235         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
236             ism.removeTransportModeTransforms(mFd);
237         }
238     }
239 
240     public static class NativeTcpSocket extends NativeSocket implements GenericTcpSocket {
NativeTcpSocket(FileDescriptor fd)241         public NativeTcpSocket(FileDescriptor fd) {
242             super(fd);
243         }
244 
acceptToJavaSocket()245         public JavaTcpSocket acceptToJavaSocket() throws Exception {
246             InetSocketAddress peer = new InetSocketAddress(0);
247             FileDescriptor newFd = Os.accept(mFd, peer);
248             return new JavaTcpSocket(new AcceptedTcpFileDescriptorSocket(newFd, peer, getPort()));
249         }
250     }
251 
252     public static class NativeUdpSocket extends NativeSocket implements GenericUdpSocket {
NativeUdpSocket(FileDescriptor fd)253         public NativeUdpSocket(FileDescriptor fd) {
254             super(fd);
255         }
256 
257         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)258         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
259             Os.sendto(mFd, data, 0, data.length, 0, dstAddr, port);
260         }
261     }
262 
263     public static class JavaUdpSocket implements GenericUdpSocket {
264         public final DatagramSocket mSocket;
265 
JavaUdpSocket(InetAddress localAddr, int port)266         public JavaUdpSocket(InetAddress localAddr, int port) {
267             try {
268                 mSocket = new DatagramSocket(port, localAddr);
269                 mSocket.setSoTimeout(SOCK_TIMEOUT);
270             } catch (SocketException e) {
271                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
272                 // could easily end up in an endless wait.
273                 throw new RuntimeException(e);
274             }
275         }
276 
JavaUdpSocket(InetAddress localAddr)277         public JavaUdpSocket(InetAddress localAddr) {
278             try {
279                 mSocket = new DatagramSocket(0, localAddr);
280                 mSocket.setSoTimeout(SOCK_TIMEOUT);
281             } catch (SocketException e) {
282                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
283                 // could easily end up in an endless wait.
284                 throw new RuntimeException(e);
285             }
286         }
287 
288         @Override
send(byte[] data)289         public void send(byte[] data) throws Exception {
290             mSocket.send(new DatagramPacket(data, data.length));
291         }
292 
293         @Override
sendTo(byte[] data, InetAddress dstAddr, int port)294         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
295             mSocket.send(new DatagramPacket(data, data.length, dstAddr, port));
296         }
297 
298         @Override
getPort()299         public int getPort() throws Exception {
300             return mSocket.getLocalPort();
301         }
302 
303         @Override
close()304         public void close() throws Exception {
305             mSocket.close();
306         }
307 
308         @Override
receive()309         public byte[] receive() throws Exception {
310             DatagramPacket data = new DatagramPacket(new byte[DATA_BUFFER_LEN], DATA_BUFFER_LEN);
311             mSocket.receive(data);
312             return Arrays.copyOfRange(data.getData(), 0, data.getLength());
313         }
314 
315         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)316         public void applyTransportModeTransform(
317                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
318             ism.applyTransportModeTransform(mSocket, direction, transform);
319         }
320 
321         @Override
removeTransportModeTransforms(IpSecManager ism)322         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
323             ism.removeTransportModeTransforms(mSocket);
324         }
325     }
326 
327     public static class JavaTcpSocket implements GenericTcpSocket {
328         public final Socket mSocket;
329 
JavaTcpSocket(Socket socket)330         public JavaTcpSocket(Socket socket) {
331             mSocket = socket;
332             try {
333                 mSocket.setSoTimeout(SOCK_TIMEOUT);
334             } catch (SocketException e) {
335                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
336                 // could easily end up in an endless wait.
337                 throw new RuntimeException(e);
338             }
339         }
340 
341         @Override
send(byte[] data)342         public void send(byte[] data) throws Exception {
343             mSocket.getOutputStream().write(data);
344         }
345 
346         @Override
receive()347         public byte[] receive() throws Exception {
348             byte[] in = new byte[DATA_BUFFER_LEN];
349             int bytesRead = mSocket.getInputStream().read(in);
350             return Arrays.copyOfRange(in, 0, bytesRead);
351         }
352 
353         @Override
getPort()354         public int getPort() throws Exception {
355             return mSocket.getLocalPort();
356         }
357 
358         @Override
close()359         public void close() throws Exception {
360             mSocket.close();
361         }
362 
363         @Override
applyTransportModeTransform( IpSecManager ism, int direction, IpSecTransform transform)364         public void applyTransportModeTransform(
365                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
366             ism.applyTransportModeTransform(mSocket, direction, transform);
367         }
368 
369         @Override
removeTransportModeTransforms(IpSecManager ism)370         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
371             ism.removeTransportModeTransforms(mSocket);
372         }
373     }
374 
375     private static class AcceptedTcpFileDescriptorSocket extends Socket {
376 
AcceptedTcpFileDescriptorSocket(FileDescriptor fd, InetSocketAddress remote, int localPort)377         AcceptedTcpFileDescriptorSocket(FileDescriptor fd, InetSocketAddress remote,
378                 int localPort) throws IOException {
379             super(new FileDescriptorSocketImpl(fd, remote, localPort));
380             connect(remote);
381         }
382 
383         private static class FileDescriptorSocketImpl extends SocketImpl {
384 
FileDescriptorSocketImpl(FileDescriptor fd, InetSocketAddress remote, int localPort)385             private FileDescriptorSocketImpl(FileDescriptor fd, InetSocketAddress remote,
386                     int localPort) {
387                 this.fd = fd;
388                 this.address = remote.getAddress();
389                 this.port = remote.getPort();
390                 this.localport = localPort;
391             }
392 
393             @Override
create(boolean stream)394             protected void create(boolean stream) throws IOException {
395                 // The socket has been created.
396             }
397 
398             @Override
connect(String host, int port)399             protected void connect(String host, int port) throws IOException {
400                 // The socket has connected.
401             }
402 
403             @Override
connect(InetAddress address, int port)404             protected void connect(InetAddress address, int port) throws IOException {
405                 // The socket has connected.
406             }
407 
408             @Override
connect(SocketAddress address, int timeout)409             protected void connect(SocketAddress address, int timeout) throws IOException {
410                 // The socket has connected.
411             }
412 
413             @Override
bind(InetAddress host, int port)414             protected void bind(InetAddress host, int port) throws IOException {
415                 // The socket is bounded.
416             }
417 
418             @Override
listen(int backlog)419             protected void listen(int backlog) throws IOException {
420                 throw new UnsupportedOperationException("listen");
421             }
422 
423             @Override
accept(SocketImpl s)424             protected void accept(SocketImpl s) throws IOException {
425                 throw new UnsupportedOperationException("accept");
426             }
427 
428             @Override
getInputStream()429             protected InputStream getInputStream() throws IOException {
430                 return new FileInputStream(fd);
431             }
432 
433             @Override
getOutputStream()434             protected OutputStream getOutputStream() throws IOException {
435                 return new FileOutputStream(fd);
436             }
437 
438             @Override
available()439             protected int available() throws IOException {
440                 try {
441                     return Os.ioctlInt(fd, FIONREAD);
442                 } catch (ErrnoException e) {
443                     throw new IOException(e);
444                 }
445             }
446 
447             @Override
close()448             protected void close() throws IOException {
449                 try {
450                     Os.close(fd);
451                 } catch (ErrnoException e) {
452                     throw new IOException(e);
453                 }
454             }
455 
456             @Override
sendUrgentData(int data)457             protected void sendUrgentData(int data) throws IOException {
458                 throw new UnsupportedOperationException("sendUrgentData");
459             }
460 
461             @Override
setOption(int optID, Object value)462             public void setOption(int optID, Object value) throws SocketException {
463                 try {
464                     setOptionInternal(optID, value);
465                 } catch (ErrnoException e) {
466                     throw new SocketException(e.getMessage());
467                 }
468             }
469 
setOptionInternal(int optID, Object value)470             private void setOptionInternal(int optID, Object value) throws ErrnoException,
471                     SocketException {
472                 switch(optID) {
473                     case SocketOptions.SO_TIMEOUT:
474                         int millis = (Integer) value;
475                         StructTimeval tv = StructTimeval.fromMillis(millis);
476                         Os.setsockoptTimeval(fd, OsConstants.SOL_SOCKET, OsConstants.SO_RCVTIMEO,
477                                 tv);
478                         return;
479                     default:
480                         throw new SocketException("Unknown socket option: " + optID);
481                 }
482             }
483 
484             @Override
getOption(int optID)485             public Object getOption(int optID) throws SocketException {
486                 try {
487                     return getOptionInternal(optID);
488                 } catch (ErrnoException e) {
489                     throw new SocketException(e.getMessage());
490                 }
491             }
492 
getOptionInternal(int optID)493             private Object getOptionInternal(int optID) throws ErrnoException, SocketException {
494                 switch (optID) {
495                     case SocketOptions.SO_LINGER:
496                         // Returns an arbitrary value because IpSecManager doesn't actually
497                         // use this value.
498                         return 10;
499                     default:
500                         throw new SocketException("Unknown socket option: " + optID);
501                 }
502             }
503         }
504     }
505 
506     public static class SocketPair<T> {
507         public final T mLeftSock;
508         public final T mRightSock;
509 
SocketPair(T leftSock, T rightSock)510         public SocketPair(T leftSock, T rightSock) {
511             mLeftSock = leftSock;
512             mRightSock = rightSock;
513         }
514     }
515 
applyTransformBidirectionally( IpSecManager ism, IpSecTransform transform, GenericSocket socket)516     protected static void applyTransformBidirectionally(
517             IpSecManager ism, IpSecTransform transform, GenericSocket socket) throws Exception {
518         for (int direction : DIRECTIONS) {
519             socket.applyTransportModeTransform(ism, direction, transform);
520         }
521     }
522 
getNativeUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)523     public static SocketPair<NativeUdpSocket> getNativeUdpSocketPair(
524             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
525             throws Exception {
526         int domain = getDomain(localAddr);
527 
528         NativeUdpSocket leftSock = new NativeUdpSocket(
529             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
530         NativeUdpSocket rightSock = new NativeUdpSocket(
531             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
532 
533         for (NativeUdpSocket sock : new NativeUdpSocket[] {leftSock, rightSock}) {
534             applyTransformBidirectionally(ism, transform, sock);
535             Os.bind(sock.mFd, localAddr, 0);
536         }
537 
538         if (connected) {
539             Os.connect(leftSock.mFd, localAddr, rightSock.getPort());
540             Os.connect(rightSock.mFd, localAddr, leftSock.getPort());
541         }
542 
543         return new SocketPair<>(leftSock, rightSock);
544     }
545 
getNativeTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)546     public static SocketPair<NativeTcpSocket> getNativeTcpSocketPair(
547             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
548         int domain = getDomain(localAddr);
549 
550         NativeTcpSocket server = new NativeTcpSocket(
551                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
552         NativeTcpSocket client = new NativeTcpSocket(
553                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
554 
555         Os.bind(server.mFd, localAddr, 0);
556 
557         applyTransformBidirectionally(ism, transform, server);
558         applyTransformBidirectionally(ism, transform, client);
559 
560         Os.listen(server.mFd, 10);
561         Os.connect(client.mFd, localAddr, server.getPort());
562         NativeTcpSocket accepted = new NativeTcpSocket(Os.accept(server.mFd, null));
563 
564         applyTransformBidirectionally(ism, transform, accepted);
565         server.close();
566 
567         return new SocketPair<>(client, accepted);
568     }
569 
getJavaUdpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)570     public static SocketPair<JavaUdpSocket> getJavaUdpSocketPair(
571             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
572             throws Exception {
573         JavaUdpSocket leftSock = new JavaUdpSocket(localAddr);
574         JavaUdpSocket rightSock = new JavaUdpSocket(localAddr);
575 
576         applyTransformBidirectionally(ism, transform, leftSock);
577         applyTransformBidirectionally(ism, transform, rightSock);
578 
579         if (connected) {
580             leftSock.mSocket.connect(localAddr, rightSock.mSocket.getLocalPort());
581             rightSock.mSocket.connect(localAddr, leftSock.mSocket.getLocalPort());
582         }
583 
584         return new SocketPair<>(leftSock, rightSock);
585     }
586 
getJavaTcpSocketPair( InetAddress localAddr, IpSecManager ism, IpSecTransform transform)587     public static SocketPair<JavaTcpSocket> getJavaTcpSocketPair(
588             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
589         JavaTcpSocket clientSock = new JavaTcpSocket(new Socket());
590 
591         // While technically the client socket does not need to be bound, the OpenJDK implementation
592         // of Socket only allocates an FD when bind() or connect() or other similar methods are
593         // called. So we call bind to force the FD creation, so that we can apply a transform to it
594         // prior to socket connect.
595         clientSock.mSocket.bind(new InetSocketAddress(localAddr, 0));
596 
597         // IpSecService doesn't support serverSockets at the moment; workaround using FD
598         NativeTcpSocket server = new NativeTcpSocket(
599                 Os.socket(getDomain(localAddr), OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
600         Os.bind(server.mFd, localAddr, 0);
601 
602         applyTransformBidirectionally(ism, transform, server);
603         applyTransformBidirectionally(ism, transform, clientSock);
604 
605         Os.listen(server.mFd, 10 /* backlog */);
606         clientSock.mSocket.connect(new InetSocketAddress(localAddr, server.getPort()));
607         JavaTcpSocket acceptedSock = server.acceptToJavaSocket();
608 
609         applyTransformBidirectionally(ism, transform, acceptedSock);
610         server.close();
611 
612         return new SocketPair<>(clientSock, acceptedSock);
613     }
614 
checkSocketPair(GenericSocket left, GenericSocket right)615     private void checkSocketPair(GenericSocket left, GenericSocket right) throws Exception {
616         left.send(TEST_DATA);
617         assertArrayEquals(TEST_DATA, right.receive());
618 
619         right.send(TEST_DATA);
620         assertArrayEquals(TEST_DATA, left.receive());
621 
622         left.close();
623         right.close();
624     }
625 
checkUnconnectedUdpSocketPair( GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr)626     private void checkUnconnectedUdpSocketPair(
627             GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr) throws Exception {
628         left.sendTo(TEST_DATA, localAddr, right.getPort());
629         assertArrayEquals(TEST_DATA, right.receive());
630 
631         right.sendTo(TEST_DATA, localAddr, left.getPort());
632         assertArrayEquals(TEST_DATA, left.receive());
633 
634         left.close();
635         right.close();
636     }
637 
buildIpSecTransform( Context context, IpSecManager.SecurityParameterIndex spi, IpSecManager.UdpEncapsulationSocket encapSocket, InetAddress remoteAddr)638     protected static IpSecTransform buildIpSecTransform(
639             Context context,
640             IpSecManager.SecurityParameterIndex spi,
641             IpSecManager.UdpEncapsulationSocket encapSocket,
642             InetAddress remoteAddr)
643             throws Exception {
644         IpSecTransform.Builder builder =
645                 new IpSecTransform.Builder(context)
646                         .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
647                         .setAuthentication(
648                                 new IpSecAlgorithm(
649                                         IpSecAlgorithm.AUTH_HMAC_SHA256,
650                                         AUTH_KEY,
651                                         AUTH_KEY.length * 4));
652 
653         if (encapSocket != null) {
654             builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
655         }
656 
657         return builder.buildTransportModeTransform(remoteAddr, spi);
658     }
659 
buildDefaultTransform(InetAddress localAddr)660     private IpSecTransform buildDefaultTransform(InetAddress localAddr) throws Exception {
661         try (IpSecManager.SecurityParameterIndex spi =
662                 mISM.allocateSecurityParameterIndex(localAddr)) {
663             return buildIpSecTransform(InstrumentationRegistry.getContext(), spi, null, localAddr);
664         }
665     }
666 
667     @Test
668     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaTcpSocketPair()669     public void testJavaTcpSocketPair() throws Exception {
670         for (String addr : LOOPBACK_ADDRS) {
671             InetAddress local = InetAddress.getByName(addr);
672             try (IpSecTransform transform = buildDefaultTransform(local)) {
673                 SocketPair<JavaTcpSocket> sockets = getJavaTcpSocketPair(local, mISM, transform);
674                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
675             }
676         }
677     }
678 
679     @Test
680     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaUdpSocketPair()681     public void testJavaUdpSocketPair() throws Exception {
682         for (String addr : LOOPBACK_ADDRS) {
683             InetAddress local = InetAddress.getByName(addr);
684             try (IpSecTransform transform = buildDefaultTransform(local)) {
685                 SocketPair<JavaUdpSocket> sockets =
686                         getJavaUdpSocketPair(local, mISM, transform, true);
687                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
688             }
689         }
690     }
691 
692     @Test
693     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testJavaUdpSocketPairUnconnected()694     public void testJavaUdpSocketPairUnconnected() throws Exception {
695         for (String addr : LOOPBACK_ADDRS) {
696             InetAddress local = InetAddress.getByName(addr);
697             try (IpSecTransform transform = buildDefaultTransform(local)) {
698                 SocketPair<JavaUdpSocket> sockets =
699                         getJavaUdpSocketPair(local, mISM, transform, false);
700                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
701             }
702         }
703     }
704 
705     @Test
706     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeTcpSocketPair()707     public void testNativeTcpSocketPair() throws Exception {
708         for (String addr : LOOPBACK_ADDRS) {
709             InetAddress local = InetAddress.getByName(addr);
710             try (IpSecTransform transform = buildDefaultTransform(local)) {
711                 SocketPair<NativeTcpSocket> sockets =
712                         getNativeTcpSocketPair(local, mISM, transform);
713                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
714             }
715         }
716     }
717 
718     @Test
719     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeUdpSocketPair()720     public void testNativeUdpSocketPair() throws Exception {
721         for (String addr : LOOPBACK_ADDRS) {
722             InetAddress local = InetAddress.getByName(addr);
723             try (IpSecTransform transform = buildDefaultTransform(local)) {
724                 SocketPair<NativeUdpSocket> sockets =
725                         getNativeUdpSocketPair(local, mISM, transform, true);
726                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
727             }
728         }
729     }
730 
731     @Test
732     @AppModeFull(reason = "Socket cannot bind in instant app mode")
testNativeUdpSocketPairUnconnected()733     public void testNativeUdpSocketPairUnconnected() throws Exception {
734         for (String addr : LOOPBACK_ADDRS) {
735             InetAddress local = InetAddress.getByName(addr);
736             try (IpSecTransform transform = buildDefaultTransform(local)) {
737                 SocketPair<NativeUdpSocket> sockets =
738                         getNativeUdpSocketPair(local, mISM, transform, false);
739                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
740             }
741         }
742     }
743 }
744