1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "DnsTlsQueryMap"
18 //#define LOG_NDEBUG 0
19
20 #include "DnsTlsQueryMap.h"
21
22 #include "log/log.h"
23
24 namespace android {
25 namespace net {
26
recordQuery(const netdutils::Slice query)27 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
28 const netdutils::Slice query) {
29 std::lock_guard guard(mLock);
30
31 // Store the query so it can be matched to the response or reissued.
32 if (query.size() < 2) {
33 ALOGW("Query is too short");
34 return nullptr;
35 }
36 int32_t newId = getFreeId();
37 if (newId < 0) {
38 ALOGW("All query IDs are in use");
39 return nullptr;
40 }
41 Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
42 std::map<uint16_t, QueryPromise>::iterator it;
43 bool inserted;
44 std::tie(it, inserted) = mQueries.emplace(newId, q);
45 if (!inserted) {
46 ALOGE("Failed to store pending query");
47 return nullptr;
48 }
49 return std::make_unique<QueryFuture>(q, it->second.result.get_future());
50 }
51
expire(QueryPromise * p)52 void DnsTlsQueryMap::expire(QueryPromise* p) {
53 Result r = { .code = Response::network_error };
54 p->result.set_value(r);
55 }
56
markTried(uint16_t newId)57 void DnsTlsQueryMap::markTried(uint16_t newId) {
58 std::lock_guard guard(mLock);
59 auto it = mQueries.find(newId);
60 if (it != mQueries.end()) {
61 it->second.tries++;
62 }
63 }
64
cleanup()65 void DnsTlsQueryMap::cleanup() {
66 std::lock_guard guard(mLock);
67 for (auto it = mQueries.begin(); it != mQueries.end();) {
68 auto& p = it->second;
69 if (p.tries >= kMaxTries) {
70 expire(&p);
71 it = mQueries.erase(it);
72 } else {
73 ++it;
74 }
75 }
76 }
77
getFreeId()78 int32_t DnsTlsQueryMap::getFreeId() {
79 if (mQueries.empty()) {
80 return 0;
81 }
82 uint16_t maxId = mQueries.rbegin()->first;
83 if (maxId < UINT16_MAX) {
84 return maxId + 1;
85 }
86 if (mQueries.size() == UINT16_MAX + 1) {
87 // Map is full.
88 return -1;
89 }
90 // Linear scan.
91 uint16_t nextId = 0;
92 for (auto& pair : mQueries) {
93 uint16_t id = pair.first;
94 if (id != nextId) {
95 // Found a gap.
96 return nextId;
97 }
98 nextId = id + 1;
99 }
100 // Unreachable (but the compiler isn't smart enough to prove it).
101 return -1;
102 }
103
getAll()104 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
105 std::lock_guard guard(mLock);
106 std::vector<DnsTlsQueryMap::Query> queries;
107 for (auto& q : mQueries) {
108 queries.push_back(q.second.query);
109 }
110 return queries;
111 }
112
empty()113 bool DnsTlsQueryMap::empty() {
114 std::lock_guard guard(mLock);
115 return mQueries.empty();
116 }
117
clear()118 void DnsTlsQueryMap::clear() {
119 std::lock_guard guard(mLock);
120 for (auto& q : mQueries) {
121 expire(&q.second);
122 }
123 mQueries.clear();
124 }
125
onResponse(std::vector<uint8_t> response)126 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
127 ALOGV("Got response of size %zu", response.size());
128 if (response.size() < 2) {
129 ALOGW("Response is too short");
130 return;
131 }
132 uint16_t id = response[0] << 8 | response[1];
133 std::lock_guard guard(mLock);
134 auto it = mQueries.find(id);
135 if (it == mQueries.end()) {
136 ALOGW("Discarding response: unknown ID %d", id);
137 return;
138 }
139 Result r = { .code = Response::success, .response = std::move(response) };
140 // Rewrite ID to match the query
141 const uint8_t* data = it->second.query.query.base();
142 r.response[0] = data[0];
143 r.response[1] = data[1];
144 ALOGV("Sending result to dispatcher");
145 it->second.result.set_value(std::move(r));
146 mQueries.erase(it);
147 }
148
149 } // end of namespace net
150 } // end of namespace android
151