• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsTransport"
18 //#define LOG_NDEBUG 0
19 
20 #include "dns/DnsTlsTransport.h"
21 
22 #include <arpa/inet.h>
23 #include <arpa/nameser.h>
24 
25 #include "dns/DnsTlsServer.h"
26 #include "dns/DnsTlsSocketFactory.h"
27 #include "dns/IDnsTlsSocketFactory.h"
28 
29 #include "log/log.h"
30 #include "Fwmark.h"
31 #undef ADD  // already defined in nameser.h
32 #include "NetdConstants.h"
33 #include "Permission.h"
34 
35 namespace android {
36 namespace net {
37 
query(const netdutils::Slice query)38 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
39     std::lock_guard<std::mutex> guard(mLock);
40 
41     auto record = mQueries.recordQuery(query);
42     if (!record) {
43         return std::async(std::launch::deferred, []{
44             return (Result) { .code = Response::internal_error };
45         });
46     }
47 
48     if (!mSocket) {
49         ALOGV("No socket for query.  Opening socket and sending.");
50         doConnect();
51     } else {
52         sendQuery(record->query);
53     }
54 
55     return std::move(record->result);
56 }
57 
sendQuery(const DnsTlsQueryMap::Query q)58 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
59     // Strip off the ID number and send the new ID instead.
60     bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
61     if (sent) {
62         mQueries.markTried(q.newId);
63     }
64     return sent;
65 }
66 
doConnect()67 void DnsTlsTransport::doConnect() {
68     ALOGV("Constructing new socket");
69     mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
70 
71     if (mSocket) {
72         auto queries = mQueries.getAll();
73         ALOGV("Initialization succeeded.  Reissuing %zu queries.", queries.size());
74         for(auto& q : queries) {
75             if (!sendQuery(q)) {
76                 break;
77             }
78         }
79     } else {
80         ALOGV("Initialization failed.");
81         mSocket.reset();
82         ALOGV("Failing all pending queries.");
83         mQueries.clear();
84     }
85 }
86 
onResponse(std::vector<uint8_t> response)87 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
88     mQueries.onResponse(std::move(response));
89 }
90 
onClosed()91 void DnsTlsTransport::onClosed() {
92     std::lock_guard<std::mutex> guard(mLock);
93     if (mClosing) {
94         return;
95     }
96     // Move remaining operations to a new thread.
97     // This is necessary because
98     // 1. onClosed is currently running on a thread that blocks mSocket's destructor
99     // 2. doReconnect will call that destructor
100     if (mReconnectThread) {
101         // Complete cleanup of a previous reconnect thread, if present.
102         mReconnectThread->join();
103         // Joining a thread that is trying to acquire mLock, while holding mLock,
104         // looks like it risks a deadlock.  However, a deadlock will not occur because
105         // once onClosed is called, it cannot be called again until after doReconnect
106         // acquires mLock.
107     }
108     mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
109 }
110 
doReconnect()111 void DnsTlsTransport::doReconnect() {
112     std::lock_guard<std::mutex> guard(mLock);
113     if (mClosing) {
114         return;
115     }
116     mQueries.cleanup();
117     if (!mQueries.empty()) {
118         ALOGV("Fast reconnect to retry remaining queries");
119         doConnect();
120     } else {
121         ALOGV("No pending queries.  Going idle.");
122         mSocket.reset();
123     }
124 }
125 
~DnsTlsTransport()126 DnsTlsTransport::~DnsTlsTransport() {
127     ALOGV("Destructor");
128     {
129         std::lock_guard<std::mutex> guard(mLock);
130         ALOGV("Locked destruction procedure");
131         mQueries.clear();
132         mClosing = true;
133     }
134     // It's possible that a reconnect thread was spawned and waiting for mLock.
135     // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
136     // but we need to wait for it to finish before allowing destruction to proceed.
137     if (mReconnectThread) {
138         ALOGV("Waiting for reconnect thread to terminate");
139         mReconnectThread->join();
140         mReconnectThread.reset();
141     }
142     // Ensure that the socket is destroyed, and can clean up its callback threads,
143     // before any of this object's fields become invalid.
144     mSocket.reset();
145     ALOGV("Destructor completed");
146 }
147 
148 // static
149 // TODO: Use this function to preheat the session cache.
150 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,unsigned netid)151 bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid) {
152     ALOGV("Beginning validation on %u", netid);
153     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
154     // order to prove that it is actually a working DNS over TLS server.
155     static const char kDnsSafeChars[] =
156             "abcdefhijklmnopqrstuvwxyz"
157             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
158             "0123456789";
159     const auto c = [](uint8_t rnd) -> uint8_t {
160         return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
161     };
162     uint8_t rnd[8];
163     arc4random_buf(rnd, ARRAY_SIZE(rnd));
164     // We could try to use res_mkquery() here, but it's basically the same.
165     uint8_t query[] = {
166         rnd[6], rnd[7],  // [0-1]   query ID
167         1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
168         0, 1,  // [4-5]   QDCOUNT (number of queries)
169         0, 0,  // [6-7]   ANCOUNT (number of answers)
170         0, 0,  // [8-9]   NSCOUNT (number of name server records)
171         0, 0,  // [10-11] ARCOUNT (number of additional records)
172         17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
173             '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
174         6, 'm', 'e', 't', 'r', 'i', 'c',
175         7, 'g', 's', 't', 'a', 't', 'i', 'c',
176         3, 'c', 'o', 'm',
177         0,  // null terminator of FQDN (root TLD)
178         0, ns_t_aaaa,  // QTYPE
179         0, ns_c_in     // QCLASS
180     };
181     const int qlen = ARRAY_SIZE(query);
182 
183     // At validation time, we only know the netId, so we have to guess/compute the
184     // corresponding socket mark.
185     Fwmark fwmark;
186     fwmark.permission = PERMISSION_SYSTEM;
187     fwmark.explicitlySelected = true;
188     fwmark.protectedFromVpn = true;
189     fwmark.netId = netid;
190     unsigned mark = fwmark.intValue;
191     int replylen = 0;
192     DnsTlsSocketFactory factory;
193     DnsTlsTransport transport(server, mark, &factory);
194     auto r = transport.query(Slice(query, qlen)).get();
195     if (r.code != Response::success) {
196         ALOGV("query failed");
197         return false;
198     }
199 
200     const std::vector<uint8_t>& recvbuf = r.response;
201     if (recvbuf.size() < NS_HFIXEDSZ) {
202         ALOGW("short response: %d", replylen);
203         return false;
204     }
205 
206     const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
207     if (qdcount != 1) {
208         ALOGW("reply query count != 1: %d", qdcount);
209         return false;
210     }
211 
212     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
213     ALOGV("%u answer count: %d", netid, ancount);
214 
215     // TODO: Further validate the response contents (check for valid AAAA record, ...).
216     // Note that currently, integration tests rely on this function accepting a
217     // response with zero records.
218 #if 0
219     for (int i = 0; i < resplen; i++) {
220         ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
221     }
222 #endif
223     return true;
224 }
225 
226 }  // end of namespace net
227 }  // end of namespace android
228