1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.net.module.util.netlink; 18 19 import static android.system.OsConstants.AF_INET; 20 import static android.system.OsConstants.AF_INET6; 21 import static android.system.OsConstants.IPPROTO_TCP; 22 import static android.system.OsConstants.IPPROTO_UDP; 23 import static android.system.OsConstants.SOCK_DGRAM; 24 import static android.system.OsConstants.SOCK_STREAM; 25 26 import static org.junit.Assert.assertEquals; 27 import static org.junit.Assert.assertNotEquals; 28 import static org.junit.Assume.assumeTrue; 29 30 import android.app.Instrumentation; 31 import android.content.Context; 32 import android.net.ConnectivityManager; 33 import android.os.Process; 34 import android.system.Os; 35 36 import androidx.test.InstrumentationRegistry; 37 import androidx.test.filters.SmallTest; 38 import androidx.test.runner.AndroidJUnit4; 39 40 import com.android.networkstack.apishim.common.ShimUtils; 41 42 import org.junit.Before; 43 import org.junit.Test; 44 import org.junit.runner.RunWith; 45 46 import java.io.FileDescriptor; 47 import java.net.Inet4Address; 48 import java.net.Inet6Address; 49 import java.net.InetAddress; 50 import java.net.InetSocketAddress; 51 52 @RunWith(AndroidJUnit4.class) 53 @SmallTest 54 public class InetDiagSocketIntegrationTest { 55 private ConnectivityManager mCm; 56 private Context mContext; 57 58 @Before setUp()59 public void setUp() throws Exception { 60 Instrumentation instrumentation = InstrumentationRegistry.getInstrumentation(); 61 mContext = instrumentation.getTargetContext(); 62 mCm = (ConnectivityManager) mContext.getSystemService(Context.CONNECTIVITY_SERVICE); 63 } 64 65 private class Connection { 66 public int socketDomain; 67 public int socketType; 68 public InetAddress localAddress; 69 public InetAddress remoteAddress; 70 public InetAddress localhostAddress; 71 public InetSocketAddress local; 72 public InetSocketAddress remote; 73 public int protocol; 74 public FileDescriptor localFd; 75 public FileDescriptor remoteFd; 76 createSocket()77 public FileDescriptor createSocket() throws Exception { 78 return Os.socket(socketDomain, socketType, protocol); 79 } 80 Connection(String to, String from)81 Connection(String to, String from) throws Exception { 82 remoteAddress = InetAddress.getByName(to); 83 if (from != null) { 84 localAddress = InetAddress.getByName(from); 85 } else { 86 localAddress = (remoteAddress instanceof Inet4Address) 87 ? Inet4Address.getByName("localhost") : Inet6Address.getByName("::"); 88 } 89 if ((localAddress instanceof Inet4Address) && (remoteAddress instanceof Inet4Address)) { 90 socketDomain = AF_INET; 91 localhostAddress = Inet4Address.getByName("localhost"); 92 } else { 93 socketDomain = AF_INET6; 94 localhostAddress = Inet6Address.getByName("::"); 95 } 96 } 97 close()98 public void close() throws Exception { 99 Os.close(localFd); 100 } 101 } 102 103 private class TcpConnection extends Connection { TcpConnection(String to, String from)104 TcpConnection(String to, String from) throws Exception { 105 super(to, from); 106 protocol = IPPROTO_TCP; 107 socketType = SOCK_STREAM; 108 109 remoteFd = createSocket(); 110 Os.bind(remoteFd, remoteAddress, 0); 111 Os.listen(remoteFd, 10); 112 int remotePort = ((InetSocketAddress) Os.getsockname(remoteFd)).getPort(); 113 114 localFd = createSocket(); 115 Os.bind(localFd, localAddress, 0); 116 Os.connect(localFd, remoteAddress, remotePort); 117 118 local = (InetSocketAddress) Os.getsockname(localFd); 119 remote = (InetSocketAddress) Os.getpeername(localFd); 120 } 121 close()122 public void close() throws Exception { 123 super.close(); 124 Os.close(remoteFd); 125 } 126 } 127 private class UdpConnection extends Connection { UdpConnection(String to, String from)128 UdpConnection(String to, String from) throws Exception { 129 super(to, from); 130 protocol = IPPROTO_UDP; 131 socketType = SOCK_DGRAM; 132 133 remoteFd = null; 134 localFd = createSocket(); 135 Os.bind(localFd, localAddress, 0); 136 137 Os.connect(localFd, remoteAddress, 7); 138 local = (InetSocketAddress) Os.getsockname(localFd); 139 remote = new InetSocketAddress(remoteAddress, 7); 140 } 141 } 142 checkConnectionOwnerUid(int protocol, InetSocketAddress local, InetSocketAddress remote, boolean expectSuccess)143 private void checkConnectionOwnerUid(int protocol, InetSocketAddress local, 144 InetSocketAddress remote, boolean expectSuccess) { 145 final int uid = mCm.getConnectionOwnerUid(protocol, local, remote); 146 147 if (expectSuccess) { 148 assertEquals(Process.myUid(), uid); 149 } else { 150 assertNotEquals(Process.myUid(), uid); 151 } 152 } 153 findLikelyFreeUdpPort(UdpConnection conn)154 private int findLikelyFreeUdpPort(UdpConnection conn) throws Exception { 155 UdpConnection udp = new UdpConnection(conn.remoteAddress.getHostAddress(), 156 conn.localAddress.getHostAddress()); 157 final int localPort = udp.local.getPort(); 158 udp.close(); 159 return localPort; 160 } 161 162 /** 163 * Create a test connection for UDP and TCP sockets and verify that this 164 * {protocol, local, remote} socket result in receiving a valid UID. 165 */ checkGetConnectionOwnerUid(String to, String from)166 public void checkGetConnectionOwnerUid(String to, String from) throws Exception { 167 TcpConnection tcp = new TcpConnection(to, from); 168 checkConnectionOwnerUid(tcp.protocol, tcp.local, tcp.remote, true); 169 checkConnectionOwnerUid(IPPROTO_UDP, tcp.local, tcp.remote, false); 170 checkConnectionOwnerUid(tcp.protocol, new InetSocketAddress(0), tcp.remote, false); 171 checkConnectionOwnerUid(tcp.protocol, tcp.local, new InetSocketAddress(0), false); 172 tcp.close(); 173 174 UdpConnection udp = new UdpConnection(to, from); 175 checkConnectionOwnerUid(udp.protocol, udp.local, udp.remote, true); 176 checkConnectionOwnerUid(IPPROTO_TCP, udp.local, udp.remote, false); 177 checkConnectionOwnerUid(udp.protocol, new InetSocketAddress(findLikelyFreeUdpPort(udp)), 178 udp.remote, false); 179 udp.close(); 180 } 181 182 @Test testGetConnectionOwnerUid()183 public void testGetConnectionOwnerUid() throws Exception { 184 // Skip the test for API <= Q, as b/141603906 this was only fixed in Q-QPR2 185 assumeTrue(ShimUtils.isAtLeastR()); 186 checkGetConnectionOwnerUid("::", null); 187 checkGetConnectionOwnerUid("::", "::"); 188 checkGetConnectionOwnerUid("0.0.0.0", null); 189 checkGetConnectionOwnerUid("0.0.0.0", "0.0.0.0"); 190 checkGetConnectionOwnerUid("127.0.0.1", null); 191 checkGetConnectionOwnerUid("127.0.0.1", "127.0.0.2"); 192 checkGetConnectionOwnerUid("::1", null); 193 checkGetConnectionOwnerUid("::1", "::1"); 194 } 195 196 /* Verify fix for b/141603906 */ 197 @Test testB141603906()198 public void testB141603906() throws Exception { 199 // Skip the test for API <= Q, as b/141603906 this was only fixed in Q-QPR2 200 assumeTrue(ShimUtils.isAtLeastR()); 201 final InetSocketAddress src = new InetSocketAddress(0); 202 final InetSocketAddress dst = new InetSocketAddress(0); 203 final int numThreads = 8; 204 final int numSockets = 5000; 205 final Thread[] threads = new Thread[numThreads]; 206 207 for (int i = 0; i < numThreads; i++) { 208 threads[i] = new Thread(() -> { 209 for (int j = 0; j < numSockets; j++) { 210 mCm.getConnectionOwnerUid(IPPROTO_TCP, src, dst); 211 } 212 }); 213 } 214 215 for (Thread thread : threads) { 216 thread.start(); 217 } 218 219 for (Thread thread : threads) { 220 thread.join(); 221 } 222 } 223 } 224