/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.server.connectivity.mdns;
import android.Manifest.permission;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.RequiresPermission;
import android.content.Context;
import android.net.Network;
import android.net.wifi.WifiManager.MulticastLock;
import android.os.SystemClock;
import android.text.format.DateUtils;
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.CollectionUtils;
import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.util.MdnsUtils;
import java.io.IOException;
import java.net.DatagramPacket;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetSocketAddress;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* The {@link MdnsSocketClient} maintains separate threads to send and receive mDNS packets for all
* the requested service types.
*
*
See https://tools.ietf.org/html/rfc6763 (namely sections 4 and 5).
*/
public class MdnsSocketClient implements MdnsSocketClientBase {
private static final String TAG = "MdnsClient";
// TODO: The following values are copied from cast module. We need to think about the
// better way to share those.
private static final String CAST_SENDER_LOG_SOURCE = "CAST_SENDER_SDK";
private static final String CAST_PREFS_NAME = "google_cast";
private static final String PREF_CAST_SENDER_ID = "PREF_CAST_SENDER_ID";
private static final String MULTICAST_TYPE = "multicast";
private static final String UNICAST_TYPE = "unicast";
private static final long SLEEP_TIME_FOR_SOCKET_THREAD_MS =
MdnsConfigs.sleepTimeForSocketThreadMs();
// A value of 0 leads to an infinite wait.
private static final long THREAD_JOIN_TIMEOUT_MS = DateUtils.SECOND_IN_MILLIS;
private static final int RECEIVER_BUFFER_SIZE = 2048;
@VisibleForTesting
final Queue multicastPacketQueue = new ArrayDeque<>();
@VisibleForTesting
final Queue unicastPacketQueue = new ArrayDeque<>();
private final Context context;
private final byte[] multicastReceiverBuffer = new byte[RECEIVER_BUFFER_SIZE];
@Nullable private final byte[] unicastReceiverBuffer;
private final MulticastLock multicastLock;
private final boolean useSeparateSocketForUnicast =
MdnsConfigs.useSeparateSocketToSendUnicastQuery();
private final boolean checkMulticastResponse = MdnsConfigs.checkMulticastResponse();
private final long checkMulticastResponseIntervalMs =
MdnsConfigs.checkMulticastResponseIntervalMs();
private final boolean propagateInterfaceIndex =
MdnsConfigs.allowNetworkInterfaceIndexPropagation();
private final Object socketLock = new Object();
private final Object timerObject = new Object();
// If multicast response was received in the current session. The value is reset in the
// beginning of each session.
@VisibleForTesting
boolean receivedMulticastResponse;
// If unicast response was received in the current session. The value is reset in the beginning
// of each session.
@VisibleForTesting
boolean receivedUnicastResponse;
// If the phone is the bad state where it can't receive any multicast response.
@VisibleForTesting
AtomicBoolean cannotReceiveMulticastResponse = new AtomicBoolean(false);
@VisibleForTesting @Nullable volatile Thread sendThread;
@VisibleForTesting @Nullable Thread multicastReceiveThread;
@VisibleForTesting @Nullable Thread unicastReceiveThread;
private volatile boolean shouldStopSocketLoop;
@Nullable private Callback callback;
@Nullable private MdnsSocket multicastSocket;
@Nullable private MdnsSocket unicastSocket;
private int receivedPacketNumber = 0;
@Nullable private Timer logMdnsPacketTimer;
private AtomicInteger packetsCount;
@Nullable private Timer checkMulticastResponseTimer;
private final SharedLog sharedLog;
@NonNull private final MdnsFeatureFlags mdnsFeatureFlags;
private final MulticastNetworkInterfaceProvider interfaceProvider;
public MdnsSocketClient(@NonNull Context context, @NonNull MulticastLock multicastLock,
SharedLog sharedLog, @NonNull MdnsFeatureFlags mdnsFeatureFlags) {
this.sharedLog = sharedLog;
this.context = context;
this.multicastLock = multicastLock;
if (useSeparateSocketForUnicast) {
unicastReceiverBuffer = new byte[RECEIVER_BUFFER_SIZE];
} else {
unicastReceiverBuffer = null;
}
this.mdnsFeatureFlags = mdnsFeatureFlags;
this.interfaceProvider = new MulticastNetworkInterfaceProvider(context, sharedLog);
}
@Override
public synchronized void setCallback(@Nullable Callback callback) {
this.callback = callback;
}
@RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
@Override
public synchronized void startDiscovery() throws IOException {
if (multicastSocket != null) {
sharedLog.w("Discovery is already in progress.");
return;
}
receivedMulticastResponse = false;
receivedUnicastResponse = false;
cannotReceiveMulticastResponse.set(false);
shouldStopSocketLoop = false;
interfaceProvider.startWatchingConnectivityChanges();
try {
// TODO (changed when importing code): consider setting thread stats tag
multicastSocket = createMdnsSocket(MdnsConstants.MDNS_PORT, sharedLog);
multicastSocket.joinGroup();
if (useSeparateSocketForUnicast) {
// For unicast, use port 0 and the system will assign it with any available port.
unicastSocket = createMdnsSocket(0, sharedLog);
}
multicastLock.acquire();
} catch (IOException e) {
multicastLock.release();
if (multicastSocket != null) {
multicastSocket.close();
multicastSocket = null;
}
if (unicastSocket != null) {
unicastSocket.close();
unicastSocket = null;
}
throw e;
} finally {
// TODO (changed when importing code): consider resetting thread stats tag
}
createAndStartSendThread();
createAndStartReceiverThreads();
}
@RequiresPermission(permission.CHANGE_WIFI_MULTICAST_STATE)
@Override
public void stopDiscovery() {
sharedLog.log("Stop discovery.");
if (multicastSocket == null && unicastSocket == null) {
return;
}
if (MdnsConfigs.clearMdnsPacketQueueAfterDiscoveryStops()) {
synchronized (multicastPacketQueue) {
multicastPacketQueue.clear();
}
synchronized (unicastPacketQueue) {
unicastPacketQueue.clear();
}
}
multicastLock.release();
interfaceProvider.stopWatchingConnectivityChanges();
shouldStopSocketLoop = true;
waitForSendThreadToStop();
waitForReceiverThreadsToStop();
synchronized (socketLock) {
multicastSocket = null;
unicastSocket = null;
}
synchronized (timerObject) {
if (checkMulticastResponseTimer != null) {
checkMulticastResponseTimer.cancel();
checkMulticastResponseTimer = null;
}
}
}
@Override
public void sendPacketRequestingMulticastResponse(@NonNull List packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks) {
sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
}
@Override
public void sendPacketRequestingUnicastResponse(@NonNull List packets,
boolean onlyUseIpv6OnIpv6OnlyNetworks) {
if (useSeparateSocketForUnicast) {
sendMdnsPackets(packets, unicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
} else {
sendMdnsPackets(packets, multicastPacketQueue, onlyUseIpv6OnIpv6OnlyNetworks);
}
}
@Override
public void notifyNetworkRequested(
@NonNull MdnsServiceBrowserListener listener,
@Nullable Network network,
@NonNull SocketCreationCallback socketCreationCallback) {
if (network != null) {
throw new IllegalArgumentException("This socket client does not support requesting "
+ "specific networks");
}
socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex()));
}
@Override
public boolean supportsRequestingSpecificNetworks() {
return false;
}
private void sendMdnsPackets(List packets,
Queue packetQueueToUse, boolean onlyUseIpv6OnIpv6OnlyNetworks) {
if (shouldStopSocketLoop && !MdnsConfigs.allowAddMdnsPacketAfterDiscoveryStops()) {
sharedLog.w("sendMdnsPacket() is called after discovery already stopped");
return;
}
if (packets.isEmpty()) {
Log.wtf(TAG, "No mDns packets to send");
return;
}
// Check all packets with the same address
if (!MdnsUtils.checkAllPacketsWithSameAddress(packets)) {
Log.wtf(TAG, "Some mDNS packets have a different target address. addresses="
+ CollectionUtils.map(packets, DatagramPacket::getSocketAddress));
return;
}
final boolean isIpv4 = ((InetSocketAddress) packets.get(0).getSocketAddress())
.getAddress() instanceof Inet4Address;
final boolean isIpv6 = ((InetSocketAddress) packets.get(0).getSocketAddress())
.getAddress() instanceof Inet6Address;
final boolean ipv6Only = multicastSocket != null && multicastSocket.isOnIPv6OnlyNetwork();
if (isIpv4 && ipv6Only) {
return;
}
if (isIpv6 && !ipv6Only && onlyUseIpv6OnIpv6OnlyNetworks) {
return;
}
synchronized (packetQueueToUse) {
while ((packetQueueToUse.size() + packets.size())
> MdnsConfigs.mdnsPacketQueueMaxSize()) {
packetQueueToUse.remove();
}
packetQueueToUse.addAll(packets);
}
triggerSendThread();
}
private void createAndStartSendThread() {
if (sendThread != null) {
sharedLog.w("A socket thread already exists.");
return;
}
sendThread = new Thread(this::sendThreadMain);
sendThread.setName("mdns-send");
sendThread.start();
}
private void createAndStartReceiverThreads() {
if (multicastReceiveThread != null) {
sharedLog.w("A multicast receiver thread already exists.");
return;
}
multicastReceiveThread =
new Thread(() -> receiveThreadMain(multicastReceiverBuffer, multicastSocket));
multicastReceiveThread.setName("mdns-multicast-receive");
multicastReceiveThread.start();
if (useSeparateSocketForUnicast) {
unicastReceiveThread =
new Thread(
() -> {
if (unicastReceiverBuffer != null) {
receiveThreadMain(unicastReceiverBuffer, unicastSocket);
}
});
unicastReceiveThread.setName("mdns-unicast-receive");
unicastReceiveThread.start();
}
}
private void triggerSendThread() {
sharedLog.log("Trigger send thread.");
Thread sendThread = this.sendThread;
if (sendThread != null) {
sendThread.interrupt();
} else {
sharedLog.w("Socket thread is null");
}
}
private void waitForReceiverThreadsToStop() {
if (multicastReceiveThread != null) {
waitForThread(multicastReceiveThread);
multicastReceiveThread = null;
}
if (unicastReceiveThread != null) {
waitForThread(unicastReceiveThread);
unicastReceiveThread = null;
}
}
private void waitForSendThreadToStop() {
sharedLog.log("wait For Send Thread To Stop");
if (sendThread == null) {
sharedLog.w("socket thread is already dead.");
return;
}
waitForThread(sendThread);
sendThread = null;
}
private void waitForThread(Thread thread) {
long startMs = SystemClock.elapsedRealtime();
long waitMs = THREAD_JOIN_TIMEOUT_MS;
while (thread.isAlive() && (waitMs > 0)) {
try {
thread.interrupt();
thread.join(waitMs);
if (thread.isAlive()) {
sharedLog.w("Failed to join thread: " + thread);
}
break;
} catch (InterruptedException e) {
// Compute remaining time after at least a single join call, in case the clock
// resolution is poor.
waitMs = THREAD_JOIN_TIMEOUT_MS - (SystemClock.elapsedRealtime() - startMs);
}
}
}
private void sendThreadMain() {
List multicastPacketsToSend = new ArrayList<>();
List unicastPacketsToSend = new ArrayList<>();
boolean shouldThreadSleep;
try {
while (!shouldStopSocketLoop) {
try {
// Make a local copy of all packets, and clear the queue.
// Send packets that ask for multicast response.
multicastPacketsToSend.clear();
synchronized (multicastPacketQueue) {
multicastPacketsToSend.addAll(multicastPacketQueue);
multicastPacketQueue.clear();
}
// Send packets that ask for unicast response.
if (useSeparateSocketForUnicast) {
unicastPacketsToSend.clear();
synchronized (unicastPacketQueue) {
unicastPacketsToSend.addAll(unicastPacketQueue);
unicastPacketQueue.clear();
}
if (unicastSocket != null) {
sendPackets(unicastPacketsToSend, unicastSocket);
}
}
// Send multicast packets.
if (multicastSocket != null) {
sendPackets(multicastPacketsToSend, multicastSocket);
}
// Sleep ONLY if no more packets have been added to the queue, while packets
// were being sent.
synchronized (multicastPacketQueue) {
synchronized (unicastPacketQueue) {
shouldThreadSleep =
multicastPacketQueue.isEmpty() && unicastPacketQueue.isEmpty();
}
}
if (shouldThreadSleep) {
Thread.sleep(SLEEP_TIME_FOR_SOCKET_THREAD_MS);
}
} catch (InterruptedException e) {
// Don't log the interruption as it's expected.
}
}
} finally {
sharedLog.log("Send thread stopped.");
try {
if (multicastSocket != null) {
multicastSocket.leaveGroup();
}
} catch (Exception t) {
sharedLog.e("Failed to leave the group.", t);
}
// Close the socket first. This is the only way to interrupt a blocking receive.
try {
// This is a race with the use of the file descriptor (b/27403984).
if (multicastSocket != null) {
multicastSocket.close();
}
if (unicastSocket != null) {
unicastSocket.close();
}
} catch (RuntimeException t) {
sharedLog.e("Failed to close the mdns socket.", t);
}
}
}
private void receiveThreadMain(byte[] receiverBuffer, @Nullable MdnsSocket socket) {
DatagramPacket packet = new DatagramPacket(receiverBuffer, receiverBuffer.length);
while (!shouldStopSocketLoop) {
try {
// This is a race with the use of the file descriptor (b/27403984).
synchronized (socketLock) {
// This checks is to make sure the socket was not set to null.
if (socket != null && (socket == multicastSocket || socket == unicastSocket)) {
socket.receive(packet);
}
}
if (!shouldStopSocketLoop) {
String responseType = socket == multicastSocket ? MULTICAST_TYPE : UNICAST_TYPE;
processResponsePacket(
packet,
responseType,
/* interfaceIndex= */ (socket == null || !propagateInterfaceIndex)
? MdnsSocket.INTERFACE_INDEX_UNSPECIFIED
: socket.getInterfaceIndex(),
/* network= */ socket.getNetwork());
}
} catch (IOException e) {
if (!shouldStopSocketLoop) {
sharedLog.e("Failed to receive mDNS packets.", e);
}
}
}
sharedLog.log("Receive thread stopped.");
}
private int processResponsePacket(@NonNull DatagramPacket packet, String responseType,
int interfaceIndex, @Nullable Network network) {
int packetNumber = ++receivedPacketNumber;
final MdnsPacket response;
try {
response = MdnsResponseDecoder.parseResponse(
packet.getData(), packet.getLength(), mdnsFeatureFlags);
} catch (MdnsPacket.ParseException e) {
sharedLog.w(String.format("Error while decoding %s packet (%d): %d",
responseType, packetNumber, e.code));
if (callback != null) {
callback.onFailedToParseMdnsResponse(packetNumber, e.code,
new SocketKey(network, interfaceIndex));
}
return e.code;
}
if (response == null) {
return MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE;
}
if (callback != null) {
callback.onResponseReceived(
response, new SocketKey(network, interfaceIndex));
}
return MdnsResponseErrorCode.SUCCESS;
}
@VisibleForTesting
MdnsSocket createMdnsSocket(int port, SharedLog sharedLog) throws IOException {
return new MdnsSocket(interfaceProvider, port, sharedLog);
}
private void sendPackets(List packets, MdnsSocket socket) {
String requestType = socket == multicastSocket ? "multicast" : "unicast";
for (DatagramPacket packet : packets) {
if (shouldStopSocketLoop) {
break;
}
try {
sharedLog.log(String.format("Sending a %s mDNS packet...", requestType));
socket.send(packet);
// Start the timer task to monitor the response.
synchronized (timerObject) {
if (socket == multicastSocket) {
if (cannotReceiveMulticastResponse.get()) {
// Don't schedule the timer task if we are already in the bad state.
return;
}
if (checkMulticastResponseTimer != null) {
// Don't schedule the timer task if it's already scheduled.
return;
}
if (checkMulticastResponse && useSeparateSocketForUnicast) {
// Only when useSeparateSocketForUnicast is true, we can tell if we
// received a multicast or unicast response.
checkMulticastResponseTimer = new Timer();
checkMulticastResponseTimer.schedule(
new TimerTask() {
@Override
public void run() {
synchronized (timerObject) {
if (checkMulticastResponseTimer == null) {
// Discovery already stopped.
return;
}
if ((!receivedMulticastResponse)
&& receivedUnicastResponse) {
sharedLog.e(String.format(
"Haven't received multicast response"
+ " in the last %d ms.",
checkMulticastResponseIntervalMs));
cannotReceiveMulticastResponse.set(true);
}
checkMulticastResponseTimer = null;
}
}
},
checkMulticastResponseIntervalMs);
}
}
}
} catch (IOException e) {
sharedLog.e(String.format("Failed to send a %s mDNS packet.", requestType), e);
}
}
packets.clear();
}
}