• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "DnsTlsQueryMap"
18 //#define LOG_NDEBUG 0
19 
20 #include "DnsTlsQueryMap.h"
21 
22 #include "log/log.h"
23 
24 namespace android {
25 namespace net {
26 
recordQuery(const netdutils::Slice query)27 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
28         const netdutils::Slice query) {
29     std::lock_guard guard(mLock);
30 
31     // Store the query so it can be matched to the response or reissued.
32     if (query.size() < 2) {
33         ALOGW("Query is too short");
34         return nullptr;
35     }
36     int32_t newId = getFreeId();
37     if (newId < 0) {
38         ALOGW("All query IDs are in use");
39         return nullptr;
40     }
41     Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
42     std::map<uint16_t, QueryPromise>::iterator it;
43     bool inserted;
44     std::tie(it, inserted) = mQueries.emplace(newId, q);
45     if (!inserted) {
46         ALOGE("Failed to store pending query");
47         return nullptr;
48     }
49     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
50 }
51 
expire(QueryPromise * p)52 void DnsTlsQueryMap::expire(QueryPromise* p) {
53     Result r = { .code = Response::network_error };
54     p->result.set_value(r);
55 }
56 
markTried(uint16_t newId)57 void DnsTlsQueryMap::markTried(uint16_t newId) {
58     std::lock_guard guard(mLock);
59     auto it = mQueries.find(newId);
60     if (it != mQueries.end()) {
61         it->second.tries++;
62     }
63 }
64 
cleanup()65 void DnsTlsQueryMap::cleanup() {
66     std::lock_guard guard(mLock);
67     for (auto it = mQueries.begin(); it != mQueries.end();) {
68         auto& p = it->second;
69         if (p.tries >= kMaxTries) {
70             expire(&p);
71             it = mQueries.erase(it);
72         } else {
73             ++it;
74         }
75     }
76 }
77 
getFreeId()78 int32_t DnsTlsQueryMap::getFreeId() {
79     if (mQueries.empty()) {
80         return 0;
81     }
82     uint16_t maxId = mQueries.rbegin()->first;
83     if (maxId < UINT16_MAX) {
84         return maxId + 1;
85     }
86     if (mQueries.size() == UINT16_MAX + 1) {
87         // Map is full.
88         return -1;
89     }
90     // Linear scan.
91     uint16_t nextId = 0;
92     for (auto& pair : mQueries) {
93         uint16_t id = pair.first;
94         if (id != nextId) {
95             // Found a gap.
96             return nextId;
97         }
98         nextId = id + 1;
99     }
100     // Unreachable (but the compiler isn't smart enough to prove it).
101     return -1;
102 }
103 
getAll()104 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
105     std::lock_guard guard(mLock);
106     std::vector<DnsTlsQueryMap::Query> queries;
107     for (auto& q : mQueries) {
108         queries.push_back(q.second.query);
109     }
110     return queries;
111 }
112 
empty()113 bool DnsTlsQueryMap::empty() {
114     std::lock_guard guard(mLock);
115     return mQueries.empty();
116 }
117 
clear()118 void DnsTlsQueryMap::clear() {
119     std::lock_guard guard(mLock);
120     for (auto& q : mQueries) {
121         expire(&q.second);
122     }
123     mQueries.clear();
124 }
125 
onResponse(std::vector<uint8_t> response)126 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
127     ALOGV("Got response of size %zu", response.size());
128     if (response.size() < 2) {
129         ALOGW("Response is too short");
130         return;
131     }
132     uint16_t id = response[0] << 8 | response[1];
133     std::lock_guard guard(mLock);
134     auto it = mQueries.find(id);
135     if (it == mQueries.end()) {
136         ALOGW("Discarding response: unknown ID %d", id);
137         return;
138     }
139     Result r = { .code = Response::success, .response = std::move(response) };
140     // Rewrite ID to match the query
141     const uint8_t* data = it->second.query.query.base();
142     r.response[0] = data[0];
143     r.response[1] = data[1];
144     ALOGV("Sending result to dispatcher");
145     it->second.result.set_value(std::move(r));
146     mQueries.erase(it);
147 }
148 
149 }  // end of namespace net
150 }  // end of namespace android
151