1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "DnsTlsQueryMap"
18 //#define LOG_NDEBUG 0
19
20 #include "dns/DnsTlsTransport.h"
21
22 #include "log/log.h"
23
24 namespace android {
25 namespace net {
26
recordQuery(const Slice query)27 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(const Slice query) {
28 std::lock_guard<std::mutex> guard(mLock);
29
30 // Store the query so it can be matched to the response or reissued.
31 if (query.size() < 2) {
32 ALOGW("Query is too short");
33 return nullptr;
34 }
35 int32_t newId = getFreeId();
36 if (newId < 0) {
37 ALOGW("All query IDs are in use");
38 return nullptr;
39 }
40 Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
41 std::map<uint16_t, QueryPromise>::iterator it;
42 bool inserted;
43 std::tie(it, inserted) = mQueries.emplace(newId, q);
44 if (!inserted) {
45 ALOGE("Failed to store pending query");
46 return nullptr;
47 }
48 return std::make_unique<QueryFuture>(q, it->second.result.get_future());
49 }
50
expire(QueryPromise * p)51 void DnsTlsQueryMap::expire(QueryPromise* p) {
52 Result r = { .code = Response::network_error };
53 p->result.set_value(r);
54 }
55
markTried(uint16_t newId)56 void DnsTlsQueryMap::markTried(uint16_t newId) {
57 std::lock_guard<std::mutex> guard(mLock);
58 auto it = mQueries.find(newId);
59 if (it != mQueries.end()) {
60 it->second.tries++;
61 }
62 }
63
cleanup()64 void DnsTlsQueryMap::cleanup() {
65 std::lock_guard<std::mutex> guard(mLock);
66 for (auto it = mQueries.begin(); it != mQueries.end();) {
67 auto& p = it->second;
68 if (p.tries >= kMaxTries) {
69 expire(&p);
70 it = mQueries.erase(it);
71 } else {
72 ++it;
73 }
74 }
75 }
76
getFreeId()77 int32_t DnsTlsQueryMap::getFreeId() {
78 if (mQueries.empty()) {
79 return 0;
80 }
81 uint16_t maxId = mQueries.rbegin()->first;
82 if (maxId < UINT16_MAX) {
83 return maxId + 1;
84 }
85 if (mQueries.size() == UINT16_MAX + 1) {
86 // Map is full.
87 return -1;
88 }
89 // Linear scan.
90 uint16_t nextId = 0;
91 for (auto& pair : mQueries) {
92 uint16_t id = pair.first;
93 if (id != nextId) {
94 // Found a gap.
95 return nextId;
96 }
97 nextId = id + 1;
98 }
99 // Unreachable (but the compiler isn't smart enough to prove it).
100 return -1;
101 }
102
getAll()103 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
104 std::lock_guard<std::mutex> guard(mLock);
105 std::vector<DnsTlsQueryMap::Query> queries;
106 for (auto& q : mQueries) {
107 queries.push_back(q.second.query);
108 }
109 return queries;
110 }
111
empty()112 bool DnsTlsQueryMap::empty() {
113 std::lock_guard<std::mutex> guard(mLock);
114 return mQueries.empty();
115 }
116
clear()117 void DnsTlsQueryMap::clear() {
118 std::lock_guard<std::mutex> guard(mLock);
119 for (auto& q : mQueries) {
120 expire(&q.second);
121 }
122 mQueries.clear();
123 }
124
onResponse(std::vector<uint8_t> response)125 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
126 ALOGV("Got response of size %zu", response.size());
127 if (response.size() < 2) {
128 ALOGW("Response is too short");
129 return;
130 }
131 uint16_t id = response[0] << 8 | response[1];
132 std::lock_guard<std::mutex> guard(mLock);
133 auto it = mQueries.find(id);
134 if (it == mQueries.end()) {
135 ALOGW("Discarding response: unknown ID %d", id);
136 return;
137 }
138 Result r = { .code = Response::success, .response = std::move(response) };
139 // Rewrite ID to match the query
140 const uint8_t* data = it->second.query.query.base();
141 r.response[0] = data[0];
142 r.response[1] = data[1];
143 ALOGV("Sending result to dispatcher");
144 it->second.result.set_value(std::move(r));
145 mQueries.erase(it);
146 }
147
148 } // end of namespace net
149 } // end of namespace android
150