/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.connectivity.mdns;

import static com.android.server.connectivity.mdns.MdnsServiceTypeClient.INVALID_TRANSACTION_ID;

import android.annotation.NonNull;
import android.os.Build;
import android.text.TextUtils;
import android.util.Pair;

import com.android.net.module.util.CollectionUtils;
import com.android.net.module.util.DnsUtils;
import com.android.net.module.util.SharedLog;
import com.android.server.connectivity.mdns.util.MdnsUtils;

import java.io.IOException;
import java.lang.ref.WeakReference;
import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;

/**
 * A {@link Callable} that builds and enqueues a mDNS query to send over the multicast socket. If a
 * query is built and enqueued successfully, then call to {@link #call()} returns the transaction ID
 * and the list of the subtypes in the query as a {@link Pair}. If a query is failed to build, or if
 * it can not be enqueued, then call to {@link #call()} returns {@code null}.
 */
public class EnqueueMdnsQueryCallable implements Callable<Pair<Integer, List<String>>> {

    private static final String TAG = "MdnsQueryCallable";
    private static final List<Integer> castShellEmulatorMdnsPorts;

    static {
        castShellEmulatorMdnsPorts = new ArrayList<>();
        String[] stringPorts = MdnsConfigs.castShellEmulatorMdnsPorts();

        for (String port : stringPorts) {
            try {
                castShellEmulatorMdnsPorts.add(Integer.parseInt(port));
            } catch (NumberFormatException e) {
                // Ignore.
            }
        }
    }

    @NonNull
    private final WeakReference<MdnsSocketClientBase> weakRequestSender;
    @NonNull
    private final String[] serviceTypeLabels;
    @NonNull
    private final List<String> subtypes;
    private final boolean expectUnicastResponse;
    private final int transactionId;
    @NonNull
    private final SocketKey socketKey;
    private final boolean sendDiscoveryQueries;
    @NonNull
    private final List<MdnsResponse> servicesToResolve;
    @NonNull
    private final MdnsUtils.Clock clock;
    @NonNull
    private final SharedLog sharedLog;
    @NonNull
    private final MdnsServiceTypeClient.Dependencies dependencies;
    private final boolean onlyUseIpv6OnIpv6OnlyNetworks;
    private final byte[] packetCreationBuffer = new byte[1500]; // TODO: use interface MTU
    @NonNull
    private final List<MdnsResponse> existingServices;
    private final boolean isQueryWithKnownAnswer;

    EnqueueMdnsQueryCallable(
            @NonNull MdnsSocketClientBase requestSender,
            @NonNull String serviceType,
            @NonNull Collection<String> subtypes,
            boolean expectUnicastResponse,
            int transactionId,
            @NonNull SocketKey socketKey,
            boolean onlyUseIpv6OnIpv6OnlyNetworks,
            boolean sendDiscoveryQueries,
            @NonNull Collection<MdnsResponse> servicesToResolve,
            @NonNull MdnsUtils.Clock clock,
            @NonNull SharedLog sharedLog,
            @NonNull MdnsServiceTypeClient.Dependencies dependencies,
            @NonNull Collection<MdnsResponse> existingServices,
            boolean isQueryWithKnownAnswer) {
        weakRequestSender = new WeakReference<>(requestSender);
        serviceTypeLabels = TextUtils.split(serviceType, "\\.");
        this.subtypes = new ArrayList<>(subtypes);
        this.expectUnicastResponse = expectUnicastResponse;
        this.transactionId = transactionId;
        this.socketKey = socketKey;
        this.onlyUseIpv6OnIpv6OnlyNetworks = onlyUseIpv6OnIpv6OnlyNetworks;
        this.sendDiscoveryQueries = sendDiscoveryQueries;
        this.servicesToResolve = new ArrayList<>(servicesToResolve);
        this.clock = clock;
        this.sharedLog = sharedLog;
        this.dependencies = dependencies;
        this.existingServices = new ArrayList<>(existingServices);
        this.isQueryWithKnownAnswer = isQueryWithKnownAnswer;
    }

