1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.example.android.vdmdemo.common; 18 19 import android.content.Context; 20 import android.content.pm.PackageManager; 21 import android.net.ConnectivityManager; 22 import android.net.Network; 23 import android.net.NetworkCapabilities; 24 import android.net.NetworkRequest; 25 import android.net.wifi.aware.AttachCallback; 26 import android.net.wifi.aware.DiscoverySession; 27 import android.net.wifi.aware.DiscoverySessionCallback; 28 import android.net.wifi.aware.PeerHandle; 29 import android.net.wifi.aware.PublishConfig; 30 import android.net.wifi.aware.PublishDiscoverySession; 31 import android.net.wifi.aware.SubscribeConfig; 32 import android.net.wifi.aware.SubscribeDiscoverySession; 33 import android.net.wifi.aware.WifiAwareManager; 34 import android.net.wifi.aware.WifiAwareNetworkInfo; 35 import android.net.wifi.aware.WifiAwareNetworkSpecifier; 36 import android.net.wifi.aware.WifiAwareSession; 37 import android.os.Build; 38 import android.os.Handler; 39 import android.os.HandlerThread; 40 import android.util.Log; 41 42 import androidx.annotation.GuardedBy; 43 import androidx.annotation.NonNull; 44 45 import dagger.hilt.android.qualifiers.ApplicationContext; 46 47 import java.io.IOException; 48 import java.net.Inet6Address; 49 import java.net.ServerSocket; 50 import java.net.Socket; 51 import java.net.SocketTimeoutException; 52 import java.util.ArrayList; 53 import java.util.List; 54 import java.util.Optional; 55 import java.util.concurrent.CompletableFuture; 56 import java.util.function.Consumer; 57 58 import javax.inject.Inject; 59 import javax.inject.Singleton; 60 61 /** Shared class between the client and the host, managing the connection between them. */ 62 @Singleton 63 public class ConnectionManager { 64 65 private static final String TAG = "VdmConnectionManager"; 66 private static final String CONNECTION_SERVICE_ID = "com.example.android.vdmdemo"; 67 private static final int NETWORK_TIMEOUT_MS = 5000; 68 69 private final RemoteIo mRemoteIo; 70 71 @ApplicationContext private final Context mContext; 72 private final ConnectivityManager mConnectivityManager; 73 private final Handler mBackgroundHandler; 74 private String mChannel = ""; 75 76 private CompletableFuture<WifiAwareSession> mWifiAwareSessionFuture = new CompletableFuture<>(); 77 78 private DiscoverySession mDiscoverySession; 79 80 /** Simple data structure to allow clients to query the current status. */ 81 public static final class ConnectionStatus { 82 public String remoteDeviceName = null; 83 public String errorMessage = null; 84 public State state = State.DISCONNECTED; 85 86 /** Enum indicating the current connection state. */ 87 public enum State { 88 DISCONNECTED, INITIALIZED, CONNECTING, CONNECTED, ERROR 89 } 90 } 91 92 @GuardedBy("mConnectionStatus") 93 private final ConnectionStatus mConnectionStatus = new ConnectionStatus(); 94 95 @GuardedBy("mConnectionCallbacks") 96 private final List<Consumer<ConnectionStatus>> mConnectionCallbacks = new ArrayList<>(); 97 98 private final RemoteIo.StreamClosedCallback mStreamClosedCallback = this::onInitialized; 99 100 @Inject ConnectionManager(@pplicationContext Context context, RemoteIo remoteIo)101 ConnectionManager(@ApplicationContext Context context, RemoteIo remoteIo) { 102 mRemoteIo = remoteIo; 103 mContext = context; 104 105 mConnectivityManager = context.getSystemService(ConnectivityManager.class); 106 final HandlerThread backgroundThread = new HandlerThread("ConnectionThread"); 107 backgroundThread.start(); 108 mBackgroundHandler = new Handler(backgroundThread.getLooper()); 109 } 110 getLocalEndpointId()111 static String getLocalEndpointId() { 112 return Build.MODEL; 113 } 114 115 /** Registers a listener for connection events. */ addConnectionCallback(Consumer<ConnectionStatus> callback)116 public void addConnectionCallback(Consumer<ConnectionStatus> callback) { 117 synchronized (mConnectionCallbacks) { 118 mConnectionCallbacks.add(callback); 119 } 120 } 121 122 /** Registers a listener for connection events. */ removeConnectionCallback(Consumer<ConnectionStatus> callback)123 public void removeConnectionCallback(Consumer<ConnectionStatus> callback) { 124 synchronized (mConnectionCallbacks) { 125 mConnectionCallbacks.remove(callback); 126 } 127 } 128 129 /** Returns the current connection status. */ getConnectionStatus()130 public ConnectionStatus getConnectionStatus() { 131 synchronized (mConnectionStatus) { 132 return mConnectionStatus; 133 } 134 } 135 136 /** Publish a local service so remote devices can discover this device. */ startHostSession(String channel)137 public void startHostSession(String channel) { 138 mChannel = channel; 139 final String serviceName = getServiceName(channel); 140 var unused = createWifiAwareSession().thenAccept(session -> session.publish( 141 new PublishConfig.Builder().setServiceName(serviceName).build(), 142 new HostDiscoverySessionCallback(), 143 mBackgroundHandler)); 144 } 145 146 /** Looks for published services from remote devices and subscribes to them. */ startClientSession(String channel)147 public void startClientSession(String channel) { 148 mChannel = channel; 149 final String serviceName = getServiceName(channel); 150 var unused = createWifiAwareSession().thenAccept(session -> session.subscribe( 151 new SubscribeConfig.Builder().setServiceName(serviceName).build(), 152 new ClientDiscoverySessionCallback(), 153 mBackgroundHandler)); 154 } 155 getServiceName(String channel)156 private String getServiceName(String channel) { 157 return CONNECTION_SERVICE_ID + channel; 158 } 159 isConnected()160 private boolean isConnected() { 161 synchronized (mConnectionStatus) { 162 return mConnectionStatus.state == ConnectionStatus.State.CONNECTED; 163 } 164 } 165 createWifiAwareSession()166 private CompletableFuture<WifiAwareSession> createWifiAwareSession() { 167 if (mWifiAwareSessionFuture.isDone() 168 && !mWifiAwareSessionFuture.isCompletedExceptionally()) { 169 return mWifiAwareSessionFuture; 170 } 171 172 Log.d(TAG, "Creating a new Wifi Aware session."); 173 WifiAwareManager wifiAwareManager = mContext.getSystemService(WifiAwareManager.class); 174 if (!mContext.getPackageManager().hasSystemFeature(PackageManager.FEATURE_WIFI_AWARE) 175 || wifiAwareManager == null 176 || !wifiAwareManager.isAvailable()) { 177 mWifiAwareSessionFuture.completeExceptionally( 178 new Exception("Wifi Aware is not available.")); 179 } else { 180 wifiAwareManager.attach( 181 new AttachCallback() { 182 @Override 183 public void onAttached(WifiAwareSession session) { 184 Log.d(TAG, "New Wifi Aware attached."); 185 mWifiAwareSessionFuture.complete(session); 186 } 187 188 @Override 189 public void onAttachFailed() { 190 mWifiAwareSessionFuture.completeExceptionally( 191 new Exception("Failed to attach Wifi Aware session.")); 192 } 193 }, 194 mBackgroundHandler); 195 } 196 mWifiAwareSessionFuture = mWifiAwareSessionFuture 197 .exceptionally(e -> { 198 Log.e(TAG, "Failed to create Wifi Aware session", e); 199 onError("Failed to create Wifi Aware session"); 200 return null; 201 }); 202 return mWifiAwareSessionFuture; 203 } 204 205 /** Explicitly terminate any existing connection. */ disconnect()206 public void disconnect() { 207 Log.d(TAG, "Terminating connections."); 208 if (mDiscoverySession != null) { 209 mDiscoverySession.close(); 210 mDiscoverySession = null; 211 } 212 synchronized (mConnectionStatus) { 213 mConnectionStatus.state = ConnectionStatus.State.DISCONNECTED; 214 notifyStateChangedLocked(); 215 } 216 } 217 onSocketAvailable(Socket socket)218 private void onSocketAvailable(Socket socket) throws IOException { 219 mRemoteIo.initialize(socket.getInputStream(), mStreamClosedCallback); 220 mRemoteIo.initialize(socket.getOutputStream(), mStreamClosedCallback); 221 synchronized (mConnectionStatus) { 222 mConnectionStatus.state = ConnectionStatus.State.CONNECTED; 223 notifyStateChangedLocked(); 224 } 225 } 226 onInitialized()227 private void onInitialized() { 228 if (mDiscoverySession == null) { 229 return; 230 } 231 synchronized (mConnectionStatus) { 232 mConnectionStatus.state = ConnectionStatus.State.INITIALIZED; 233 notifyStateChangedLocked(); 234 } 235 } 236 onConnecting(byte[] remoteDeviceName)237 private void onConnecting(byte[] remoteDeviceName) { 238 synchronized (mConnectionStatus) { 239 mConnectionStatus.state = ConnectionStatus.State.CONNECTING; 240 mConnectionStatus.remoteDeviceName = new String(remoteDeviceName); 241 Log.d(TAG, "Connecting to " + mConnectionStatus.remoteDeviceName); 242 notifyStateChangedLocked(); 243 } 244 } 245 onError(String message)246 private void onError(String message) { 247 Log.e(TAG, "Error: " + message); 248 synchronized (mConnectionStatus) { 249 mConnectionStatus.state = ConnectionStatus.State.ERROR; 250 mConnectionStatus.errorMessage = message; 251 notifyStateChangedLocked(); 252 } 253 } 254 255 @GuardedBy("mConnectionStatus") notifyStateChangedLocked()256 private void notifyStateChangedLocked() { 257 Log.d(TAG, "Connection state changed: " + mConnectionStatus.state); 258 synchronized (mConnectionCallbacks) { 259 for (Consumer<ConnectionStatus> callback : mConnectionCallbacks) { 260 callback.accept(mConnectionStatus); 261 } 262 } 263 } 264 265 private abstract class VdmDiscoverySessionCallback extends DiscoverySessionCallback { 266 267 private NetworkCallback mNetworkCallback; 268 269 @Override onPublishStarted(@onNull PublishDiscoverySession session)270 public void onPublishStarted(@NonNull PublishDiscoverySession session) { 271 mDiscoverySession = session; 272 onInitialized(); 273 } 274 275 @Override onSubscribeStarted(@onNull SubscribeDiscoverySession session)276 public void onSubscribeStarted(@NonNull SubscribeDiscoverySession session) { 277 mDiscoverySession = session; 278 onInitialized(); 279 } 280 281 @Override onServiceDiscovered( PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter)282 public void onServiceDiscovered( 283 PeerHandle peerHandle, byte[] serviceSpecificInfo, List<byte[]> matchFilter) { 284 Log.d(TAG, "Discovered service: " + new String(serviceSpecificInfo)); 285 sendLocalEndpointId(peerHandle); 286 } 287 288 @Override onSessionTerminated()289 public void onSessionTerminated() { 290 Log.d(TAG, "Discovery session terminated."); 291 if (mNetworkCallback != null) { 292 mConnectivityManager.unregisterNetworkCallback(mNetworkCallback); 293 mNetworkCallback = null; 294 } 295 } 296 sendLocalEndpointId(PeerHandle peerHandle)297 void sendLocalEndpointId(PeerHandle peerHandle) { 298 mDiscoverySession.sendMessage(peerHandle, 0, getLocalEndpointId().getBytes()); 299 } 300 301 @Override onMessageReceived(PeerHandle peerHandle, byte[] message)302 public void onMessageReceived(PeerHandle peerHandle, byte[] message) { 303 Log.d(TAG, "Received message: " + new String(message)); 304 if (isConnected()) { 305 return; 306 } 307 onConnecting(message); 308 establishConnection(peerHandle); 309 } 310 establishConnection(PeerHandle peerHandle)311 protected abstract void establishConnection(PeerHandle peerHandle); 312 requestNetwork( PeerHandle peerHandle, Optional<Integer> port, NetworkCallback networkCallback)313 void requestNetwork( 314 PeerHandle peerHandle, Optional<Integer> port, NetworkCallback networkCallback) { 315 WifiAwareNetworkSpecifier.Builder networkSpecifierBuilder; 316 networkSpecifierBuilder = 317 new WifiAwareNetworkSpecifier.Builder(mDiscoverySession, peerHandle) 318 .setPskPassphrase(CONNECTION_SERVICE_ID); 319 if (mNetworkCallback != null) { 320 mConnectivityManager.unregisterNetworkCallback(mNetworkCallback); 321 } 322 mNetworkCallback = networkCallback; 323 port.ifPresent(networkSpecifierBuilder::setPort); 324 325 NetworkRequest networkRequest = 326 new NetworkRequest.Builder() 327 .addTransportType(NetworkCapabilities.TRANSPORT_WIFI_AWARE) 328 .setNetworkSpecifier(networkSpecifierBuilder.build()) 329 .build(); 330 Log.d(TAG, "Requesting network"); 331 mConnectivityManager.requestNetwork( 332 networkRequest, mNetworkCallback, NETWORK_TIMEOUT_MS); 333 } 334 } 335 336 private final class HostDiscoverySessionCallback extends VdmDiscoverySessionCallback { 337 @Override establishConnection(PeerHandle peerHandle)338 protected void establishConnection(PeerHandle peerHandle) { 339 try { 340 ServerSocket serverSocket = new ServerSocket(0); 341 serverSocket.setSoTimeout(NETWORK_TIMEOUT_MS); 342 requestNetwork(peerHandle, Optional.of(serverSocket.getLocalPort()), 343 new NetworkCallback()); 344 sendLocalEndpointId(peerHandle); 345 onSocketAvailable(serverSocket.accept()); 346 } catch (SocketTimeoutException e) { 347 Log.e(TAG, "Socket timeout: " + e.getMessage()); 348 } catch (IOException e) { 349 onError("Failed to establish connection."); 350 } 351 } 352 } 353 354 private final class ClientDiscoverySessionCallback extends VdmDiscoverySessionCallback { 355 @Override establishConnection(PeerHandle peerHandle)356 protected void establishConnection(PeerHandle peerHandle) { 357 requestNetwork(peerHandle, /* port= */ Optional.empty(), new ClientNetworkCallback()); 358 } 359 } 360 361 private class NetworkCallback extends ConnectivityManager.NetworkCallback { 362 363 @Override onLost(@onNull Network network)364 public void onLost(@NonNull Network network) { 365 Log.d(TAG, "Network lost"); 366 onInitialized(); 367 } 368 369 @Override onUnavailable()370 public void onUnavailable() { 371 Log.d(TAG, "Network unavailable"); 372 onError("Network unavailable"); 373 } 374 } 375 376 private class ClientNetworkCallback extends NetworkCallback { 377 378 @Override onCapabilitiesChanged(@onNull Network network, @NonNull NetworkCapabilities networkCapabilities)379 public void onCapabilitiesChanged(@NonNull Network network, 380 @NonNull NetworkCapabilities networkCapabilities) { 381 if (isConnected()) { 382 return; 383 } 384 385 WifiAwareNetworkInfo peerAwareInfo = 386 (WifiAwareNetworkInfo) networkCapabilities.getTransportInfo(); 387 Inet6Address peerIpv6 = peerAwareInfo.getPeerIpv6Addr(); 388 int peerPort = peerAwareInfo.getPort(); 389 try { 390 Socket socket = network.getSocketFactory().createSocket(peerIpv6, peerPort); 391 onSocketAvailable(socket); 392 } catch (IOException e) { 393 Log.e(TAG, "Failed to establish connection.", e); 394 onError("Failed to establish connection."); 395 } 396 } 397 398 @Override onLost(@onNull Network network)399 public void onLost(@NonNull Network network) { 400 super.onLost(network); 401 startClientSession(mChannel); 402 } 403 } 404 } 405