• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsTransport"
18 //#define LOG_NDEBUG 0
19 
20 #include "DnsTlsTransport.h"
21 
22 #include <arpa/inet.h>
23 #include <arpa/nameser.h>
24 
25 #include "DnsTlsSocketFactory.h"
26 #include "IDnsTlsSocketFactory.h"
27 
28 #include "log/log.h"
29 
30 namespace android {
31 namespace net {
32 
query(const netdutils::Slice query)33 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
34     std::lock_guard guard(mLock);
35 
36     auto record = mQueries.recordQuery(query);
37     if (!record) {
38         return std::async(std::launch::deferred, []{
39             return (Result) { .code = Response::internal_error };
40         });
41     }
42 
43     if (!mSocket) {
44         ALOGV("No socket for query.  Opening socket and sending.");
45         doConnect();
46     } else {
47         sendQuery(record->query);
48     }
49 
50     return std::move(record->result);
51 }
52 
sendQuery(const DnsTlsQueryMap::Query q)53 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
54     // Strip off the ID number and send the new ID instead.
55     bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
56     if (sent) {
57         mQueries.markTried(q.newId);
58     }
59     return sent;
60 }
61 
doConnect()62 void DnsTlsTransport::doConnect() {
63     ALOGV("Constructing new socket");
64     mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
65 
66     if (mSocket) {
67         auto queries = mQueries.getAll();
68         ALOGV("Initialization succeeded.  Reissuing %zu queries.", queries.size());
69         for(auto& q : queries) {
70             if (!sendQuery(q)) {
71                 break;
72             }
73         }
74     } else {
75         ALOGV("Initialization failed.");
76         mSocket.reset();
77         ALOGV("Failing all pending queries.");
78         mQueries.clear();
79     }
80 }
81 
onResponse(std::vector<uint8_t> response)82 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
83     mQueries.onResponse(std::move(response));
84 }
85 
onClosed()86 void DnsTlsTransport::onClosed() {
87     std::lock_guard guard(mLock);
88     if (mClosing) {
89         return;
90     }
91     // Move remaining operations to a new thread.
92     // This is necessary because
93     // 1. onClosed is currently running on a thread that blocks mSocket's destructor
94     // 2. doReconnect will call that destructor
95     if (mReconnectThread) {
96         // Complete cleanup of a previous reconnect thread, if present.
97         mReconnectThread->join();
98         // Joining a thread that is trying to acquire mLock, while holding mLock,
99         // looks like it risks a deadlock.  However, a deadlock will not occur because
100         // once onClosed is called, it cannot be called again until after doReconnect
101         // acquires mLock.
102     }
103     mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
104 }
105 
doReconnect()106 void DnsTlsTransport::doReconnect() {
107     std::lock_guard guard(mLock);
108     if (mClosing) {
109         return;
110     }
111     mQueries.cleanup();
112     if (!mQueries.empty()) {
113         ALOGV("Fast reconnect to retry remaining queries");
114         doConnect();
115     } else {
116         ALOGV("No pending queries.  Going idle.");
117         mSocket.reset();
118     }
119 }
120 
~DnsTlsTransport()121 DnsTlsTransport::~DnsTlsTransport() {
122     ALOGV("Destructor");
123     {
124         std::lock_guard guard(mLock);
125         ALOGV("Locked destruction procedure");
126         mQueries.clear();
127         mClosing = true;
128     }
129     // It's possible that a reconnect thread was spawned and waiting for mLock.
130     // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
131     // but we need to wait for it to finish before allowing destruction to proceed.
132     if (mReconnectThread) {
133         ALOGV("Waiting for reconnect thread to terminate");
134         mReconnectThread->join();
135         mReconnectThread.reset();
136     }
137     // Ensure that the socket is destroyed, and can clean up its callback threads,
138     // before any of this object's fields become invalid.
139     mSocket.reset();
140     ALOGV("Destructor completed");
141 }
142 
143 // static
144 // TODO: Use this function to preheat the session cache.
145 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,unsigned netid,uint32_t mark)146 bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
147     ALOGV("Beginning validation on %u", netid);
148     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
149     // order to prove that it is actually a working DNS over TLS server.
150     static const char kDnsSafeChars[] =
151             "abcdefhijklmnopqrstuvwxyz"
152             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
153             "0123456789";
154     const auto c = [](uint8_t rnd) -> uint8_t {
155         return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
156     };
157     uint8_t rnd[8];
158     arc4random_buf(rnd, std::size(rnd));
159     // We could try to use res_mkquery() here, but it's basically the same.
160     uint8_t query[] = {
161         rnd[6], rnd[7],  // [0-1]   query ID
162         1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
163         0, 1,  // [4-5]   QDCOUNT (number of queries)
164         0, 0,  // [6-7]   ANCOUNT (number of answers)
165         0, 0,  // [8-9]   NSCOUNT (number of name server records)
166         0, 0,  // [10-11] ARCOUNT (number of additional records)
167         17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
168             '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
169         6, 'm', 'e', 't', 'r', 'i', 'c',
170         7, 'g', 's', 't', 'a', 't', 'i', 'c',
171         3, 'c', 'o', 'm',
172         0,  // null terminator of FQDN (root TLD)
173         0, ns_t_aaaa,  // QTYPE
174         0, ns_c_in     // QCLASS
175     };
176     const int qlen = std::size(query);
177 
178     int replylen = 0;
179     DnsTlsSocketFactory factory;
180     DnsTlsTransport transport(server, mark, &factory);
181     auto r = transport.query(netdutils::Slice(query, qlen)).get();
182     if (r.code != Response::success) {
183         ALOGV("query failed");
184         return false;
185     }
186 
187     const std::vector<uint8_t>& recvbuf = r.response;
188     if (recvbuf.size() < NS_HFIXEDSZ) {
189         ALOGW("short response: %d", replylen);
190         return false;
191     }
192 
193     const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
194     if (qdcount != 1) {
195         ALOGW("reply query count != 1: %d", qdcount);
196         return false;
197     }
198 
199     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
200     ALOGV("%u answer count: %d", netid, ancount);
201 
202     // TODO: Further validate the response contents (check for valid AAAA record, ...).
203     // Note that currently, integration tests rely on this function accepting a
204     // response with zero records.
205 #if 0
206     for (int i = 0; i < resplen; i++) {
207         ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
208     }
209 #endif
210     return true;
211 }
212 
213 }  // end of namespace net
214 }  // end of namespace android
215