    /**
     * Call to execute the mdns query.
     *
     * @return The pair of transaction id and the subtypes for the query.
     */
    // Incompatible return type for override of Callable#call().
    @SuppressWarnings("nullness:override.return.invalid")
    @Override
    public Pair<Integer, List<String>> call() {
        try {
            MdnsSocketClientBase requestSender = weakRequestSender.get();
            if (requestSender == null) {
                return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
            }

            final List<MdnsRecord> questions = new ArrayList<>();

            if (sendDiscoveryQueries) {
                // Base service type
                questions.add(new MdnsPointerRecord(serviceTypeLabels, expectUnicastResponse));
                for (String subtype : subtypes) {
                    final String[] labels = MdnsUtils.constructFullSubtype(serviceTypeLabels,
                            MdnsConstants.SUBTYPE_PREFIX + subtype);
                    questions.add(new MdnsPointerRecord(labels, expectUnicastResponse));
                }
            }

            // List of (name, type) to query
            final long now = clock.elapsedRealtime();
            for (MdnsResponse response : servicesToResolve) {
                final String[] serviceName = response.getServiceName();
                if (serviceName == null) continue;
                boolean renewTxt = !response.hasTextRecord() || MdnsUtils.isRecordRenewalNeeded(
                        response.getTextRecord(), now);
                boolean renewSrv = !response.hasServiceRecord() || MdnsUtils.isRecordRenewalNeeded(
                        response.getServiceRecord(), now);
                if (renewSrv && renewTxt) {
                    questions.add(new MdnsAnyRecord(serviceName, expectUnicastResponse));
                } else {
                    if (renewTxt) {
                        questions.add(new MdnsTextRecord(serviceName, expectUnicastResponse));
                    }
                    if (renewSrv) {
                        questions.add(new MdnsServiceRecord(serviceName, expectUnicastResponse));
                        // The hostname is not yet known, so queries for address records will be
                        // sent the next time the EnqueueMdnsQueryCallable is enqueued if the reply
                        // does not contain them. In practice, advertisers should include the
                        // address records when queried for SRV, although it's not a MUST
                        // requirement (RFC6763 12.2).
                    } else if (!response.hasInet4AddressRecord()
                            && !response.hasInet6AddressRecord()) {
                        final String[] host = response.getServiceRecord().getServiceHost();
                        questions.add(new MdnsInetAddressRecord(
                                host, MdnsRecord.TYPE_A, expectUnicastResponse));
                        questions.add(new MdnsInetAddressRecord(
                                host, MdnsRecord.TYPE_AAAA, expectUnicastResponse));
                    }
                }
            }

            if (questions.size() == 0) {
                // No query to send
                return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
            }

            // Put the existing ptr records into known-answer section.
            final List<MdnsRecord> knownAnswers = new ArrayList<>();
            if (sendDiscoveryQueries) {
                for (MdnsResponse existingService : existingServices) {
                    for (MdnsPointerRecord ptrRecord : existingService.getPointerRecords()) {
                        // Ignore any PTR records that don't match the current query.
                        if (!CollectionUtils.any(questions,
                                q -> q instanceof MdnsPointerRecord
                                        && DnsUtils.equalsDnsLabelIgnoreDnsCase(
                                                q.getName(), ptrRecord.getName()))) {
                            continue;
                        }

                        knownAnswers.add(new MdnsPointerRecord(
                                ptrRecord.getName(),
                                ptrRecord.getReceiptTime(),
                                ptrRecord.getCacheFlush(),
                                ptrRecord.getRemainingTTL(now), // Put the remaining ttl.
                                ptrRecord.getPointer()));
                    }
                }
            }

            final MdnsPacket queryPacket = new MdnsPacket(
                    transactionId,
                    MdnsConstants.FLAGS_QUERY,
                    questions,
                    knownAnswers,
                    Collections.emptyList(), /* authorityRecords */
                    Collections.emptyList() /* additionalRecords */);
            sendPacketToIpv4AndIpv6(requestSender, MdnsConstants.MDNS_PORT, queryPacket);
            for (Integer emulatorPort : castShellEmulatorMdnsPorts) {
                sendPacketToIpv4AndIpv6(requestSender, emulatorPort, queryPacket);
            }
            return Pair.create(transactionId, subtypes);
        } catch (Exception e) {
            sharedLog.e(String.format("Failed to create mDNS packet for subtype: %s.",
                    TextUtils.join(",", subtypes)), e);
            return Pair.create(INVALID_TRANSACTION_ID, new ArrayList<>());
        }
    }

    private void sendPacket(MdnsSocketClientBase requestSender, InetSocketAddress address,
            MdnsPacket mdnsPacket) throws IOException {
        final List<DatagramPacket> packets = dependencies.getDatagramPacketsFromMdnsPacket(
                packetCreationBuffer, mdnsPacket, address, isQueryWithKnownAnswer);
        if (expectUnicastResponse) {
            // MdnsMultinetworkSocketClient is only available on T+
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                    && requestSender instanceof MdnsMultinetworkSocketClient) {
                ((MdnsMultinetworkSocketClient) requestSender).sendPacketRequestingUnicastResponse(
                        packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
            } else {
                requestSender.sendPacketRequestingUnicastResponse(
                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
            }
        } else {
            if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU
                    && requestSender instanceof MdnsMultinetworkSocketClient) {
                ((MdnsMultinetworkSocketClient) requestSender)
                        .sendPacketRequestingMulticastResponse(
                                packets, socketKey, onlyUseIpv6OnIpv6OnlyNetworks);
            } else {
                requestSender.sendPacketRequestingMulticastResponse(
                        packets, onlyUseIpv6OnIpv6OnlyNetworks);
            }
        }
    }

    private void sendPacketToIpv4AndIpv6(MdnsSocketClientBase requestSender, int port,
            MdnsPacket mdnsPacket) {
        try {
            sendPacket(requestSender,
                    new InetSocketAddress(MdnsConstants.getMdnsIPv4Address(), port), mdnsPacket);
        } catch (IOException e) {
            sharedLog.e("Can't send packet to IPv4", e);
        }
        try {
            sendPacket(requestSender,
                    new InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), port), mdnsPacket);
        } catch (IOException e) {
            sharedLog.e("Can't send packet to IPv6", e);
        }
    }
}