1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18 #define LOG_TAG "resolv"
19
20 #include "DnsStats.h"
21
22 #include <android-base/logging.h>
23 #include <android-base/stringprintf.h>
24
25 namespace android::net {
26
27 using base::StringPrintf;
28 using netdutils::DumpWriter;
29 using netdutils::IPAddress;
30 using netdutils::IPSockAddr;
31 using netdutils::ScopedIndent;
32 using std::chrono::duration_cast;
33 using std::chrono::microseconds;
34 using std::chrono::milliseconds;
35 using std::chrono::seconds;
36
37 namespace {
38
39 static constexpr IPAddress INVALID_IPADDRESS = IPAddress();
40
rcodeToName(int rcode)41 std::string rcodeToName(int rcode) {
42 // clang-format off
43 switch (rcode) {
44 case NS_R_NO_ERROR: return "NOERROR";
45 case NS_R_FORMERR: return "FORMERR";
46 case NS_R_SERVFAIL: return "SERVFAIL";
47 case NS_R_NXDOMAIN: return "NXDOMAIN";
48 case NS_R_NOTIMPL: return "NOTIMP";
49 case NS_R_REFUSED: return "REFUSED";
50 case NS_R_YXDOMAIN: return "YXDOMAIN";
51 case NS_R_YXRRSET: return "YXRRSET";
52 case NS_R_NXRRSET: return "NXRRSET";
53 case NS_R_NOTAUTH: return "NOTAUTH";
54 case NS_R_NOTZONE: return "NOTZONE";
55 case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR";
56 case NS_R_TIMEOUT: return "TIMEOUT";
57 default: return StringPrintf("UNKNOWN(%d)", rcode);
58 }
59 // clang-format on
60 }
61
ensureNoInvalidIp(const std::vector<IPSockAddr> & servers)62 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& servers) {
63 for (const auto& server : servers) {
64 if (server.ip() == INVALID_IPADDRESS || server.port() == 0) {
65 LOG(WARNING) << "Invalid server: " << server;
66 return false;
67 }
68 }
69 return true;
70 }
71
72 } // namespace
73
74 // The comparison ignores the last update time.
operator ==(const StatsData & o) const75 bool StatsData::operator==(const StatsData& o) const {
76 return std::tie(serverSockAddr, total, rcodeCounts, latencyUs) ==
77 std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
78 }
79
averageLatencyMs() const80 int StatsData::averageLatencyMs() const {
81 return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
82 }
83
toString() const84 std::string StatsData::toString() const {
85 if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
86
87 const auto now = std::chrono::steady_clock::now();
88 const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
89 std::string buf;
90 for (const auto& [rcode, counts] : rcodeCounts) {
91 if (counts != 0) {
92 buf += StringPrintf("%s:%d ", rcodeToName(rcode).c_str(), counts);
93 }
94 }
95 return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
96 averageLatencyMs(), buf.c_str(), lastUpdateSec);
97 }
98
StatsRecords(const IPSockAddr & ipSockAddr,size_t size)99 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
100 : mCapacity(size), mStatsData(ipSockAddr) {}
101
push(const Record & record)102 void StatsRecords::push(const Record& record) {
103 updateStatsData(record, true);
104 mRecords.push_back(record);
105
106 if (mRecords.size() > mCapacity) {
107 updateStatsData(mRecords.front(), false);
108 mRecords.pop_front();
109 }
110
111 // Update the quality factors.
112 mSkippedCount = 0;
113
114 // Because failures due to no permission can't prove that the quality of DNS server is bad,
115 // skip the penalty update. The average latency, however, has been updated. For short-latency
116 // servers, it will be fine. For long-latency servers, their average latency will be
117 // decreased but the latency-based algorithm will adjust their average latency back to the
118 // right range after few attempts when network is not restricted.
119 // The check is synced from isNetworkRestricted() in res_send.cpp.
120 if (record.linux_errno != EPERM) {
121 updatePenalty(record);
122 }
123 }
124
updateStatsData(const Record & record,const bool add)125 void StatsRecords::updateStatsData(const Record& record, const bool add) {
126 const int rcode = record.rcode;
127 if (add) {
128 mStatsData.total += 1;
129 mStatsData.rcodeCounts[rcode] += 1;
130 mStatsData.latencyUs += record.latencyUs;
131 } else {
132 mStatsData.total -= 1;
133 mStatsData.rcodeCounts[rcode] -= 1;
134 mStatsData.latencyUs -= record.latencyUs;
135 }
136 mStatsData.lastUpdate = std::chrono::steady_clock::now();
137 }
138
updatePenalty(const Record & record)139 void StatsRecords::updatePenalty(const Record& record) {
140 switch (record.rcode) {
141 case NS_R_NO_ERROR:
142 case NS_R_NXDOMAIN:
143 case NS_R_NOTAUTH:
144 mPenalty = 0;
145 return;
146 default:
147 // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
148 if (mPenalty == 0) {
149 mPenalty = 100;
150 } else {
151 // The evaluated quality drops more quickly when continuous failures happen.
152 mPenalty = std::min(mPenalty * 2, kMaxQuality);
153 }
154 return;
155 }
156 }
157
score() const158 double StatsRecords::score() const {
159 const int avgRtt = mStatsData.averageLatencyMs();
160
161 // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
162 // 1) when the server doesn't have any stats yet.
163 // 2) when the sorting has been disabled while it was enabled before.
164 int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
165
166 // Normalization.
167 return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
168 }
169
incrementSkippedCount()170 void StatsRecords::incrementSkippedCount() {
171 mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
172 }
173
setServers(const std::vector<netdutils::IPSockAddr> & servers,Protocol protocol)174 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
175 if (!ensureNoInvalidIp(servers)) return false;
176
177 ServerStatsMap& statsMap = mStats[protocol];
178 for (const auto& server : servers) {
179 statsMap.try_emplace(server, StatsRecords(server, kLogSize));
180 }
181
182 // Clean up the map to eliminate the nodes not belonging to the given list of servers.
183 const auto cleanup = [&](ServerStatsMap* statsMap) {
184 ServerStatsMap tmp;
185 for (const auto& server : servers) {
186 if (statsMap->find(server) != statsMap->end()) {
187 tmp.insert(statsMap->extract(server));
188 }
189 }
190 statsMap->swap(tmp);
191 };
192
193 cleanup(&statsMap);
194
195 return true;
196 }
197
addStats(const IPSockAddr & ipSockAddr,const DnsQueryEvent & record)198 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
199 if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
200
201 bool added = false;
202 for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
203 if (serverSockAddr == ipSockAddr) {
204 const StatsRecords::Record rec = {
205 .rcode = record.rcode(),
206 .linux_errno = record.linux_errno(),
207 .latencyUs = microseconds(record.latency_micros()),
208 };
209 statsRecords.push(rec);
210 added = true;
211 } else {
212 statsRecords.incrementSkippedCount();
213 }
214 }
215
216 return added;
217 }
218
getSortedServers(Protocol protocol) const219 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
220 // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
221 // while. Need to figure out if it is worth doing for DoT servers.
222 if (protocol == PROTO_DOT) return {};
223
224 auto it = mStats.find(protocol);
225 if (it == mStats.end()) return {};
226
227 // Sorting on insertion in decreasing order.
228 std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
229 for (const auto& [ip, statsRecords] : it->second) {
230 sortedData.insert({statsRecords.score(), ip});
231 }
232
233 std::vector<IPSockAddr> ret;
234 ret.reserve(sortedData.size());
235 for (auto& [_, v] : sortedData) {
236 ret.push_back(v); // IPSockAddr is trivially-copyable.
237 }
238
239 return ret;
240 }
241
getAverageLatencyUs(Protocol protocol) const242 std::optional<microseconds> DnsStats::getAverageLatencyUs(Protocol protocol) const {
243 const auto stats = getStats(protocol);
244
245 int count = 0;
246 microseconds sum;
247 for (const auto& v : stats) {
248 count += v.total;
249 sum += v.latencyUs;
250 }
251
252 if (count == 0) return std::nullopt;
253 return sum / count;
254 }
255
getStats(Protocol protocol) const256 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
257 std::vector<StatsData> ret;
258
259 if (mStats.find(protocol) != mStats.end()) {
260 for (const auto& [_, statsRecords] : mStats.at(protocol)) {
261 ret.push_back(statsRecords.getStatsData());
262 }
263 }
264 return ret;
265 }
266
dump(DumpWriter & dw)267 void DnsStats::dump(DumpWriter& dw) {
268 const auto dumpStatsMap = [&](ServerStatsMap& statsMap) {
269 ScopedIndent indentLog(dw);
270 if (statsMap.size() == 0) {
271 dw.println("<no server>");
272 return;
273 }
274 for (const auto& [_, statsRecords] : statsMap) {
275 const StatsData& data = statsRecords.getStatsData();
276 std::string str = data.toString();
277 str += StringPrintf(" score{%.1f}", statsRecords.score());
278 dw.println("%s", str.c_str());
279 }
280 };
281
282 dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)");
283 ScopedIndent indentStats(dw);
284
285 dw.println("over UDP");
286 dumpStatsMap(mStats[PROTO_UDP]);
287
288 dw.println("over TLS");
289 dumpStatsMap(mStats[PROTO_DOT]);
290
291 dw.println("over TCP");
292 dumpStatsMap(mStats[PROTO_TCP]);
293 }
294
295 } // namespace android::net
296