• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsQueryMap"
18 //#define LOG_NDEBUG 0
19 
20 #include "dns/DnsTlsTransport.h"
21 
22 #include "log/log.h"
23 
24 namespace android {
25 namespace net {
26 
recordQuery(const Slice query)27 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(const Slice query) {
28     std::lock_guard<std::mutex> guard(mLock);
29 
30     // Store the query so it can be matched to the response or reissued.
31     if (query.size() < 2) {
32         ALOGW("Query is too short");
33         return nullptr;
34     }
35     int32_t newId = getFreeId();
36     if (newId < 0) {
37         ALOGW("All query IDs are in use");
38         return nullptr;
39     }
40     Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
41     std::map<uint16_t, QueryPromise>::iterator it;
42     bool inserted;
43     std::tie(it, inserted) = mQueries.emplace(newId, q);
44     if (!inserted) {
45         ALOGE("Failed to store pending query");
46         return nullptr;
47     }
48     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
49 }
50 
expire(QueryPromise * p)51 void DnsTlsQueryMap::expire(QueryPromise* p) {
52     Result r = { .code = Response::network_error };
53     p->result.set_value(r);
54 }
55 
markTried(uint16_t newId)56 void DnsTlsQueryMap::markTried(uint16_t newId) {
57     std::lock_guard<std::mutex> guard(mLock);
58     auto it = mQueries.find(newId);
59     if (it != mQueries.end()) {
60         it->second.tries++;
61     }
62 }
63 
cleanup()64 void DnsTlsQueryMap::cleanup() {
65     std::lock_guard<std::mutex> guard(mLock);
66     for (auto it = mQueries.begin(); it != mQueries.end();) {
67         auto& p = it->second;
68         if (p.tries >= kMaxTries) {
69             expire(&p);
70             it = mQueries.erase(it);
71         } else {
72             ++it;
73         }
74     }
75 }
76 
getFreeId()77 int32_t DnsTlsQueryMap::getFreeId() {
78     if (mQueries.empty()) {
79         return 0;
80     }
81     uint16_t maxId = mQueries.rbegin()->first;
82     if (maxId < UINT16_MAX) {
83         return maxId + 1;
84     }
85     if (mQueries.size() == UINT16_MAX + 1) {
86         // Map is full.
87         return -1;
88     }
89     // Linear scan.
90     uint16_t nextId = 0;
91     for (auto& pair : mQueries) {
92         uint16_t id = pair.first;
93         if (id != nextId) {
94             // Found a gap.
95             return nextId;
96         }
97         nextId = id + 1;
98     }
99     // Unreachable (but the compiler isn't smart enough to prove it).
100     return -1;
101 }
102 
getAll()103 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
104     std::lock_guard<std::mutex> guard(mLock);
105     std::vector<DnsTlsQueryMap::Query> queries;
106     for (auto& q : mQueries) {
107         queries.push_back(q.second.query);
108     }
109     return queries;
110 }
111 
empty()112 bool DnsTlsQueryMap::empty() {
113     std::lock_guard<std::mutex> guard(mLock);
114     return mQueries.empty();
115 }
116 
clear()117 void DnsTlsQueryMap::clear() {
118     std::lock_guard<std::mutex> guard(mLock);
119     for (auto& q : mQueries) {
120         expire(&q.second);
121     }
122     mQueries.clear();
123 }
124 
onResponse(std::vector<uint8_t> response)125 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
126     ALOGV("Got response of size %zu", response.size());
127     if (response.size() < 2) {
128         ALOGW("Response is too short");
129         return;
130     }
131     uint16_t id = response[0] << 8 | response[1];
132     std::lock_guard<std::mutex> guard(mLock);
133     auto it = mQueries.find(id);
134     if (it == mQueries.end()) {
135         ALOGW("Discarding response: unknown ID %d", id);
136         return;
137     }
138     Result r = { .code = Response::success, .response = std::move(response) };
139     // Rewrite ID to match the query
140     const uint8_t* data = it->second.query.query.base();
141     r.response[0] = data[0];
142     r.response[1] = data[1];
143     ALOGV("Sending result to dispatcher");
144     it->second.result.set_value(std::move(r));
145     mQueries.erase(it);
146 }
147 
148 }  // end of namespace net
149 }  // end of namespace android
150